From bc81e03f19a5cf1d7707346f68c19dee1050b118 Mon Sep 17 00:00:00 2001 From: Samuel McBroom Date: Fri, 2 May 2025 16:12:53 -0700 Subject: [PATCH] Add auto scale down option (#405) --- README.md | 20 ++++--- cmd/mc-router/main.go | 35 +++++++---- cmd/mc-router/metrics.go | 50 +++++++++------- server/connector.go | 126 +++++++++++++++++++++++++++++++-------- server/docker.go | 31 ++++++++-- server/docker_swarm.go | 29 +++++++-- server/down_scaler.go | 96 +++++++++++++++++++++++++++++ server/k8s.go | 46 ++++++++------ server/k8s_test.go | 13 ++-- server/routes.go | 32 ++++++---- server/routes_test.go | 4 +- 11 files changed, 373 insertions(+), 109 deletions(-) create mode 100644 server/down_scaler.go diff --git a/README.md b/README.md index ea98ed4..3288c37 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,12 @@ Routes Minecraft client connections to backend servers based upon the requested The host:port bound for servicing API requests (env API_BINDING) -auto-scale-up Increase Kubernetes StatefulSet Replicas (only) from 0 to 1 on respective backend servers when accessed (env AUTO_SCALE_UP) + -auto-scale-down + Decrease Kubernetes StatefulSet Replicas (only) from 1 to 0 after all backend connections have stopped and a configurable amount of delay has passed (env AUTO_SCALE_DOWN) + -auto-scale-down-after + String indicating how long an auto scale down should wait before scaling down a backend server. If a player rejoins the server during this delay, the scale down will be canceled (env AUTO_SCALE_DOWN_AFTER) + -auto-scale-allow-deny string + Path to config for server allowlists and denylists. If a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up when -auto-scale-up is enabled or cancel active down scalers when -auto-scale-down is enabled (env AUTO_SCALE_ALLOW_DENY) -clients-to-allow value Zero or more client IP addresses or CIDRs to allow. Takes precedence over deny. (env CLIENTS_TO_ALLOW) -clients-to-deny value @@ -80,8 +86,6 @@ Routes Minecraft client connections to backend servers based upon the requested If set, a POST request that contains connection status notifications will be sent to this HTTP address (env WEBHOOK_URL) -record-logins Log and generate metrics on player logins. Metrics only supported with influxdb or prometheus backend (env RECORD_LOGINS) - -auto-scale-up-allow-deny string - Path to config for server allowlists and denylists. If -auto-scale-up is enabled and a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up (env AUTO_SCALE_UP_ALLOW_DENY) ``` ## Docker Multi-Architecture Image @@ -172,9 +176,9 @@ The following shows a JSON file for routes config, where `default-server` can al } ``` -## Auto Scale Up Allow/Deny List +## Auto Scale Allow/Deny List -The allow/deny list configuration allows limiting which players can scale up servers when using the `-auto-scale-up` option or the `AUTO_SCALE_UP` env variable. Global allow/deny lists can be configured that apply to all backend servers, but server-specific lists can be added as well. There are a few important things to note about the configuration: +The allow/deny list configuration allows limiting which players can scale up servers when using the `-auto-scale-up` option (`AUTO_SCALE_UP` env variable) and which players can cancel an active down scaler when using the `-auto-scale-down` option (`AUTO_SCALE_DOWN` env variable). Global allow/deny lists can be configured that apply to all backend servers, but server-specific lists can be added as well. There are a few important things to note about the configuration: - The `mc-router` process will not automatically pick up changes to the config. If updates to the config are made, the router must be restarted. - Allowlists always take priority over denylists. This means if a player is included in a sever-specific allowlist and the global denylist, the player will still be considered allowed on that server. If a player is listed in both a global allowlist and denylist, the denylist entry will be ignored. - Player entries only require a `uuid` or `name`. Both will be checked if specified, but otherwise a `uuid` will take priority over a `name`. @@ -267,13 +271,13 @@ kubectl apply -f https://raw.githubusercontent.com/itzg/mc-router/master/docs/k8 * I extended the allowed node port range by adding `--service-node-port-range=25000-32767` to `/etc/kubernetes/manifests/kube-apiserver.yaml` -##### Auto Scale Up +##### Auto Scale Up/Down -The `-auto-scale-up` flag argument makes the router "wake up" any stopped backend servers, by changing `replicas: 0` to `replicas: 1`. +The `-auto-scale-up` flag argument makes the router "wake up" any stopped backend servers by changing `replicas: 0` to `replicas: 1`. The `-auto-scale-down` flag argument makes the router shut down any running backend servers with no active connections by changing `replicas: 1` to `replicas: 0`. The scale down will occur after a configurable (using the `-auto-scale-down-after` argument) waiting period, such as `10m` (10 minutes), `2h` (2 hours), etc. If any players connect to the server during this period the scale down will be canceled. It is recommended to set this value high enough so a temporary player disconnect will not immediately shut down the server (`1m` or higher). -This requires using `kind: StatefulSet` instead of `kind: Service` for the Minecraft backend servers. +Both options require using `kind: StatefulSet` instead of `kind: Service` for the Minecraft backend servers. -It also requires the `ClusterRole` to permit `get` + `update` for `statefulsets` & `statefulsets/scale`, +They also require the `ClusterRole` to permit `get` + `update` for `statefulsets` & `statefulsets/scale`, e.g. like this (or some equivalent more fine-grained one to only watch/list services+statefulsets, and only get+update scale): ```yaml diff --git a/cmd/mc-router/main.go b/cmd/mc-router/main.go index 23d062a..9bf7c95 100644 --- a/cmd/mc-router/main.go +++ b/cmd/mc-router/main.go @@ -33,6 +33,13 @@ type WebhookConfig struct { RequireUser bool `default:"false" usage:"Indicates if the webhook will only be called if a user is connecting rather than just server list/ping"` } +type AutoScale struct { + Up bool `usage:"Increase Kubernetes StatefulSet Replicas (only) from 0 to 1 on respective backend servers when accessed"` + Down bool `default:"false" usage:"Decrease Kubernetes StatefulSet Replicas (only) from 1 to 0 on respective backend servers after there are no connections"` + DownAfter string `default:"10m" usage:"Server scale down delay after there are no connections"` + AllowDeny string `usage:"Path to config for server allowlists and denylists. If a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up when -auto-scale-up is enabled or cancel active down scalers when -auto-scale-down is enabled"` +} + type Config struct { Port int `default:"25565" usage:"The [port] bound to listen for Minecraft client connections"` Default string `usage:"host:port of a default Minecraft server to use when mapping not found"` @@ -44,7 +51,6 @@ type Config struct { ConnectionRateLimit int `default:"1" usage:"Max number of connections to allow per second"` InKubeCluster bool `usage:"Use in-cluster Kubernetes config"` KubeConfig string `usage:"The path to a Kubernetes configuration file"` - AutoScaleUp bool `usage:"Increase Kubernetes StatefulSet Replicas (only) from 0 to 1 on respective backend servers when accessed"` InDocker bool `usage:"Use Docker service discovery"` InDockerSwarm bool `usage:"Use Docker Swarm service discovery"` DockerSocket string `default:"unix:///var/run/docker.sock" usage:"Path to Docker socket to use"` @@ -58,7 +64,7 @@ type Config struct { MetricsBackendConfig MetricsBackendConfig RoutesConfig string `usage:"Name or full path to routes config file"` NgrokToken string `usage:"If set, an ngrok tunnel will be established. It is HIGHLY recommended to pass as an environment variable."` - AutoScaleUpAllowDeny string `usage:"Path to config for server allowlists and denylists. If -auto-scale-up is enabled and a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up"` + AutoScale AutoScale ClientsToAllow []string `usage:"Zero or more client IP addresses or CIDRs to allow. Takes precedence over deny."` ClientsToDeny []string `usage:"Zero or more client IP addresses or CIDRs to deny. Ignored if any configured to allow"` @@ -111,9 +117,9 @@ func main() { defer pprof.StopCPUProfile() } - var autoScaleUpAllowDenyConfig *server.AllowDenyConfig = nil - if config.AutoScaleUpAllowDeny != "" { - autoScaleUpAllowDenyConfig, err = server.ParseAllowDenyConfig(config.AutoScaleUpAllowDeny) + var autoScaleAllowDenyConfig *server.AllowDenyConfig = nil + if config.AutoScale.AllowDeny != "" { + autoScaleAllowDenyConfig, err = server.ParseAllowDenyConfig(config.AutoScale.AllowDeny) if err != nil { logrus.WithError(err).Fatal("trying to parse autoscale up allow-deny-list file") } @@ -124,6 +130,15 @@ func main() { metricsBuilder := NewMetricsBuilder(config.MetricsBackend, &config.MetricsBackendConfig) + downScalerEnabled := config.AutoScale.Down && (config.InKubeCluster || config.KubeConfig != "") + downScalerDelay, err := time.ParseDuration(config.AutoScale.DownAfter) + if err != nil { + logrus.WithError(err).Fatal("Unable to parse auto scale down after duration") + } + // Only one instance should be created + server.DownScaler = server.NewDownScaler(ctx, downScalerEnabled, downScalerDelay) + + c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) @@ -152,7 +167,7 @@ func main() { trustedIpNets = append(trustedIpNets, ipNet) } - connector := server.NewConnector(metricsBuilder.BuildConnectorMetrics(), config.UseProxyProtocol, config.ReceiveProxyProtocol, trustedIpNets, config.RecordLogins, autoScaleUpAllowDenyConfig) + connector := server.NewConnector(metricsBuilder.BuildConnectorMetrics(), config.UseProxyProtocol, config.ReceiveProxyProtocol, trustedIpNets, config.RecordLogins, autoScaleAllowDenyConfig) clientFilter, err := server.NewClientFilter(config.ClientsToAllow, config.ClientsToDeny) if err != nil { @@ -185,14 +200,14 @@ func main() { } if config.InKubeCluster { - err = server.K8sWatcher.StartInCluster(config.AutoScaleUp) + err = server.K8sWatcher.StartInCluster(config.AutoScale.Up, config.AutoScale.Down) if err != nil { logrus.WithError(err).Fatal("Unable to start k8s integration") } else { defer server.K8sWatcher.Stop() } } else if config.KubeConfig != "" { - err := server.K8sWatcher.StartWithConfig(config.KubeConfig, config.AutoScaleUp) + err := server.K8sWatcher.StartWithConfig(config.KubeConfig, config.AutoScale.Up, config.AutoScale.Down) if err != nil { logrus.WithError(err).Fatal("Unable to start k8s integration") } else { @@ -201,7 +216,7 @@ func main() { } if config.InDocker { - err = server.DockerWatcher.Start(config.DockerSocket, config.DockerTimeout, config.DockerRefreshInterval) + err = server.DockerWatcher.Start(config.DockerSocket, config.DockerTimeout, config.DockerRefreshInterval, config.AutoScale.Up, config.AutoScale.Down) if err != nil { logrus.WithError(err).Fatal("Unable to start docker integration") } else { @@ -210,7 +225,7 @@ func main() { } if config.InDockerSwarm { - err = server.DockerSwarmWatcher.Start(config.DockerSocket, config.DockerTimeout, config.DockerRefreshInterval) + err = server.DockerSwarmWatcher.Start(config.DockerSocket, config.DockerTimeout, config.DockerRefreshInterval, config.AutoScale.Up, config.AutoScale.Down) if err != nil { logrus.WithError(err).Fatal("Unable to start docker swarm integration") } else { diff --git a/cmd/mc-router/metrics.go b/cmd/mc-router/metrics.go index 8552517..dbb5936 100644 --- a/cmd/mc-router/metrics.go +++ b/cmd/mc-router/metrics.go @@ -60,13 +60,14 @@ func (b expvarMetricsBuilder) Start(ctx context.Context) error { func (b expvarMetricsBuilder) BuildConnectorMetrics() *server.ConnectorMetrics { c := expvarMetrics.NewCounter("connections") return &server.ConnectorMetrics{ - Errors: expvarMetrics.NewCounter("errors").With("subsystem", "connector"), - BytesTransmitted: expvarMetrics.NewCounter("bytes"), - ConnectionsFrontend: c, - ConnectionsBackend: c, - ActiveConnections: expvarMetrics.NewGauge("active_connections"), - ServerActivePlayer: expvarMetrics.NewGauge("server_active_player"), - ServerLogins: expvarMetrics.NewCounter("server_logins"), + Errors: expvarMetrics.NewCounter("errors").With("subsystem", "connector"), + BytesTransmitted: expvarMetrics.NewCounter("bytes"), + ConnectionsFrontend: c, + ConnectionsBackend: c, + ActiveConnections: expvarMetrics.NewGauge("active_connections"), + ServerActivePlayer: expvarMetrics.NewGauge("server_active_player"), + ServerLogins: expvarMetrics.NewCounter("server_logins"), + ServerActiveConnections: expvarMetrics.NewGauge("server_active_connections"), } } @@ -80,13 +81,14 @@ func (b discardMetricsBuilder) Start(ctx context.Context) error { func (b discardMetricsBuilder) BuildConnectorMetrics() *server.ConnectorMetrics { return &server.ConnectorMetrics{ - Errors: discardMetrics.NewCounter(), - BytesTransmitted: discardMetrics.NewCounter(), - ConnectionsFrontend: discardMetrics.NewCounter(), - ConnectionsBackend: discardMetrics.NewCounter(), - ActiveConnections: discardMetrics.NewGauge(), - ServerActivePlayer: discardMetrics.NewGauge(), - ServerLogins: discardMetrics.NewCounter(), + Errors: discardMetrics.NewCounter(), + BytesTransmitted: discardMetrics.NewCounter(), + ConnectionsFrontend: discardMetrics.NewCounter(), + ConnectionsBackend: discardMetrics.NewCounter(), + ActiveConnections: discardMetrics.NewGauge(), + ServerActivePlayer: discardMetrics.NewGauge(), + ServerLogins: discardMetrics.NewCounter(), + ServerActiveConnections: discardMetrics.NewGauge(), } } @@ -131,13 +133,14 @@ func (b *influxMetricsBuilder) BuildConnectorMetrics() *server.ConnectorMetrics c := metrics.NewCounter("mc_router_connections") return &server.ConnectorMetrics{ - Errors: metrics.NewCounter("mc_router_errors"), - BytesTransmitted: metrics.NewCounter("mc_router_transmitted_bytes"), - ConnectionsFrontend: c.With("side", "frontend"), - ConnectionsBackend: c.With("side", "backend"), - ActiveConnections: metrics.NewGauge("mc_router_connections_active"), - ServerActivePlayer: metrics.NewGauge("mc_router_server_player_active"), - ServerLogins: metrics.NewCounter("mc_router_server_logins"), + Errors: metrics.NewCounter("mc_router_errors"), + BytesTransmitted: metrics.NewCounter("mc_router_transmitted_bytes"), + ConnectionsFrontend: c.With("side", "frontend"), + ConnectionsBackend: c.With("side", "backend"), + ActiveConnections: metrics.NewGauge("mc_router_connections_active"), + ServerActivePlayer: metrics.NewGauge("mc_router_server_player_active"), + ServerLogins: metrics.NewCounter("mc_router_server_logins"), + ServerActiveConnections: metrics.NewGauge("mc_router_server_active_connections"), } } @@ -194,5 +197,10 @@ func (b prometheusMetricsBuilder) BuildConnectorMetrics() *server.ConnectorMetri Name: "server_logins", Help: "The total number of player logins", }, []string{"player_name", "player_uuid", "server_address"})), + ServerActiveConnections: prometheusMetrics.NewGauge(promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "mc_router", + Name: "server_active_connections", + Help: "The number of active connections per server", + }, []string{"server_address"})), } } diff --git a/server/connector.go b/server/connector.go index 687a8f7..d805462 100644 --- a/server/connector.go +++ b/server/connector.go @@ -30,13 +30,14 @@ const ( var noDeadline time.Time type ConnectorMetrics struct { - Errors metrics.Counter - BytesTransmitted metrics.Counter - ConnectionsFrontend metrics.Counter - ConnectionsBackend metrics.Counter - ActiveConnections metrics.Gauge - ServerActivePlayer metrics.Gauge - ServerLogins metrics.Counter + Errors metrics.Counter + BytesTransmitted metrics.Counter + ConnectionsFrontend metrics.Counter + ConnectionsBackend metrics.Counter + ActiveConnections metrics.Gauge + ServerActivePlayer metrics.Gauge + ServerLogins metrics.Counter + ServerActiveConnections metrics.Gauge } type ClientInfo struct { @@ -67,7 +68,48 @@ func (p *PlayerInfo) String() string { return fmt.Sprintf("%s/%s", p.Name, p.Uuid) } +type ServerMetrics struct { + sync.RWMutex + activeConnections map[string]int +} + +func NewServerMetrics() *ServerMetrics { + return &ServerMetrics{ + activeConnections: make(map[string]int), + } +} + +func (sm *ServerMetrics) IncrementActiveConnections(serverAddress string) { + sm.Lock() + defer sm.Unlock() + if _, ok := sm.activeConnections[serverAddress]; !ok { + sm.activeConnections[serverAddress] = 1 + return + } + sm.activeConnections[serverAddress] += 1 +} + +func (sm *ServerMetrics) DecrementActiveConnections(serverAddress string) { + sm.Lock() + defer sm.Unlock() + if activeConnections, ok := sm.activeConnections[serverAddress]; ok && activeConnections <= 0 { + sm.activeConnections[serverAddress] = 0 + return + } + sm.activeConnections[serverAddress] -= 1 +} + +func (sm *ServerMetrics) ActiveConnectionsValue(serverAddress string) int { + sm.Lock() + defer sm.Unlock() + if activeConnections, ok := sm.activeConnections[serverAddress]; ok { + return activeConnections + } + return 0 +} + func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet, recordLogins bool, autoScaleUpAllowDenyConfig *AllowDenyConfig) *Connector { + return &Connector{ metrics: metrics, sendProxyProto: sendProxyProto, @@ -76,6 +118,7 @@ func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyPr trustedProxyNets: trustedProxyNets, recordLogins: recordLogins, autoScaleUpAllowDenyConfig: autoScaleUpAllowDenyConfig, + serverMetrics: NewServerMetrics(), } } @@ -88,6 +131,7 @@ type Connector struct { trustedProxyNets []*net.IPNet activeConnections int32 + serverMetrics *ServerMetrics connectionsCond *sync.Cond ngrokToken string clientFilter *ClientFilter @@ -348,10 +392,48 @@ func (c *Connector) readPlayerInfo(bufferedReader *bufio.Reader, clientAddr net. } } +func (c *Connector) cleanupBackendConnection(ctx context.Context, clientAddr net.Addr, serverAddress string, playerInfo *PlayerInfo, backendHostPort string, cleanupMetrics bool, checkScaleDown bool) { + if c.connectionNotifier != nil { + err := c.connectionNotifier.NotifyDisconnected(ctx, clientAddr, serverAddress, playerInfo, backendHostPort) + if err != nil { + logrus.WithError(err).Warn("failed to notify disconnected") + } + } + + if cleanupMetrics { + c.metrics.ActiveConnections.Set(float64( + atomic.AddInt32(&c.activeConnections, -1))) + + c.serverMetrics.DecrementActiveConnections(serverAddress) + c.metrics.ServerActiveConnections. + With("server_address", serverAddress). + Set(float64(c.serverMetrics.ActiveConnectionsValue(serverAddress))) + + if c.recordLogins && playerInfo != nil { + c.metrics.ServerActivePlayer. + With("player_name", playerInfo.Name). + With("player_uuid", playerInfo.Uuid.String()). + With("server_address", serverAddress). + Set(0) + } + } + if checkScaleDown && c.serverMetrics.ActiveConnectionsValue(serverAddress) <= 0 { + DownScaler.Begin(serverAddress) + } + c.connectionsCond.Signal() +} + func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.Conn, clientAddr net.Addr, preReadContent io.Reader, serverAddress string, playerInfo *PlayerInfo, nextState mcproto.State) { - backendHostPort, resolvedHost, waker := Routes.FindBackendForServerAddress(ctx, serverAddress) + backendHostPort, resolvedHost, waker, _ := Routes.FindBackendForServerAddress(ctx, serverAddress) + cleanupMetrics := false + cleanupCheckScaleDown := false + + defer func() { + c.cleanupBackendConnection(ctx, clientAddr, serverAddress, playerInfo, backendHostPort, cleanupMetrics, cleanupCheckScaleDown) + }() + if waker != nil && nextState > mcproto.StateStatus { serverAllowsPlayer := c.autoScaleUpAllowDenyConfig.ServerAllowsPlayer(serverAddress, playerInfo) logrus. @@ -361,6 +443,9 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. WithField("serverAllowsPlayer", serverAllowsPlayer). Debug("checked if player is allowed to wake up the server") if serverAllowsPlayer { + // Cancel down scaler if active before scale up + DownScaler.Cancel(serverAddress) + cleanupCheckScaleDown = true if err := waker(ctx); err != nil { logrus.WithFields(logrus.Fields{"serverAddress": serverAddress}).WithError(err).Error("failed to wake up backend") c.metrics.Errors.With("type", "wakeup_failed").Add(1) @@ -426,6 +511,12 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. c.metrics.ActiveConnections.Set(float64( atomic.AddInt32(&c.activeConnections, 1))) + + c.serverMetrics.IncrementActiveConnections(serverAddress) + c.metrics.ServerActiveConnections. + With("server_address", serverAddress). + Set(float64(c.serverMetrics.ActiveConnectionsValue(serverAddress))) + if c.recordLogins && playerInfo != nil { logrus. WithField("client", clientAddr). @@ -446,24 +537,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. Add(1) } - defer func() { - if c.connectionNotifier != nil { - err := c.connectionNotifier.NotifyDisconnected(ctx, clientAddr, serverAddress, playerInfo, backendHostPort) - if err != nil { - logrus.WithError(err).Warn("failed to notify disconnected") - } - } - c.metrics.ActiveConnections.Set(float64( - atomic.AddInt32(&c.activeConnections, -1))) - if c.recordLogins && playerInfo != nil { - c.metrics.ServerActivePlayer. - With("player_name", playerInfo.Name). - With("player_uuid", playerInfo.Uuid.String()). - With("server_address", serverAddress). - Set(0) - } - c.connectionsCond.Signal() - }() + cleanupMetrics = true // PROXY protocol implementation if c.sendProxyProto { diff --git a/server/docker.go b/server/docker.go index 87688c4..2e744fb 100644 --- a/server/docker.go +++ b/server/docker.go @@ -15,7 +15,7 @@ import ( ) type IDockerWatcher interface { - Start(socket string, timeoutSeconds int, refreshIntervalSeconds int) error + Start(socket string, timeoutSeconds int, refreshIntervalSeconds int, autoScaleUp bool, autoScaleDown bool) error Stop() } @@ -31,19 +31,38 @@ var DockerWatcher IDockerWatcher = &dockerWatcherImpl{} type dockerWatcherImpl struct { sync.RWMutex + autoScaleUp bool + autoScaleDown bool client *client.Client contextCancel context.CancelFunc } -func (w *dockerWatcherImpl) makeWakerFunc(_ *routableContainer) func(ctx context.Context) error { +func (w *dockerWatcherImpl) makeWakerFunc(_ *routableContainer) ScalerFunc { + if !w.autoScaleUp { + return nil + } return func(ctx context.Context) error { + logrus.Fatal("Auto scale up is not yet supported for docker") return nil } } -func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshIntervalSeconds int) error { +func (w *dockerWatcherImpl) makeSleeperFunc(_ *routableContainer) ScalerFunc { + if !w.autoScaleDown { + return nil + } + return func(ctx context.Context) error { + logrus.Fatal("Auto scale down is not yet supported for docker") + return nil + } +} + +func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshIntervalSeconds int, autoScaleUp bool, autoScaleDown bool) error { var err error + w.autoScaleUp = autoScaleUp + w.autoScaleDown = autoScaleDown + timeout := time.Duration(timeoutSeconds) * time.Second refreshInterval := time.Duration(refreshIntervalSeconds) * time.Second @@ -75,7 +94,7 @@ func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshInte for _, c := range initialContainers { containerMap[c.externalContainerName] = c if c.externalContainerName != "" { - Routes.CreateMapping(c.externalContainerName, c.containerEndpoint, w.makeWakerFunc(c)) + Routes.CreateMapping(c.externalContainerName, c.containerEndpoint, w.makeWakerFunc(c), w.makeSleeperFunc(c)) } else { Routes.SetDefaultRoute(c.containerEndpoint) } @@ -97,7 +116,7 @@ func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshInte containerMap[rs.externalContainerName] = rs logrus.WithField("routableContainer", rs).Debug("ADD") if rs.externalContainerName != "" { - Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs)) + Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs)) } else { Routes.SetDefaultRoute(rs.containerEndpoint) } @@ -105,7 +124,7 @@ func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshInte containerMap[rs.externalContainerName] = rs if rs.externalContainerName != "" { Routes.DeleteMapping(rs.externalContainerName) - Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs)) + Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs)) } else { Routes.SetDefaultRoute(rs.containerEndpoint) } diff --git a/server/docker_swarm.go b/server/docker_swarm.go index e056e66..fe2c6bd 100644 --- a/server/docker_swarm.go +++ b/server/docker_swarm.go @@ -23,19 +23,38 @@ var DockerSwarmWatcher IDockerWatcher = &dockerSwarmWatcherImpl{} type dockerSwarmWatcherImpl struct { sync.RWMutex + autoScaleUp bool + autoScaleDown bool client *client.Client contextCancel context.CancelFunc } -func (w *dockerSwarmWatcherImpl) makeWakerFunc(_ *routableService) func(ctx context.Context) error { +func (w *dockerSwarmWatcherImpl) makeWakerFunc(_ *routableService) ScalerFunc { + if !w.autoScaleUp { + return nil + } return func(ctx context.Context) error { + logrus.Fatal("Auto scale up is not yet supported for docker swarm") return nil } } -func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refreshIntervalSeconds int) error { +func (w *dockerSwarmWatcherImpl) makeSleeperFunc(_ *routableService) ScalerFunc { + if !w.autoScaleDown { + return nil + } + return func(ctx context.Context) error { + logrus.Fatal("Auto scale down is not yet supported for docker swarm") + return nil + } +} + +func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refreshIntervalSeconds int, autoScaleUp bool, autoScaleDown bool) error { var err error + w.autoScaleUp = autoScaleUp + w.autoScaleDown = autoScaleDown + timeout := time.Duration(timeoutSeconds) * time.Second refreshInterval := time.Duration(refreshIntervalSeconds) * time.Second @@ -67,7 +86,7 @@ func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refres for _, s := range initialServices { serviceMap[s.externalServiceName] = s if s.externalServiceName != "" { - Routes.CreateMapping(s.externalServiceName, s.containerEndpoint, w.makeWakerFunc(s)) + Routes.CreateMapping(s.externalServiceName, s.containerEndpoint, w.makeWakerFunc(s), w.makeSleeperFunc(s)) } else { Routes.SetDefaultRoute(s.containerEndpoint) } @@ -89,7 +108,7 @@ func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refres serviceMap[rs.externalServiceName] = rs logrus.WithField("routableService", rs).Debug("ADD") if rs.externalServiceName != "" { - Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs)) + Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs)) } else { Routes.SetDefaultRoute(rs.containerEndpoint) } @@ -97,7 +116,7 @@ func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refres serviceMap[rs.externalServiceName] = rs if rs.externalServiceName != "" { Routes.DeleteMapping(rs.externalServiceName) - Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs)) + Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs)) } else { Routes.SetDefaultRoute(rs.containerEndpoint) } diff --git a/server/down_scaler.go b/server/down_scaler.go new file mode 100644 index 0000000..91c6da6 --- /dev/null +++ b/server/down_scaler.go @@ -0,0 +1,96 @@ +package server + +import ( + "context" + "sync" + "time" + + "github.com/sirupsen/logrus" +) + +type IDownScaler interface { + Reset() + Begin(serverAddress string) + Cancel(serverAddress string) +} + +var DownScaler IDownScaler + +func NewDownScaler(ctx context.Context, enabled bool, delay time.Duration) IDownScaler { + ds := &downScalerImpl{ + enabled: enabled, + delay: delay, + parentContext: ctx, + contextCancellations: make(map[string]context.CancelFunc), + } + + return ds +} + +type downScalerImpl struct { + sync.RWMutex + enabled bool + delay time.Duration + parentContext context.Context + contextCancellations map[string]context.CancelFunc +} + +func (ds *downScalerImpl) Reset() { + // Cancel all existing scale down routines + for _, scaleDownCancel := range ds.contextCancellations { + scaleDownCancel() + } + ds.contextCancellations = make(map[string]context.CancelFunc) +} + +func (ds *downScalerImpl) Begin(serverAddress string) { + ds.Lock() + defer ds.Unlock() + + if !ds.enabled { + return + } + + // If an existing scale down routine exists, cancel it + if scaleDownCancel, ok := ds.contextCancellations[serverAddress]; ok { + scaleDownCancel() + } + + logrus.WithField("serverAddress", serverAddress).Debug("Beginning scale down") + scaleDownContext, scaleDownContextCancellation := context.WithCancel(ds.parentContext) + ds.contextCancellations[serverAddress] = scaleDownContextCancellation + go ds.scaleDown(scaleDownContext, serverAddress) +} + +func (ds *downScalerImpl) Cancel(serverAddress string) { + ds.Lock() + defer ds.Unlock() + + if !ds.enabled { + return + } + + if scaleDownContextCancellation, ok := ds.contextCancellations[serverAddress]; ok { + logrus.WithField("serverAddress", serverAddress).Debug("Canceling scale down") + scaleDownContextCancellation() + delete(ds.contextCancellations, serverAddress) + } +} + +func (ds *downScalerImpl) scaleDown(ctx context.Context, serverAddress string) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(ds.delay): + _, _, _, sleeper := Routes.FindBackendForServerAddress(ctx, serverAddress) + if sleeper == nil { + return + } + if err := sleeper(ctx); err != nil { + logrus.WithField("serverAddress", serverAddress).WithError(err).Error("failed to scale down backend") + } + return + } + } +} diff --git a/server/k8s.go b/server/k8s.go index 4d9ddf5..df4cc21 100644 --- a/server/k8s.go +++ b/server/k8s.go @@ -27,8 +27,8 @@ const ( ) type IK8sWatcher interface { - StartWithConfig(kubeConfigFile string, autoScaleUp bool) error - StartInCluster(autoScaleUp bool) error + StartWithConfig(kubeConfigFile string, autoScaleUp bool, autoScaleDown bool) error + StartInCluster(autoScaleUp bool, autoScaleDown bool) error Stop() } @@ -36,6 +36,8 @@ var K8sWatcher IK8sWatcher = &k8sWatcherImpl{} type k8sWatcherImpl struct { sync.RWMutex + autoScaleUp bool + autoScaleDown bool // The key in mappings is a Service, and the value the StatefulSet name mappings map[string]string @@ -43,26 +45,28 @@ type k8sWatcherImpl struct { stop chan struct{} } -func (w *k8sWatcherImpl) StartInCluster(autoScaleUp bool) error { +func (w *k8sWatcherImpl) StartInCluster(autoScaleUp bool, autoScaleDown bool) error { config, err := rest.InClusterConfig() if err != nil { return errors.Wrap(err, "Unable to load in-cluster config") } - return w.startWithLoadedConfig(config, autoScaleUp) + return w.startWithLoadedConfig(config, autoScaleUp, autoScaleDown) } -func (w *k8sWatcherImpl) StartWithConfig(kubeConfigFile string, autoScaleUp bool) error { +func (w *k8sWatcherImpl) StartWithConfig(kubeConfigFile string, autoScaleUp bool, autoScaleDown bool) error { config, err := clientcmd.BuildConfigFromFlags("", kubeConfigFile) if err != nil { return errors.Wrap(err, "Could not load kube config file") } - return w.startWithLoadedConfig(config, autoScaleUp) + return w.startWithLoadedConfig(config, autoScaleUp, autoScaleDown) } -func (w *k8sWatcherImpl) startWithLoadedConfig(config *rest.Config, autoScaleUp bool) error { +func (w *k8sWatcherImpl) startWithLoadedConfig(config *rest.Config, autoScaleUp bool, autoScaleDown bool) error { w.stop = make(chan struct{}, 1) + w.autoScaleUp = autoScaleUp + w.autoScaleDown = autoScaleDown clientset, err := kubernetes.NewForConfig(config) if err != nil { @@ -88,7 +92,7 @@ func (w *k8sWatcherImpl) startWithLoadedConfig(config *rest.Config, autoScaleUp go serviceController.Run(w.stop) w.mappings = make(map[string]string) - if autoScaleUp { + if autoScaleUp || autoScaleDown { _, statefulSetController := cache.NewInformer( cache.NewListWatchFromClient( clientset.AppsV1().RESTClient(), @@ -156,7 +160,7 @@ func (w *k8sWatcherImpl) handleUpdate(oldObj interface{}, newObj interface{}) { "new": newRoutableService, }).Debug("UPDATE") if newRoutableService.externalServiceName != "" { - Routes.CreateMapping(newRoutableService.externalServiceName, newRoutableService.containerEndpoint, newRoutableService.autoScaleUp) + Routes.CreateMapping(newRoutableService.externalServiceName, newRoutableService.containerEndpoint, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown) } else { Routes.SetDefaultRoute(newRoutableService.containerEndpoint) } @@ -187,7 +191,7 @@ func (w *k8sWatcherImpl) handleAdd(obj interface{}) { logrus.WithField("routableService", routableService).Debug("ADD") if routableService.externalServiceName != "" { - Routes.CreateMapping(routableService.externalServiceName, routableService.containerEndpoint, routableService.autoScaleUp) + Routes.CreateMapping(routableService.externalServiceName, routableService.containerEndpoint, routableService.autoScaleUp, routableService.autoScaleDown) } else { Routes.SetDefaultRoute(routableService.containerEndpoint) } @@ -204,7 +208,8 @@ func (w *k8sWatcherImpl) Stop() { type routableService struct { externalServiceName string containerEndpoint string - autoScaleUp func(ctx context.Context) error + autoScaleUp ScalerFunc + autoScaleDown ScalerFunc } // obj is expected to be a *v1.Service @@ -239,12 +244,19 @@ func (w *k8sWatcherImpl) buildDetails(service *core.Service, externalServiceName rs := &routableService{ externalServiceName: externalServiceName, containerEndpoint: net.JoinHostPort(clusterIp, port), - autoScaleUp: w.buildScaleUpFunction(service), + autoScaleUp: w.buildScaleFunction(service, 0, 1), + autoScaleDown: w.buildScaleFunction(service, 1, 0), } return rs } -func (w *k8sWatcherImpl) buildScaleUpFunction(service *core.Service) func(ctx context.Context) error { +func (w *k8sWatcherImpl) buildScaleFunction(service *core.Service, from int32, to int32) ScalerFunc { + if from <= to && !w.autoScaleUp { + return nil + } + if from >= to && !w.autoScaleDown { + return nil + } return func(ctx context.Context) error { serviceName := service.Name if statefulSetName, exists := w.mappings[serviceName]; exists { @@ -255,7 +267,7 @@ func (w *k8sWatcherImpl) buildScaleUpFunction(service *core.Service) func(ctx co "statefulSet": statefulSetName, "replicas": replicas, }).Debug("StatefulSet of Service Replicas") - if replicas == 0 { + if replicas == from { if _, err := w.clientset.AppsV1().StatefulSets(service.Namespace).UpdateScale(ctx, statefulSetName, &autoscaling.Scale{ ObjectMeta: meta.ObjectMeta{ Name: scale.Name, @@ -263,15 +275,15 @@ func (w *k8sWatcherImpl) buildScaleUpFunction(service *core.Service) func(ctx co UID: scale.UID, ResourceVersion: scale.ResourceVersion, }, - Spec: autoscaling.ScaleSpec{Replicas: 1}}, meta.UpdateOptions{}, + Spec: autoscaling.ScaleSpec{Replicas: to}}, meta.UpdateOptions{}, ); err == nil { logrus.WithFields(logrus.Fields{ "service": serviceName, "statefulSet": statefulSetName, "replicas": replicas, - }).Info("StatefulSet Replicas Autoscaled from 0 to 1 (wake up)") + }).Infof("StatefulSet Replicas Autoscaled from %d to %d", from, to) } else { - return errors.Wrap(err, "UpdateScale for Replicas=1 failed for StatefulSet: "+statefulSetName) + return errors.Wrapf(err, "UpdateScale for Replicas=%d failed for StatefulSet: %s", to, statefulSetName) } } } else { diff --git a/server/k8s_test.go b/server/k8s_test.go index eb2e2f7..ac2d41e 100644 --- a/server/k8s_test.go +++ b/server/k8s_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -78,6 +79,8 @@ func TestK8sWatcherImpl_handleAddThenUpdate(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + // DownScaler needs to be instantiated + DownScaler = NewDownScaler(context.Background(), false, 1 * time.Second) Routes.Reset() watcher := &k8sWatcherImpl{} @@ -87,7 +90,7 @@ func TestK8sWatcherImpl_handleAddThenUpdate(t *testing.T) { watcher.handleAdd(&initialSvc) for _, s := range test.initial.scenarios { - backend, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given) + backend, _, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given) assert.Equal(t, s.expect, backend, "initial: given=%s", s.given) } @@ -97,7 +100,7 @@ func TestK8sWatcherImpl_handleAddThenUpdate(t *testing.T) { watcher.handleUpdate(&initialSvc, &updatedSvc) for _, s := range test.update.scenarios { - backend, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given) + backend, _, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given) assert.Equal(t, s.expect, backend, "update: given=%s", s.given) } }) @@ -149,6 +152,8 @@ func TestK8sWatcherImpl_handleAddThenDelete(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + // DownScaler needs to be instantiated + DownScaler = NewDownScaler(context.Background(), false, 1 * time.Second) Routes.Reset() watcher := &k8sWatcherImpl{} @@ -158,13 +163,13 @@ func TestK8sWatcherImpl_handleAddThenDelete(t *testing.T) { watcher.handleAdd(&initialSvc) for _, s := range test.initial.scenarios { - backend, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given) + backend, _, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given) assert.Equal(t, s.expect, backend, "initial: given=%s", s.given) } watcher.handleDelete(&initialSvc) for _, s := range test.delete { - backend, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given) + backend, _, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given) assert.Equal(t, s.expect, backend, "update: given=%s", s.given) } }) diff --git a/server/routes.go b/server/routes.go index f7402b9..f69425b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -12,6 +12,10 @@ import ( "github.com/sirupsen/logrus" ) +type ScalerFunc func(ctx context.Context) error + +var EmptyScalerFunc = func(ctx context.Context) error { return nil } + var tcpShieldPattern = regexp.MustCompile("///.*") func init() { @@ -70,7 +74,7 @@ func routesCreateHandler(writer http.ResponseWriter, request *http.Request) { return } - Routes.CreateMapping(definition.ServerAddress, definition.Backend, func(ctx context.Context) error { return nil }) + Routes.CreateMapping(definition.ServerAddress, definition.Backend, EmptyScalerFunc, EmptyScalerFunc) RoutesConfig.AddMapping(definition.ServerAddress, definition.Backend) writer.WriteHeader(http.StatusCreated) } @@ -102,10 +106,11 @@ type IRoutes interface { // FindBackendForServerAddress returns the host:port for the external server address, if registered. // Otherwise, an empty string is returned. Also returns the normalized version of the given serverAddress. // The 3rd value returned is an (optional) "waker" function which a caller must invoke to wake up serverAddress. - FindBackendForServerAddress(ctx context.Context, serverAddress string) (string, string, func(ctx context.Context) error) + // The 4th value returned is an (optional) "sleeper" function which a caller must invoke to shut down serverAddress. + FindBackendForServerAddress(ctx context.Context, serverAddress string) (string, string, ScalerFunc, ScalerFunc) GetMappings() map[string]string DeleteMapping(serverAddress string) bool - CreateMapping(serverAddress string, backend string, waker func(ctx context.Context) error) + CreateMapping(serverAddress string, backend string, waker ScalerFunc, sleeper ScalerFunc) SetDefaultRoute(backend string) SimplifySRV(srvEnabled bool) } @@ -122,13 +127,14 @@ func NewRoutes() IRoutes { func (r *routesImpl) RegisterAll(mappings map[string]string) { for k, v := range mappings { - r.CreateMapping(k, v, func(ctx context.Context) error { return nil }) + r.CreateMapping(k, v, EmptyScalerFunc, EmptyScalerFunc) } } type mapping struct { backend string - waker func(ctx context.Context) error + waker ScalerFunc + sleeper ScalerFunc } type routesImpl struct { @@ -140,6 +146,7 @@ type routesImpl struct { func (r *routesImpl) Reset() { r.mappings = make(map[string]mapping) + DownScaler.Reset() } func (r *routesImpl) SetDefaultRoute(backend string) { @@ -154,7 +161,7 @@ func (r *routesImpl) SimplifySRV(srvEnabled bool) { r.simplifySRV = srvEnabled } -func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddress string) (string, string, func(ctx context.Context) error) { +func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddress string) (string, string, ScalerFunc, ScalerFunc) { r.RLock() defer r.RUnlock() @@ -190,10 +197,10 @@ func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddres if r.mappings != nil { if mapping, exists := r.mappings[serverAddress]; exists { - return mapping.backend, serverAddress, mapping.waker + return mapping.backend, serverAddress, mapping.waker, mapping.sleeper } } - return r.defaultRoute, serverAddress, nil + return r.defaultRoute, serverAddress, nil, nil } func (r *routesImpl) GetMappings() map[string]string { @@ -212,6 +219,8 @@ func (r *routesImpl) DeleteMapping(serverAddress string) bool { defer r.Unlock() logrus.WithField("serverAddress", serverAddress).Info("Deleting route") + DownScaler.Cancel(serverAddress) + if _, ok := r.mappings[serverAddress]; ok { delete(r.mappings, serverAddress) return true @@ -220,7 +229,7 @@ func (r *routesImpl) DeleteMapping(serverAddress string) bool { } } -func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker func(ctx context.Context) error) { +func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker ScalerFunc, sleeper ScalerFunc) { r.Lock() defer r.Unlock() @@ -230,5 +239,8 @@ func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker f "serverAddress": serverAddress, "backend": backend, }).Info("Created route mapping") - r.mappings[serverAddress] = mapping{backend: backend, waker: waker} + r.mappings[serverAddress] = mapping{backend: backend, waker: waker, sleeper: sleeper} + + // Trigger auto scale down when mapping is created to ensure servers are shut down if router restarts + DownScaler.Begin(serverAddress) } diff --git a/server/routes_test.go b/server/routes_test.go index 8a9d34b..2001d33 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -66,9 +66,9 @@ func Test_routesImpl_FindBackendForServerAddress(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r := NewRoutes() - r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend, func(ctx context.Context) error { return nil }) + r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend, EmptyScalerFunc, EmptyScalerFunc) - if got, server, _ := r.FindBackendForServerAddress(context.Background(), tt.args.serverAddress); got != tt.want { + if got, server, _, _ := r.FindBackendForServerAddress(context.Background(), tt.args.serverAddress); got != tt.want { t.Errorf("routesImpl.FindBackendForServerAddress() = %v, want %v", got, tt.want) } else { assert.Equal(t, tt.mapping.serverAddress, server)