diff --git a/README.md b/README.md index 9ee0fe5..5758bc4 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,9 @@ Some other features included: -auto-scale-allow-deny string Path to config for server allowlists and denylists. If a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up when -auto-scale-up is enabled or cancel active down scalers when -auto-scale-down is enabled (env AUTO_SCALE_ALLOW_DENY) -auto-scale-asleep-motd string - MOTD to display when auto-scaled down servers are accessed; if empty, no status will be served (env AUTO_SCALE_ASLEEP_MOTD) + MOTD to display when auto-scaled down servers are accessed; if empty, no status will be served (env AUTO_SCALE_ASLEEP_MOTD) + -auto-scale-loading-motd string + MOTD to display while auto-scaled Docker servers are waking up; if empty, asleep status will be served (env AUTO_SCALE_LOADING_MOTD) -auto-scale-down Scale to zero after idle. For Kubernetes, decreases StatefulSet replicas from 1 to 0. For Docker, gracefully stops the container when there are no connections (env AUTO_SCALE_DOWN) -auto-scale-down-after string @@ -176,6 +178,7 @@ These are the labels scanned: - `mc-router.auto-scale-down`: Per-container override to enable/disable auto scale down for Docker. When true (or left unspecified and the global `-auto-scale-down` flag is enabled), mc-router will stop this container after it has been idle for the configured `-auto-scale-down-after` duration. - `mc-router.auto-scale-asleep-motd`: Per-container override for MOTD to show when container is scaled to zero. If empty or not set the host will appear unresponsive. +- `mc-router.auto-scale-loading-motd`: Per-container override for MOTD to show while the container is waking and not yet reachable. If empty or not set, the global `-auto-scale-loading-motd` value is used. #### Docker Auto Scale Up/Down @@ -198,6 +201,7 @@ For usage with docker compose refer to the [examples/docker-autoscale/compose.ym Behavior: - When a client connects to a labeled hostname and the container is stopped or paused, mc-router will start/unpause it and wait until it becomes reachable (up to ~60s). +- While that wake-up is in progress and status pings are received, mc-router can return a loading MOTD (per-container override or `-auto-scale-loading-motd`). - When no clients remain connected and the idle timer elapses (`-auto-scale-down-after`), mc-router gracefully stops the container. Note: Docker Swarm discovery is supported; however, auto scale up/down is not yet supported for Swarm services. diff --git a/server/api_server.go b/server/api_server.go index a57cb90..eaba753 100644 --- a/server/api_server.go +++ b/server/api_server.go @@ -81,7 +81,7 @@ func routesCreateHandler(writer http.ResponseWriter, request *http.Request) { return } - Routes.CreateMapping(definition.ServerAddress, definition.Backend, "", nil, nil, "") + Routes.CreateMapping(definition.ServerAddress, definition.Backend, "", nil, nil, "", "") RoutesConfigLoader.SaveRoutes() writer.WriteHeader(http.StatusCreated) } @@ -102,7 +102,7 @@ func routesSetDefault(writer http.ResponseWriter, request *http.Request) { return } - Routes.SetDefaultRoute(body.Backend, "", nil, nil, "") + Routes.SetDefaultRoute(body.Backend, "", nil, nil, "", "") RoutesConfigLoader.SaveRoutes() writer.WriteHeader(http.StatusOK) } diff --git a/server/configs.go b/server/configs.go index b688ca7..c22ee72 100644 --- a/server/configs.go +++ b/server/configs.go @@ -8,11 +8,12 @@ type WebhookConfig struct { } type AutoScale struct { - Up bool `usage:"Scale from zero on access. For Kubernetes, increases StatefulSet replicas from 0 to 1. For Docker, starts or unpauses the container when accessed"` - Down bool `default:"false" usage:"Scale to zero after idle. For Kubernetes, decreases StatefulSet replicas from 1 to 0. For Docker, gracefully stops the container when there are no connections"` - DownAfter time.Duration `default:"10m" usage:"Server scale down delay after there are no connections"` - AllowDeny string `usage:"Path to config for server allowlists and denylists. If a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up when -auto-scale-up is enabled or cancel active down scalers when -auto-scale-down is enabled"` - AsleepMOTD string `usage:"MOTD to display when auto-scaled down servers are accessed; if empty, no status will be served"` + Up bool `usage:"Scale from zero on access. For Kubernetes, increases StatefulSet replicas from 0 to 1. For Docker, starts or unpauses the container when accessed"` + Down bool `default:"false" usage:"Scale to zero after idle. For Kubernetes, decreases StatefulSet replicas from 1 to 0. For Docker, gracefully stops the container when there are no connections"` + DownAfter time.Duration `default:"10m" usage:"Server scale down delay after there are no connections"` + AllowDeny string `usage:"Path to config for server allowlists and denylists. If a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up when -auto-scale-up is enabled or cancel active down scalers when -auto-scale-down is enabled"` + AsleepMOTD string `usage:"MOTD to display when auto-scaled down servers are accessed; if empty, no status will be served"` + LoadingMOTD string `usage:"MOTD to display while auto-scaled Docker servers are waking up; if empty, asleep status will be served"` } type RoutesConfig struct { diff --git a/server/connector.go b/server/connector.go index 750cb51..5e45a2d 100644 --- a/server/connector.go +++ b/server/connector.go @@ -78,6 +78,7 @@ func NewConnector(ctx context.Context, metrics *ConnectorMetrics, sendProxyProto autoScaleUpAllowDenyConfig: autoScaleUpAllowDenyConfig, activeConnections: NewActiveConnections(), scaleActiveConnections: NewActiveConnections(), + wakingServers: NewActiveConnections(), } } @@ -97,12 +98,14 @@ type Connector struct { totalActiveConnections int32 activeConnections *ActiveConnections scaleActiveConnections *ActiveConnections + wakingServers *ActiveConnections connectionsCond *sync.Cond ngrok NgrokConnector clientFilter *ClientFilter autoScaleUpAllowDenyConfig *AllowDenyConfig connectionNotifier ConnectionNotifier asleepMOTD string + loadingMOTD string } func (c *Connector) UseConnectionNotifier(notifier ConnectionNotifier) { @@ -364,9 +367,11 @@ func (c *Connector) HandleConnection(frontendConn net.Conn) { // serveStatus writes a predefined status JSON and optionally handles ping/pong func (c *Connector) serveStatus(frontendConn net.Conn, reader *bufio.Reader, serverAddress string, clientProtocol int) { - motd := Routes.GetAsleepMOTD(serverAddress) - if motd == "" { - motd = c.asleepMOTD + motd := "" + if c.isWakeInProgress(serverAddress) { + motd = c.getLoadingMOTD(serverAddress) + } else { + motd = c.getAsleepMOTD(serverAddress) } if motd == "" { return @@ -453,8 +458,13 @@ func (c *Connector) serveStatus(frontendConn net.Conn, reader *bufio.Reader, ser } // serveLegacyStatus writes a simple legacy SLP response and closes the connection -func (c *Connector) serveLegacyStatus(frontendConn net.Conn) { - motd := c.asleepMOTD +func (c *Connector) serveLegacyStatus(frontendConn net.Conn, serverAddress string) { + motd := "" + if c.isWakeInProgress(serverAddress) { + motd = c.getLoadingMOTD(serverAddress) + } else { + motd = c.getAsleepMOTD(serverAddress) + } if motd == "" { return } @@ -550,7 +560,9 @@ func (c *Connector) findAndConnectBackend(frontendConn net.Conn, } cleanupCheckScaleDown = true logrus.WithField("serverAddress", serverAddress).Info("Waking up backend server") + c.wakingServers.Increment(serverAddress) newBackendHostPort, err := waker(c.ctx) + c.wakingServers.Decrement(serverAddress) if err != nil { logrus.WithFields(logrus.Fields{"serverAddress": serverAddress}).WithError(err).Error("failed to wake up backend") c.metrics.Errors.With("type", "wakeup_failed").Add(1) @@ -561,6 +573,9 @@ func (c *Connector) findAndConnectBackend(frontendConn net.Conn, c.metrics.Errors.With("type", "wakeup_no_address").Add(1) return } + if scalingTarget == "" { + scalingTarget = newBackendHostPort + } // Cancel again in case any routes were changed during wake up DownScaler.Cancel(scalingTarget) backendHostPort = newBackendHostPort @@ -571,13 +586,15 @@ func (c *Connector) findAndConnectBackend(frontendConn net.Conn, } } - if backendHostPort == "" { - logrus. - WithField("serverAddress", serverAddress). - WithField("resolvedHost", resolvedHost). - WithField("player", playerInfo). - Warn("Unable to find registered backend") - c.metrics.Errors.With("type", "missing_backend").Add(1) + if backendHostPort == "" || (c.isWakeInProgress(serverAddress) && nextState == mcproto.StateStatus) { + if waker == nil { + logrus. + WithField("serverAddress", serverAddress). + WithField("resolvedHost", resolvedHost). + WithField("player", playerInfo). + Warn("Unable to find registered backend") + c.metrics.Errors.With("type", "missing_backend").Add(1) + } if c.connectionNotifier != nil { err := c.connectionNotifier.NotifyMissingBackend(c.ctx, clientAddr, serverAddress, playerInfo) @@ -597,7 +614,7 @@ func (c *Connector) findAndConnectBackend(frontendConn net.Conn, // Read Status Request and Ping directly from the client connection br := bufio.NewReader(frontendConn) if isLegacy { - c.serveLegacyStatus(frontendConn) + c.serveLegacyStatus(frontendConn, serverAddress) } else { c.serveStatus(frontendConn, br, serverAddress, clientProtocol) } @@ -605,6 +622,18 @@ func (c *Connector) findAndConnectBackend(frontendConn net.Conn, return } + if c.isWakeInProgress(serverAddress) { + logrus. + WithField("serverAddress", serverAddress). + WithField("resolvedHost", resolvedHost). + WithField("player", playerInfo). + Debug("Waiting for backend to wake up before connecting") + // TODO: replace with event-based notification + for c.isWakeInProgress(serverAddress) { + time.Sleep(500 * time.Millisecond) + } + } + logrus. WithField("client", clientAddr). WithField("server", serverAddress). @@ -793,6 +822,39 @@ func (c *Connector) UseAsleepMOTD(motd string) { c.asleepMOTD = motd } +// UseLoadingMOTD configures a predefined MOTD to serve when backends are waking up +func (c *Connector) UseLoadingMOTD(motd string) { + c.loadingMOTD = motd +} + +func (c *Connector) isWakeInProgress(serverAddress string) bool { + if serverAddress == "" { + return false + } + + return c.wakingServers.GetCount(serverAddress) > 0 +} + +func (c *Connector) getAsleepMOTD(serverAddress string) string { + motd := Routes.GetAsleepMOTD(serverAddress) + if motd == "" { + motd = c.asleepMOTD + } + return motd +} + +func (c *Connector) getLoadingMOTD(serverAddress string) string { + motd := Routes.GetLoadingMOTD(serverAddress) + if motd == "" { + motd = c.loadingMOTD + } + // If no specific loading MOTD, fall back to asleep MOTD + if motd == "" { + return c.getAsleepMOTD(serverAddress) + } + return motd +} + // getVersionInfo falls back to client protocol and a derived name but in future // could be extended to cache server-reported versions func (c *Connector) getVersionInfo(_ string, clientProtocol int) (string, int) { diff --git a/server/connector_test.go b/server/connector_test.go index ab26a53..f428f44 100644 --- a/server/connector_test.go +++ b/server/connector_test.go @@ -77,3 +77,36 @@ func parseTrustedProxyNets(nets []string) []*net.IPNet { } return parsedNets } + +func TestConnectorWakeTracking(t *testing.T) { + c := &Connector{wakingServers: NewActiveConnections()} + + assert.False(t, c.isWakeInProgress("scale-target")) + c.wakingServers.Increment("scale-target") + assert.True(t, c.isWakeInProgress("scale-target")) + + // track concurrent wake operations for same route + c.wakingServers.Increment("scale-target") + c.wakingServers.Decrement("scale-target") + assert.True(t, c.isWakeInProgress("scale-target")) + + c.wakingServers.Decrement("scale-target") + assert.False(t, c.isWakeInProgress("scale-target")) +} + +func TestConnectorGetLoadingMOTD(t *testing.T) { + oldRoutes := Routes + defer func() { + Routes = oldRoutes + }() + + Routes = NewRoutes() + Routes.CreateMapping("mc.example.com", "backend:25565", "", nil, nil, "", "route loading") + + c := &Connector{loadingMOTD: "global loading"} + assert.Equal(t, "route loading", c.getLoadingMOTD("mc.example.com")) + assert.Equal(t, "global loading", c.getLoadingMOTD("other.example.com")) + + Routes.SetDefaultRoute("default:25565", "", nil, nil, "", "default loading") + assert.Equal(t, "default loading", c.getLoadingMOTD("")) +} diff --git a/server/docker.go b/server/docker.go index e29330c..ac740e8 100644 --- a/server/docker.go +++ b/server/docker.go @@ -26,6 +26,7 @@ const ( DockerRouterLabelAutoScaleUp = "mc-router.auto-scale-up" DockerRouterLabelAutoScaleDown = "mc-router.auto-scale-down" DockerRouterLabelAutoScaleAsleepMOTD = "mc-router.auto-scale-asleep-motd" + DockerRouterLabelAutoScaleLoadingMOTD = "mc-router.auto-scale-loading-motd" ) type dockerWatcherConfig struct { @@ -163,6 +164,11 @@ func (w *dockerWatcherImpl) makeSleeperFunc(rc *routableContainer) SleeperFunc { return err } } + err = w.monitorContainers(ctx) + if err != nil { + logrus.WithError(err).Error("Docker monitoring failed") + return err + } return nil } } @@ -186,23 +192,24 @@ func (w *dockerWatcherImpl) monitorContainers(ctx context.Context) error { wakerFunc := w.makeWakerFunc(rs) sleeperFunc := w.makeSleeperFunc(rs) if rs.externalContainerName != "" { - Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, "", wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD) + Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, "", wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD, rs.autoScaleLoadingMOTD) } else { - Routes.SetDefaultRoute(rs.containerEndpoint, "", wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD) + Routes.SetDefaultRoute(rs.containerEndpoint, "", wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD, rs.autoScaleLoadingMOTD) } } else if oldRs.containerEndpoint != rs.containerEndpoint || oldRs.containerID != rs.containerID || oldRs.autoScaleUp != rs.autoScaleUp || oldRs.autoScaleDown != rs.autoScaleDown || - oldRs.autoScaleAsleepMOTD != rs.autoScaleAsleepMOTD { + oldRs.autoScaleAsleepMOTD != rs.autoScaleAsleepMOTD || + oldRs.autoScaleLoadingMOTD != rs.autoScaleLoadingMOTD { w.containerMap[rs.externalContainerName] = rs wakerFunc := w.makeWakerFunc(rs) sleeperFunc := w.makeSleeperFunc(rs) if rs.externalContainerName != "" { Routes.DeleteMapping(rs.externalContainerName) - Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, "", wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD) + Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, "", wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD, rs.autoScaleLoadingMOTD) } else { - Routes.SetDefaultRoute(rs.containerEndpoint, "", wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD) + Routes.SetDefaultRoute(rs.containerEndpoint, "", wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD, rs.autoScaleLoadingMOTD) } logrus.WithFields(logrus.Fields{"old": oldRs, "new": rs}).Debug("UPDATE") } @@ -214,7 +221,7 @@ func (w *dockerWatcherImpl) monitorContainers(ctx context.Context) error { if rs.externalContainerName != "" { Routes.DeleteMapping(rs.externalContainerName) } else { - Routes.SetDefaultRoute("", "", nil, nil, "") + Routes.SetDefaultRoute("", "", nil, nil, "", "") } logrus.WithField("routableContainer", rs).Debug("DELETE") } @@ -257,9 +264,9 @@ func (w *dockerWatcherImpl) Start(ctx context.Context) error { wakerFunc := w.makeWakerFunc(c) sleeperFunc := w.makeSleeperFunc(c) if c.externalContainerName != "" { - Routes.CreateMapping(c.externalContainerName, c.containerEndpoint, "", wakerFunc, sleeperFunc, c.autoScaleAsleepMOTD) + Routes.CreateMapping(c.externalContainerName, c.containerEndpoint, "", wakerFunc, sleeperFunc, c.autoScaleAsleepMOTD, c.autoScaleLoadingMOTD) } else { - Routes.SetDefaultRoute(c.containerEndpoint, "", wakerFunc, sleeperFunc, c.autoScaleAsleepMOTD) + Routes.SetDefaultRoute(c.containerEndpoint, "", wakerFunc, sleeperFunc, c.autoScaleAsleepMOTD, c.autoScaleLoadingMOTD) } } @@ -315,6 +322,7 @@ func (w *dockerWatcherImpl) listContainers(ctx context.Context) ([]*routableCont autoScaleUp: data.autoScaleUp, autoScaleDown: data.autoScaleDown, autoScaleAsleepMOTD: data.autoScaleAsleepMOTD, + autoScaleLoadingMOTD: data.autoScaleLoadingMOTD, }) } if data.def != nil && *data.def { @@ -325,6 +333,7 @@ func (w *dockerWatcherImpl) listContainers(ctx context.Context) ([]*routableCont autoScaleUp: data.autoScaleUp, autoScaleDown: data.autoScaleDown, autoScaleAsleepMOTD: data.autoScaleAsleepMOTD, + autoScaleLoadingMOTD: data.autoScaleLoadingMOTD, }) } } @@ -341,6 +350,7 @@ type parsedDockerContainerData struct { autoScaleDown bool autoScaleUp bool autoScaleAsleepMOTD string + autoScaleLoadingMOTD string notRunning bool } @@ -419,6 +429,9 @@ func (w *dockerWatcherImpl) parseContainerData(container *container.InspectRespo if key == DockerRouterLabelAutoScaleAsleepMOTD { data.autoScaleAsleepMOTD = value } + if key == DockerRouterLabelAutoScaleLoadingMOTD { + data.autoScaleLoadingMOTD = value + } } // probably not minecraft related @@ -499,4 +512,5 @@ type routableContainer struct { autoScaleUp bool autoScaleDown bool autoScaleAsleepMOTD string + autoScaleLoadingMOTD string } diff --git a/server/docker_swarm.go b/server/docker_swarm.go index 319d25e..3e246f0 100644 --- a/server/docker_swarm.go +++ b/server/docker_swarm.go @@ -89,9 +89,9 @@ func (w *dockerSwarmWatcherImpl) Start(ctx context.Context) error { wakerFunc := w.makeWakerFunc(s) sleeperFunc := w.makeSleeperFunc(s) if s.externalServiceName != "" { - Routes.CreateMapping(s.externalServiceName, s.containerEndpoint, "", wakerFunc, sleeperFunc, "") + Routes.CreateMapping(s.externalServiceName, s.containerEndpoint, "", wakerFunc, sleeperFunc, "", "") } else { - Routes.SetDefaultRoute(s.containerEndpoint, "", wakerFunc, sleeperFunc, "") + Routes.SetDefaultRoute(s.containerEndpoint, "", wakerFunc, sleeperFunc, "", "") } } @@ -113,9 +113,9 @@ func (w *dockerSwarmWatcherImpl) Start(ctx context.Context) error { wakerFunc := w.makeWakerFunc(rs) sleeperFunc := w.makeSleeperFunc(rs) if rs.externalServiceName != "" { - Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, "", wakerFunc, sleeperFunc, "") + Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, "", wakerFunc, sleeperFunc, "", "") } else { - Routes.SetDefaultRoute(rs.containerEndpoint, "", wakerFunc, sleeperFunc, "") + Routes.SetDefaultRoute(rs.containerEndpoint, "", wakerFunc, sleeperFunc, "", "") } } else if oldRs.containerEndpoint != rs.containerEndpoint { serviceMap[rs.externalServiceName] = rs @@ -123,9 +123,9 @@ func (w *dockerSwarmWatcherImpl) Start(ctx context.Context) error { sleeperFunc := w.makeSleeperFunc(rs) if rs.externalServiceName != "" { Routes.DeleteMapping(rs.externalServiceName) - Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, "", wakerFunc, sleeperFunc, "") + Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, "", wakerFunc, sleeperFunc, "", "") } else { - Routes.SetDefaultRoute(rs.containerEndpoint, "", wakerFunc, sleeperFunc, "") + Routes.SetDefaultRoute(rs.containerEndpoint, "", wakerFunc, sleeperFunc, "", "") } logrus.WithFields(logrus.Fields{"old": oldRs, "new": rs}).Debug("UPDATE") } @@ -137,7 +137,7 @@ func (w *dockerSwarmWatcherImpl) Start(ctx context.Context) error { if rs.externalServiceName != "" { Routes.DeleteMapping(rs.externalServiceName) } else { - Routes.SetDefaultRoute("", "", nil, nil, "") + Routes.SetDefaultRoute("", "", nil, nil, "", "") } logrus.WithField("routableService", rs).Debug("DELETE") } diff --git a/server/k8s.go b/server/k8s.go index 4bbc2b4..39d64f6 100644 --- a/server/k8s.go +++ b/server/k8s.go @@ -185,9 +185,9 @@ func (w *K8sWatcher) handleUpdate(oldObj interface{}, newObj interface{}) { "new": newRoutableService, }).Debug("UPDATE") if newRoutableService.externalServiceName != "" { - w.routesHandler.CreateMapping(newRoutableService.externalServiceName, newRoutableService.containerEndpoint, newRoutableService.scalingTarget, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown, "") + w.routesHandler.CreateMapping(newRoutableService.externalServiceName, newRoutableService.containerEndpoint, newRoutableService.scalingTarget, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown, "", "") } else { - w.routesHandler.SetDefaultRoute(newRoutableService.containerEndpoint, newRoutableService.scalingTarget, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown, "") + w.routesHandler.SetDefaultRoute(newRoutableService.containerEndpoint, newRoutableService.scalingTarget, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown, "", "") } } } @@ -202,7 +202,7 @@ func (w *K8sWatcher) handleDelete(obj interface{}) { if routableService.externalServiceName != "" { w.routesHandler.DeleteMapping(routableService.externalServiceName) } else { - w.routesHandler.SetDefaultRoute("", "", nil, nil, "") + w.routesHandler.SetDefaultRoute("", "", nil, nil, "", "") } } } @@ -216,9 +216,9 @@ func (w *K8sWatcher) handleAdd(obj interface{}) { logrus.WithField("routableService", routableService).Debug("ADD") if routableService.externalServiceName != "" { - w.routesHandler.CreateMapping(routableService.externalServiceName, routableService.containerEndpoint, routableService.scalingTarget, routableService.autoScaleUp, routableService.autoScaleDown, "") + w.routesHandler.CreateMapping(routableService.externalServiceName, routableService.containerEndpoint, routableService.scalingTarget, routableService.autoScaleUp, routableService.autoScaleDown, "", "") } else { - w.routesHandler.SetDefaultRoute(routableService.containerEndpoint, routableService.scalingTarget, routableService.autoScaleUp, routableService.autoScaleDown, "") + w.routesHandler.SetDefaultRoute(routableService.containerEndpoint, routableService.scalingTarget, routableService.autoScaleUp, routableService.autoScaleDown, "", "") } } } diff --git a/server/k8s_test.go b/server/k8s_test.go index f9a56f0..42a6b38 100644 --- a/server/k8s_test.go +++ b/server/k8s_test.go @@ -28,16 +28,16 @@ func (m *MockedRoutesHandler) GetBackendForServer(server string) string { } } -func (m *MockedRoutesHandler) CreateMapping(serverAddress string, backend string, scaleKey string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) { - m.MethodCalled("CreateMapping", serverAddress, backend, scaleKey, waker, sleeper, asleepMOTD) +func (m *MockedRoutesHandler) CreateMapping(serverAddress string, backend string, scaleKey string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string, loadingMOTD string) { + m.MethodCalled("CreateMapping", serverAddress, backend, scaleKey, waker, sleeper, asleepMOTD, loadingMOTD) if m.routes == nil { m.routes = make(map[string]string) } m.routes[serverAddress] = backend } -func (m *MockedRoutesHandler) SetDefaultRoute(backend string, scaleKey string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) { - m.MethodCalled("SetDefaultRoute", backend, scaleKey, waker, sleeper, asleepMOTD) +func (m *MockedRoutesHandler) SetDefaultRoute(backend string, scaleKey string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string, loadingMOTD string) { + m.MethodCalled("SetDefaultRoute", backend, scaleKey, waker, sleeper, asleepMOTD, loadingMOTD) if m.routes == nil { m.routes = make(map[string]string) } @@ -183,8 +183,8 @@ func TestK8sWatcherImpl_handleAddThenUpdate(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() - routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) @@ -264,8 +264,8 @@ func TestK8sWatcherImpl_handleAddThenDelete(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() - routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) @@ -363,8 +363,8 @@ func TestK8s_externalName(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() - routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) @@ -431,8 +431,8 @@ func TestK8s_proxyServerName(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() - routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) @@ -456,8 +456,8 @@ func TestK8s_proxyServerNameScaleEndpoint(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() - routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) @@ -472,15 +472,15 @@ func TestK8s_proxyServerNameScaleEndpoint(t *testing.T) { watcher.handleAdd(&svc) // Verify CreateMapping was called with the correct scaleKey (original endpoint) - routesHandler.AssertCalled(t, "CreateMapping", "mc.example.com", "velocity:25577", "10.0.0.5:25565", mock.Anything, mock.Anything, mock.Anything) + routesHandler.AssertCalled(t, "CreateMapping", "mc.example.com", "velocity:25577", "10.0.0.5:25565", mock.Anything, mock.Anything, mock.Anything, mock.Anything) } func TestK8s_proxyServerNameUpdate(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() - routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) @@ -509,8 +509,8 @@ func TestK8s_autoScaleWithoutProxy(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() - routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) @@ -532,5 +532,5 @@ func TestK8s_autoScaleWithoutProxy(t *testing.T) { // CRITICAL: Verify scaleKey is set to the service endpoint (not empty) // This ensures auto-scaling targets the correct StatefulSet - routesHandler.AssertCalled(t, "CreateMapping", "atm-10.example.com", "10.0.0.10:25565", "10.0.0.10:25565", mock.Anything, mock.Anything, mock.Anything) + routesHandler.AssertCalled(t, "CreateMapping", "atm-10.example.com", "10.0.0.10:25565", "10.0.0.10:25565", mock.Anything, mock.Anything, mock.Anything, mock.Anything) } diff --git a/server/routes.go b/server/routes.go index 824bc43..a74bdf9 100644 --- a/server/routes.go +++ b/server/routes.go @@ -36,8 +36,8 @@ type RouteFinder interface { } type RoutesHandler interface { - CreateMapping(serverAddress string, backend string, scalingTarget string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) - SetDefaultRoute(backend string, scalingTarget string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) + CreateMapping(serverAddress string, backend string, scalingTarget string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string, loadingMOTD string) + SetDefaultRoute(backend string, scalingTarget string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string, loadingMOTD string) // DeleteMapping requests that the serverAddress be removed from routes. // Returns true if the route existed. DeleteMapping(serverAddress string) bool @@ -59,6 +59,7 @@ type IRoutes interface { GetMappings() map[string]string GetDefaultRoute() (string, string, WakerFunc, SleeperFunc) GetAsleepMOTD(serverAddress string) string + GetLoadingMOTD(serverAddress string) string SimplifySRV(srvEnabled bool) } @@ -74,7 +75,7 @@ func NewRoutes() IRoutes { func (r *routesImpl) RegisterAll(mappings map[string]string) { for k, v := range mappings { - r.CreateMapping(k, v, "", nil, nil, "") + r.CreateMapping(k, v, "", nil, nil, "", "") } } @@ -83,6 +84,7 @@ type mapping struct { waker WakerFunc sleeper SleeperFunc asleepMOTD string + loadingMOTD string scalingTarget string // The endpoint to scale (may differ from backend when using proxy) } @@ -98,11 +100,11 @@ func (r *routesImpl) Reset() { DownScaler.Reset() } -func (r *routesImpl) SetDefaultRoute(backend string, scalingTarget string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) { +func (r *routesImpl) SetDefaultRoute(backend string, scalingTarget string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string, loadingMOTD string) { if scalingTarget == "" { scalingTarget = backend } - r.defaultRoute = mapping{backend: backend, scalingTarget: scalingTarget, waker: waker, sleeper: sleeper, asleepMOTD: asleepMOTD} + r.defaultRoute = mapping{backend: backend, scalingTarget: scalingTarget, waker: waker, sleeper: sleeper, asleepMOTD: asleepMOTD, loadingMOTD: loadingMOTD} logrus.WithFields(logrus.Fields{ "backend": backend, @@ -127,6 +129,20 @@ func (r *routesImpl) GetAsleepMOTD(serverAddress string) string { return "" } +func (r *routesImpl) GetLoadingMOTD(serverAddress string) string { + r.RLock() + defer r.RUnlock() + + if serverAddress == "" { + return r.defaultRoute.loadingMOTD + } + + if m, ok := r.mappings[serverAddress]; ok { + return m.loadingMOTD + } + return "" +} + func (r *routesImpl) SimplifySRV(srvEnabled bool) { r.simplifySRV = srvEnabled } @@ -225,7 +241,7 @@ func (r *routesImpl) DeleteMapping(serverAddress string) bool { } } -func (r *routesImpl) CreateMapping(serverAddress string, backend string, scalingTarget string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) { +func (r *routesImpl) CreateMapping(serverAddress string, backend string, scalingTarget string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string, loadingMOTD string) { r.Lock() defer r.Unlock() @@ -239,7 +255,7 @@ func (r *routesImpl) CreateMapping(serverAddress string, backend string, scaling "serverAddress": serverAddress, "backend": backend, }).Info("Created route mapping") - r.mappings[serverAddress] = mapping{backend: backend, scalingTarget: scalingTarget, waker: waker, sleeper: sleeper, asleepMOTD: asleepMOTD} + r.mappings[serverAddress] = mapping{backend: backend, scalingTarget: scalingTarget, waker: waker, sleeper: sleeper, asleepMOTD: asleepMOTD, loadingMOTD: loadingMOTD} // Trigger auto scale down when mapping is created to ensure servers are shut down if router restarts if DownScaler != nil && scalingTarget != "" { diff --git a/server/routes_config_loader.go b/server/routes_config_loader.go index 44c51e2..3104caa 100644 --- a/server/routes_config_loader.go +++ b/server/routes_config_loader.go @@ -44,7 +44,7 @@ func (r *routesConfigLoader) Load(routesConfigFileName string) error { } Routes.RegisterAll(config.Mappings) - Routes.SetDefaultRoute(config.DefaultServer, "", nil, nil, "") + Routes.SetDefaultRoute(config.DefaultServer, "", nil, nil, "", "") return nil } @@ -62,7 +62,7 @@ func (r *routesConfigLoader) Reload() error { logrus.WithField("routesConfig", r.fileName).Info("Re-loading routes config file") Routes.Reset() Routes.RegisterAll(config.Mappings) - Routes.SetDefaultRoute(config.DefaultServer, "", nil, nil, "") + Routes.SetDefaultRoute(config.DefaultServer, "", nil, nil, "", "") return nil } diff --git a/server/routes_test.go b/server/routes_test.go index 9c36194..c18f1f6 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -68,7 +68,7 @@ func Test_routesImpl_FindBackendForServerAddress(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r := NewRoutes() - r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend, "", nil, nil, "") + r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend, "", nil, nil, "", "") if got, server, _, _, _ := r.FindBackendForServerAddress(context.Background(), tt.args.serverAddress); got != tt.want { t.Errorf("routesImpl.FindBackendForServerAddress() = %v, want %v", got, tt.want) @@ -84,7 +84,7 @@ func Test_routesImpl_ScaleKey(t *testing.T) { t.Run("scaleKey defaults to backend when empty", func(t *testing.T) { r := NewRoutes() - r.CreateMapping("mc.example.com", "backend:25565", "", nil, nil, "") + r.CreateMapping("mc.example.com", "backend:25565", "", nil, nil, "", "") _, _, scaleKey, _, _ := r.FindBackendForServerAddress(context.Background(), "mc.example.com") assert.Equal(t, "backend:25565", scaleKey) @@ -92,7 +92,7 @@ func Test_routesImpl_ScaleKey(t *testing.T) { t.Run("scaleKey is set when provided", func(t *testing.T) { r := NewRoutes() - r.CreateMapping("mc.example.com", "proxy:25577", "10.0.0.5:25565", nil, nil, "") + r.CreateMapping("mc.example.com", "proxy:25577", "10.0.0.5:25565", nil, nil, "", "") backend, _, scaleKey, _, _ := r.FindBackendForServerAddress(context.Background(), "mc.example.com") assert.Equal(t, "proxy:25577", backend) @@ -108,8 +108,8 @@ func Test_routesImpl_ScaleKey(t *testing.T) { } // Two routes with same proxy backend but different scaleKeys - r.CreateMapping("mc1.example.com", "proxy:25577", "10.0.0.1:25565", nil, sleeper, "") - r.CreateMapping("mc2.example.com", "proxy:25577", "10.0.0.2:25565", nil, nil, "") + r.CreateMapping("mc1.example.com", "proxy:25577", "10.0.0.1:25565", nil, sleeper, "", "") + r.CreateMapping("mc2.example.com", "proxy:25577", "10.0.0.2:25565", nil, nil, "", "") sleepers := r.GetSleepers("10.0.0.1:25565") require.Len(t, sleepers, 1) @@ -127,7 +127,7 @@ func Test_routesImpl_ScaleKey(t *testing.T) { t.Run("default route scaleKey", func(t *testing.T) { r := NewRoutes() - r.SetDefaultRoute("proxy:25577", "10.0.0.5:25565", nil, nil, "") + r.SetDefaultRoute("proxy:25577", "10.0.0.5:25565", nil, nil, "", "") backend, scaleKey, _, _ := r.GetDefaultRoute() assert.Equal(t, "proxy:25577", backend) @@ -136,10 +136,21 @@ func Test_routesImpl_ScaleKey(t *testing.T) { t.Run("default route scaleKey defaults to backend", func(t *testing.T) { r := NewRoutes() - r.SetDefaultRoute("backend:25565", "", nil, nil, "") + r.SetDefaultRoute("backend:25565", "", nil, nil, "", "") backend, scaleKey, _, _ := r.GetDefaultRoute() assert.Equal(t, "backend:25565", backend) assert.Equal(t, "backend:25565", scaleKey) }) } + +func Test_routesImpl_LoadingMOTD(t *testing.T) { + r := NewRoutes() + r.CreateMapping("mc.example.com", "backend:25565", "", nil, nil, "asleep", "loading") + + assert.Equal(t, "loading", r.GetLoadingMOTD("mc.example.com")) + assert.Equal(t, "", r.GetLoadingMOTD("other.example.com")) + + r.SetDefaultRoute("default:25565", "", nil, nil, "default asleep", "default loading") + assert.Equal(t, "default loading", r.GetLoadingMOTD("")) +} diff --git a/server/server.go b/server/server.go index e412ea5..5902514 100644 --- a/server/server.go +++ b/server/server.go @@ -69,7 +69,7 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) { Routes.RegisterAll(config.Mapping) if config.Default != "" { - Routes.SetDefaultRoute(config.Default, "", nil, nil, "") + Routes.SetDefaultRoute(config.Default, "", nil, nil, "", "") } if config.ConnectionRateLimit < 1 { @@ -83,6 +83,7 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) { autoScaleAllowDenyConfig) connector.UseAsleepMOTD(config.AutoScale.AsleepMOTD) + connector.UseLoadingMOTD(config.AutoScale.LoadingMOTD) clientFilter, err := NewClientFilter(config.ClientsToAllow, config.ClientsToDeny) if err != nil {