From 4dff00dda948da17fe44b15b8c544d11a2121c12 Mon Sep 17 00:00:00 2001 From: Lenart Kos <39205323+Lenart12@users.noreply.github.com> Date: Sat, 20 Dec 2025 20:31:34 +0100 Subject: [PATCH] Docker auto-scale and asleep motd status (#488) --- README.md | 38 ++- examples/docker-autoscale/compose-minimal.yml | 27 ++ examples/docker-autoscale/compose.yml | 50 ++++ mcproto/read.go | 24 +- mcproto/types.go | 31 +++ mcproto/write.go | 148 +++++++++++ server/api_server.go | 4 +- server/configs.go | 9 +- server/connector.go | 205 +++++++++++++-- server/docker.go | 241 ++++++++++++++---- server/docker_swarm.go | 28 +- server/down_scaler.go | 59 +++-- server/k8s.go | 52 +++- server/k8s_test.go | 28 +- server/routes.go | 91 +++++-- server/routes_config_loader.go | 7 +- server/routes_test.go | 2 +- server/server.go | 6 +- 18 files changed, 885 insertions(+), 165 deletions(-) create mode 100644 examples/docker-autoscale/compose-minimal.yml create mode 100644 examples/docker-autoscale/compose.yml create mode 100644 mcproto/write.go diff --git a/README.md b/README.md index bf2db64..bdfd225 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,8 @@ Some other features included: - Rate limits incoming connections to reduce DDoS attacks. - Can be configured to allow/deny IP addresses or ranges - Includes a webhook integration for notifying other systems when a player connects and disconnects from a server. -- Can auto-scale (between zero and one) backend servers deployed as Kubernetes StatefulSets. +- Can auto-scale (between zero and one) backend servers deployed as Kubernetes StatefulSets +- or start and stop backend servers running as docker containers. - Built-in ngrok integration where mc-router acts as an agent - Exports/exposes metrics for various Prometheus and InfluxDB. If enabled, includes player login metrics. @@ -25,12 +26,14 @@ Some other features included: The host:port bound for servicing API requests (env API_BINDING) -auto-scale-allow-deny string Path to config for server allowlists and denylists. If a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up when -auto-scale-up is enabled or cancel active down scalers when -auto-scale-down is enabled (env AUTO_SCALE_ALLOW_DENY) + -auto-scale-asleep-motd string + MOTD to display when auto-scaled down servers are accessed; if empty, no status will be served (env AUTO_SCALE_ASLEEP_MOTD) -auto-scale-down - Decrease Kubernetes StatefulSet Replicas (only) from 1 to 0 on respective backend servers after there are no connections (env AUTO_SCALE_DOWN) + Scale to zero after idle. For Kubernetes, decreases StatefulSet replicas from 1 to 0. For Docker, gracefully stops the container when there are no connections (env AUTO_SCALE_DOWN) -auto-scale-down-after string Server scale down delay after there are no connections (env AUTO_SCALE_DOWN_AFTER) (default "10m") -auto-scale-up - Increase Kubernetes StatefulSet Replicas (only) from 0 to 1 on respective backend servers when accessed (env AUTO_SCALE_UP) + Scale from zero on access. For Kubernetes, increases StatefulSet replicas from 0 to 1. For Docker, starts or unpauses the container when accessed (env AUTO_SCALE_UP) -clients-to-allow value Zero or more client IP addresses or CIDRs to allow. Takes precedence over deny. (env CLIENTS_TO_ALLOW) -clients-to-deny value @@ -169,6 +172,35 @@ These are the labels scanned: - `mc-router.port`: This value must be set to the port the Minecraft server is listening on. The default value is 25565. - `mc-router.default`: Set this to a truthy value to make this server the default backend. Please note that `mc-router.host` is still required to be set. - `mc-router.network`: Specify the network you are using for the router if multiple are present in the container/service. You can either use the network ID, it's full name or an alias. +- `mc-router.auto-scale-up`: Per-container override to enable/disable auto scale up for Docker. When true (or left unspecified and the global `-auto-scale-up` flag is enabled), mc-router will start or unpause this container when a client connects to the declared hostname(s). +- `mc-router.auto-scale-down`: Per-container override to enable/disable auto scale down for Docker. When true (or left unspecified and the global `-auto-scale-down` flag is enabled), mc-router will stop this container after it has been idle for the configured `-auto-scale-down-after` duration. +- `mc-router.auto-scale-asleep-motd`: Per-container override for MOTD to show when container is scaled to zero. If empty or not set the host will +appear unresponsive. + +#### Docker Auto Scale Up/Down + +To use scale-to-zero with Docker containers: + +- Start mc-router with Docker discovery and scaling enabled, for example: + + ```bash + docker run --rm \ + -p 25565:25565 \ + -v /var/run/docker.sock:/var/run/docker.sock:ro \ + itzg/mc-router \ + -in-docker -auto-scale-up -auto-scale-down -auto-scale-down-after=10m + ``` + +- Label each Minecraft container with at least `mc-router.host`. You can also set per-container autoscale overrides using `mc-router.auto-scale-up` and `mc-router.auto-scale-down` labels. + +For usage with docker compose refer to the [examples/docker-autoscale/compose.yml](examples/docker-autoscale/compose.yml) or [examples/docker-autoscale/compose-minimal.yml](examples/docker-autoscale/compose-minimal.yml) examples. + +Behavior: + +- When a client connects to a labeled hostname and the container is stopped or paused, mc-router will start/unpause it and wait until it becomes reachable (up to ~60s). +- When no clients remain connected and the idle timer elapses (`-auto-scale-down-after`), mc-router gracefully stops the container. + +Note: Docker Swarm discovery is supported; however, auto scale up/down is not yet supported for Swarm services. #### Example Docker deployment diff --git a/examples/docker-autoscale/compose-minimal.yml b/examples/docker-autoscale/compose-minimal.yml new file mode 100644 index 0000000..6c5ffaa --- /dev/null +++ b/examples/docker-autoscale/compose-minimal.yml @@ -0,0 +1,27 @@ +services: + router: + image: itzg/mc-router + environment: + IN_DOCKER: true + AUTO_SCALE_DOWN: true + AUTO_SCALE_UP: true + AUTO_SCALE_DOWN_AFTER: 2h + AUTO_SCALE_ASLEEP_MOTD: "Server is asleep. Join again to wake it up!" + ports: + - "25565:25565" + volumes: + - /var/run/docker.sock:/var/run/docker.sock:ro + vanilla: + image: itzg/minecraft-server + environment: + EULA: "TRUE" + labels: + mc-router.host: "vanilla.example.com" + paper: + image: itzg/minecraft-server + environment: + EULA: "TRUE" + TYPE: PAPER + labels: + mc-router.host: "paper.example.com" + diff --git a/examples/docker-autoscale/compose.yml b/examples/docker-autoscale/compose.yml new file mode 100644 index 0000000..296ae93 --- /dev/null +++ b/examples/docker-autoscale/compose.yml @@ -0,0 +1,50 @@ +# This is a verbose example with comments and explanations for configuring auto-scaling behavior +# for Docker backend servers. See compose-minimal.yml for a simple minimal example. +services: + router: + image: itzg/mc-router + environment: + IN_DOCKER: true + # Global auto-scaling settings for all docker-backend servers + # Settings can be overridden per-backend using labels + # as shown in the backend services below (except for AUTO_SCALE_DOWN_AFTER which is global only) + # Enable auto-scaling down after inactivity for all backends by default + AUTO_SCALE_DOWN: true + # Enable auto-scaling up after player join for all backends by default + AUTO_SCALE_UP: true + # Time of inactivity after which to scale down (default: 10m) - Global only setting + AUTO_SCALE_DOWN_AFTER: 2h + # MOTD to show when server is asleep (default: empty string - don't show MOTD, show server offline instead) + AUTO_SCALE_ASLEEP_MOTD: "Server is asleep. Join again to wake it up!" + ports: + - "25565:25565" + volumes: + - /var/run/docker.sock:/var/run/docker.sock:ro + vanilla: + image: itzg/minecraft-server + environment: + EULA: "TRUE" + labels: + # If global auto scaling settings are enabled, this backend will + # auto-scale without any additional auto-scale related configuration + mc-router.host: "vanilla.example.com" + fabric: + image: itzg/minecraft-server + environment: + EULA: "TRUE" + TYPE: FABRIC + labels: + mc-router.host: "fabric.example.com" + # Disable auto-scaling for this backend specifically + mc-router.auto-scale-up: false + mc-router.auto-scale-down: false + paper: + image: itzg/minecraft-server + environment: + EULA: "TRUE" + TYPE: PAPER + labels: + mc-router.host: "paper.example.com" + # Override asleep MOTD for this backend + mc-router.auto-scale-asleep-motd: "Paper is folded. Join to unfold!" + diff --git a/mcproto/read.go b/mcproto/read.go index 4dc7f64..fe1fadf 100644 --- a/mcproto/read.go +++ b/mcproto/read.go @@ -58,7 +58,29 @@ func ReadPacket(reader *bufio.Reader, addr net.Addr, state State) (*Packet, erro return nil, err } - packet.Data = remainder.Bytes() + // For status state, decode based on packet ID: + // - 0x00 Status Request: no payload + // - 0x01 Ping: 8-byte long payload + if state == StateStatus { + switch packet.PacketID { + case PacketIdStatusRequest: + // no payload + packet.Data = nil + case PacketIdPingRequest: + timestamp, err := ReadLong(remainder) + if err != nil { + return nil, err + } + packet.Data = &PingPayload{ + Timestamp: timestamp, + } + default: + // unknown in status state; keep raw + packet.Data = remainder.Bytes() + } + } else { + packet.Data = remainder.Bytes() + } logrus. WithField("client", addr). diff --git a/mcproto/types.go b/mcproto/types.go index 421883c..d486905 100644 --- a/mcproto/types.go +++ b/mcproto/types.go @@ -79,6 +79,13 @@ const ( PacketIdHandshake = 0x00 PacketIdLogin = 0x00 // during StateLogin PacketIdLegacyServerListPing = 0xFE + PacketIdStatusRequest = 0x00 + PacketIdPingRequest = 0x01 +) + +const ( + PacketIdStatusResponse = 0x00 + PackedIdPongResponse = 0x01 ) type Handshake struct { @@ -106,6 +113,30 @@ type LegacyServerListPing struct { ServerPort uint16 } +// StatusResponse is a minimal structure for the status JSON +type StatusResponse struct { + Version struct { + Name string `json:"name"` + Protocol int `json:"protocol"` + } `json:"version"` + Players struct { + Max int `json:"max"` + Online int `json:"online"` + Sample []struct { + Name string `json:"name"` + ID string `json:"id"` + } `json:"sample,omitempty"` + } `json:"players"` + Description map[string]interface{} `json:"description"` + Favicon string `json:"favicon,omitempty"` + EnforcesSecureChat *bool `json:"enforcesSecureChat,omitempty"` +} + +// PingPayload represents the status ping payload (packet 0x01) +type PingPayload struct { + Timestamp int64 +} + type ByteReader interface { ReadByte() (byte, error) } diff --git a/mcproto/write.go b/mcproto/write.go new file mode 100644 index 0000000..871a44c --- /dev/null +++ b/mcproto/write.go @@ -0,0 +1,148 @@ +package mcproto + +import ( + "bufio" + "bytes" + "encoding/binary" + "encoding/json" + "io" + "unicode/utf16" +) + +// WriteVarInt writes a VarInt (Minecraft format) to w +func WriteVarInt(w io.Writer, value int32) error { + var buf [5]byte + i := 0 + v := uint32(value) + for { + temp := byte(v & 0x7F) + v >>= 7 + if v != 0 { + temp |= 0x80 + } + buf[i] = temp + i++ + if v == 0 { + break + } + } + _, err := w.Write(buf[:i]) + return err +} + +// WriteString writes a Minecraft length-prefixed string +func WriteString(w io.Writer, s string) error { + if err := WriteVarInt(w, int32(len(s))); err != nil { + return err + } + _, err := io.WriteString(w, s) + return err +} + +// buildPacket builds a framed packet: [length VarInt][packetId VarInt][payload] +func buildPacket(packetID int32, payload []byte) []byte { + var b bytes.Buffer + _ = WriteVarInt(&b, packetID) + b.Write(payload) + + var framed bytes.Buffer + _ = WriteVarInt(&framed, int32(b.Len())) + framed.Write(b.Bytes()) + return framed.Bytes() +} + +// WriteStatusJSONPacket writes a Status Response (packet 0x00) with the provided JSON string +func WriteStatusJSONPacket(w io.Writer, jsonString string) error { + // payload is the JSON as a Minecraft string + var payload bytes.Buffer + if err := WriteString(&payload, jsonString); err != nil { + return err + } + pkt := buildPacket(PacketIdStatusResponse, payload.Bytes()) + _, err := w.Write(pkt) + return err +} + +// WriteStatusFromStruct writes a Status Response from a struct +func WriteStatusFromStruct(w io.Writer, status StatusResponse) error { + b, err := json.Marshal(status) + if err != nil { + return err + } + return WriteStatusJSONPacket(w, string(b)) +} + +// WritePongPacket writes Pong (packet 0x01) with the same payload +func WritePongPacket(w io.Writer, timestamp int64) error { + var pl bytes.Buffer + // payload is a signed long (64-bit) + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], uint64(timestamp)) + pl.Write(buf[:]) + pkt := buildPacket(PackedIdPongResponse, pl.Bytes()) + _, err := w.Write(pkt) + return err +} + +// WriteLegacySLPResponse writes the 1.6-compatible legacy response packet (0xFF) +// Format: FF, [length short], UTF16BE string beginning with "\u00A7\u0031\u0000" then null-delimited fields +// fields: protocol, version, motd, online, max +func WriteLegacySLPResponse(w io.Writer, protocol int, version string, motd string, online int, max int) error { + // Build the string with null separators + s := "\u00A7\u0031\u0000" + + intToString(protocol) + "\u0000" + + version + "\u0000" + + motd + "\u0000" + + intToString(online) + "\u0000" + + intToString(max) + + // Encode UTF-16BE + runes := []rune(s) + encoded := utf16.Encode(runes) + var be bytes.Buffer + for _, v := range encoded { + var tmp [2]byte + binary.BigEndian.PutUint16(tmp[:], v) + be.Write(tmp[:]) + } + + bw := bufio.NewWriter(w) + // 0xFF + if _, err := bw.Write([]byte{0xFF}); err != nil { + return err + } + // length short in code units + var lenBuf [2]byte + binary.BigEndian.PutUint16(lenBuf[:], uint16(len(encoded))) + if _, err := bw.Write(lenBuf[:]); err != nil { + return err + } + if _, err := bw.Write(be.Bytes()); err != nil { + return err + } + return bw.Flush() +} + +// helpers +func intToString(i int) string { + if i == 0 { + return "0" + } + neg := false + if i < 0 { + neg = true + i = -i + } + var buf [20]byte + pos := len(buf) + for i > 0 { + pos-- + buf[pos] = byte('0' + (i % 10)) + i /= 10 + } + if neg { + pos-- + buf[pos] = '-' + } + return string(buf[pos:]) +} diff --git a/server/api_server.go b/server/api_server.go index a0f55b1..9769d1d 100644 --- a/server/api_server.go +++ b/server/api_server.go @@ -81,7 +81,7 @@ func routesCreateHandler(writer http.ResponseWriter, request *http.Request) { return } - Routes.CreateMapping(definition.ServerAddress, definition.Backend, EmptyScalerFunc, EmptyScalerFunc) + Routes.CreateMapping(definition.ServerAddress, definition.Backend, nil, nil, "") RoutesConfigLoader.SaveRoutes() writer.WriteHeader(http.StatusCreated) } @@ -102,7 +102,7 @@ func routesSetDefault(writer http.ResponseWriter, request *http.Request) { return } - Routes.SetDefaultRoute(body.Backend) + Routes.SetDefaultRoute(body.Backend, nil, nil, "") RoutesConfigLoader.SaveRoutes() writer.WriteHeader(http.StatusOK) } diff --git a/server/configs.go b/server/configs.go index 6e1039d..4466469 100644 --- a/server/configs.go +++ b/server/configs.go @@ -6,10 +6,11 @@ type WebhookConfig struct { } type AutoScale struct { - Up bool `usage:"Increase Kubernetes StatefulSet Replicas (only) from 0 to 1 on respective backend servers when accessed"` - Down bool `default:"false" usage:"Decrease Kubernetes StatefulSet Replicas (only) from 1 to 0 on respective backend servers after there are no connections"` - DownAfter string `default:"10m" usage:"Server scale down delay after there are no connections"` - AllowDeny string `usage:"Path to config for server allowlists and denylists. If a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up when -auto-scale-up is enabled or cancel active down scalers when -auto-scale-down is enabled"` + Up bool `usage:"Scale from zero on access. For Kubernetes, increases StatefulSet replicas from 0 to 1. For Docker, starts or unpauses the container when accessed"` + Down bool `default:"false" usage:"Scale to zero after idle. For Kubernetes, decreases StatefulSet replicas from 1 to 0. For Docker, gracefully stops the container when there are no connections"` + DownAfter string `default:"10m" usage:"Server scale down delay after there are no connections"` + AllowDeny string `usage:"Path to config for server allowlists and denylists. If a global/server entry is specified, only players allowed to connect to the server will be able to trigger a scale up when -auto-scale-up is enabled or cancel active down scalers when -auto-scale-down is enabled"` + AsleepMOTD string `usage:"MOTD to display when auto-scaled down servers are accessed; if empty, no status will be served"` } type RoutesConfig struct { diff --git a/server/connector.go b/server/connector.go index 1cd3744..90a7e4a 100644 --- a/server/connector.go +++ b/server/connector.go @@ -38,30 +38,30 @@ func NewActiveConnections() *ActiveConnections { } } -func (sm *ActiveConnections) Increment(serverAddress string) { +func (sm *ActiveConnections) Increment(backendAddress string) { sm.Lock() defer sm.Unlock() - if _, ok := sm.activeConnections[serverAddress]; !ok { - sm.activeConnections[serverAddress] = 1 + if _, ok := sm.activeConnections[backendAddress]; !ok { + sm.activeConnections[backendAddress] = 1 return } - sm.activeConnections[serverAddress] += 1 + sm.activeConnections[backendAddress] += 1 } -func (sm *ActiveConnections) Decrement(serverAddress string) { +func (sm *ActiveConnections) Decrement(backendAddress string) { sm.Lock() defer sm.Unlock() - if activeConnections, ok := sm.activeConnections[serverAddress]; ok && activeConnections <= 0 { - sm.activeConnections[serverAddress] = 0 + if activeConnections, ok := sm.activeConnections[backendAddress]; ok && activeConnections <= 0 { + sm.activeConnections[backendAddress] = 0 return } - sm.activeConnections[serverAddress] -= 1 + sm.activeConnections[backendAddress] -= 1 } -func (sm *ActiveConnections) GetCount(serverAddress string) int { +func (sm *ActiveConnections) GetCount(backendAddress string) int { sm.Lock() defer sm.Unlock() - if activeConnections, ok := sm.activeConnections[serverAddress]; ok { + if activeConnections, ok := sm.activeConnections[backendAddress]; ok { return activeConnections } return 0 @@ -100,6 +100,7 @@ type Connector struct { clientFilter *ClientFilter autoScaleUpAllowDenyConfig *AllowDenyConfig connectionNotifier ConnectionNotifier + asleepMOTD string } func (c *Connector) UseConnectionNotifier(notifier ConnectionNotifier) { @@ -312,7 +313,7 @@ func (c *Connector) HandleConnection(frontendConn net.Conn) { Debug("Got user info") } - c.findAndConnectBackend(frontendConn, clientAddr, inspectionBuffer, handshake.ServerAddress, playerInfo, handshake.NextState) + c.findAndConnectBackend(frontendConn, clientAddr, inspectionBuffer, handshake.ServerAddress, playerInfo, handshake.NextState, false, int(handshake.ProtocolVersion)) } else if packet.PacketID == mcproto.PacketIdLegacyServerListPing { handshake, ok := packet.Data.(*mcproto.LegacyServerListPing) @@ -332,7 +333,7 @@ func (c *Connector) HandleConnection(frontendConn net.Conn) { serverAddress := handshake.ServerAddress - c.findAndConnectBackend(frontendConn, clientAddr, inspectionBuffer, serverAddress, nil, mcproto.StateStatus) + c.findAndConnectBackend(frontendConn, clientAddr, inspectionBuffer, serverAddress, nil, mcproto.StateStatus, true, 0) } else { logrus. WithField("client", clientAddr). @@ -343,6 +344,110 @@ func (c *Connector) HandleConnection(frontendConn net.Conn) { } } +// serveStatus writes a predefined status JSON and optionally handles ping/pong +func (c *Connector) serveStatus(frontendConn net.Conn, reader *bufio.Reader, serverAddress string, clientProtocol int) { + motd := Routes.GetAsleepMOTD(serverAddress) + if motd == "" { + motd = c.asleepMOTD + } + if motd == "" { + return + } + + // Consume Status Request (0x00) if present; some clients may send Ping (0x01) directly + _ = frontendConn.SetReadDeadline(time.Now().Add(3 * time.Second)) + firstPkt, err := mcproto.ReadPacket(reader, frontendConn.RemoteAddr(), mcproto.StateStatus) + var pingPending bool + var pingVal int64 + if err == nil && firstPkt != nil { + if firstPkt.PacketID == mcproto.PacketIdPingRequest { + if payload, ok := firstPkt.Data.(mcproto.PingPayload); ok { + pingPending = true + pingVal = payload.Timestamp + logrus.WithFields(logrus.Fields{ + "client": frontendConn.RemoteAddr(), + "ping_val": pingVal, + }).Debug("Predefined status: received immediate ping") + } + } + // else 0x00 is the normal status request; proceed to write response + } else if err != nil { + logrus.WithFields(logrus.Fields{ + "client": frontendConn.RemoteAddr(), + "error": err, + }).Warn("Predefined status: error reading initial status packet") + } + + // Build and write Status Response + viName, viProto := c.getVersionInfo(serverAddress, clientProtocol) + var status mcproto.StatusResponse + status.Version.Name = viName + status.Version.Protocol = viProto + status.Players.Max = 1 + status.Players.Online = 0 + status.Description = map[string]interface{}{"text": motd} + + // Write Status Response + _ = frontendConn.SetWriteDeadline(time.Now().Add(handshakeTimeout)) + if err := mcproto.WriteStatusFromStruct(frontendConn, status); err != nil { + logrus.WithError(err).Warn("Failed to write predefined status response") + return + } + + // If we didn't already get a ping, briefly wait for one + if !pingPending { + _ = frontendConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + if nextPkt, err2 := mcproto.ReadPacket(reader, frontendConn.RemoteAddr(), mcproto.StateStatus); err2 == nil && nextPkt != nil { + if nextPkt.PacketID == mcproto.PacketIdPingRequest { + if payload, ok := nextPkt.Data.(mcproto.PingPayload); ok { + pingPending = true + pingVal = payload.Timestamp + logrus.WithFields(logrus.Fields{ + "client": frontendConn.RemoteAddr(), + "ping_val": pingVal, + }).Debug("Predefined status: received ping after status") + } + } + } else if err2 != nil { + logrus.WithFields(logrus.Fields{ + "client": frontendConn.RemoteAddr(), + "error": err2, + }).Debug("Predefined status: error/timeout reading ping after status") + } + } + if pingPending { + if err := mcproto.WritePongPacket(frontendConn, pingVal); err != nil { + logrus.WithFields(logrus.Fields{ + "client": frontendConn.RemoteAddr(), + "error": err, + }).Warn("Predefined status: failed to write pong") + } else { + logrus.WithFields(logrus.Fields{ + "client": frontendConn.RemoteAddr(), + "ping_val": pingVal, + }).Debug("Predefined status: wrote pong") + } + } else { + logrus.WithFields(logrus.Fields{ + "client": frontendConn.RemoteAddr(), + }).Debug("Predefined status: no ping received, closing") + } +} + +// serveLegacyStatus writes a simple legacy SLP response and closes the connection +func (c *Connector) serveLegacyStatus(frontendConn net.Conn) { + motd := c.asleepMOTD + if motd == "" { + return + } + _ = frontendConn.SetWriteDeadline(time.Now().Add(handshakeTimeout)) + // 127 protocol for legacy response per spec; version name and motd from predefined JSON if available + // write a basic response: protocol=127, version="1.7+", motd, online=0, max=1 + if err := mcproto.WriteLegacySLPResponse(frontendConn, 127, "1.7+", motd, 0, 1); err != nil { + logrus.WithError(err).Warn("Failed to write legacy SLP response") + } +} + func (c *Connector) readPlayerInfo(protocolVersion mcproto.ProtocolVersion, bufferedReader *bufio.Reader, clientAddr net.Addr, state mcproto.State) (*PlayerInfo, error) { loginPacket, err := mcproto.ReadPacket(bufferedReader, clientAddr, state) if err != nil { @@ -375,10 +480,10 @@ func (c *Connector) cleanupBackendConnection(clientAddr net.Addr, serverAddress c.metrics.ActiveConnections.Set(float64( atomic.AddInt32(&c.totalActiveConnections, -1))) - c.activeConnections.Decrement(serverAddress) + c.activeConnections.Decrement(backendHostPort) c.metrics.ServerActiveConnections. With("server_address", serverAddress). - Set(float64(c.activeConnections.GetCount(serverAddress))) + Set(float64(c.activeConnections.GetCount(backendHostPort))) if c.recordLogins && playerInfo != nil { c.metrics.ServerActivePlayer. @@ -388,14 +493,19 @@ func (c *Connector) cleanupBackendConnection(clientAddr net.Addr, serverAddress Set(0) } } - if checkScaleDown && c.activeConnections.GetCount(serverAddress) <= 0 { - DownScaler.Begin(serverAddress) + logrus. + WithField("client", clientAddr). + WithField("backendHostPort", backendHostPort). + WithField("connectionCount", c.activeConnections.GetCount(backendHostPort)). + Info("Closed connection to backend") + if checkScaleDown && c.activeConnections.GetCount(backendHostPort) <= 0 { + DownScaler.Begin(backendHostPort) } c.connectionsCond.Signal() } func (c *Connector) findAndConnectBackend(frontendConn net.Conn, - clientAddr net.Addr, preReadContent io.Reader, serverAddress string, playerInfo *PlayerInfo, nextState mcproto.State) { + clientAddr net.Addr, preReadContent io.Reader, serverAddress string, playerInfo *PlayerInfo, nextState mcproto.State, isLegacy bool, clientProtocol int) { backendHostPort, resolvedHost, waker, _ := Routes.FindBackendForServerAddress(c.ctx, serverAddress) cleanupMetrics := false @@ -415,13 +525,29 @@ func (c *Connector) findAndConnectBackend(frontendConn net.Conn, Debug("checked if player is allowed to wake up the server") if serverAllowsPlayer { // Cancel down scaler if active before scale up - DownScaler.Cancel(serverAddress) + if backendHostPort != "" { + DownScaler.Cancel(backendHostPort) + } cleanupCheckScaleDown = true - if err := waker(c.ctx); err != nil { + logrus.WithField("serverAddress", serverAddress).Info("Waking up backend server") + newBackendHostPort, err := waker(c.ctx) + if err != nil { logrus.WithFields(logrus.Fields{"serverAddress": serverAddress}).WithError(err).Error("failed to wake up backend") c.metrics.Errors.With("type", "wakeup_failed").Add(1) return } + if newBackendHostPort == "" { + logrus.WithFields(logrus.Fields{"serverAddress": serverAddress}).Warn("waker did not return a backend address") + c.metrics.Errors.With("type", "wakeup_no_address").Add(1) + return + } + // Cancel again in case any routes were changed during wake up + DownScaler.Cancel(newBackendHostPort) + backendHostPort = newBackendHostPort + logrus.WithFields(logrus.Fields{ + "serverAddress": serverAddress, + "backendHostPort": backendHostPort, + }).Info("Woke up backend server") } } @@ -440,6 +566,22 @@ func (c *Connector) findAndConnectBackend(frontendConn net.Conn, } } + // If status request and configured, serve predefined response + if nextState == mcproto.StateStatus && Routes.HasRoute(serverAddress) { + logrus.WithFields(logrus.Fields{ + "client": clientAddr, + "server": serverAddress, + "isLegacy": isLegacy, + }).Debug("Missing backend: serving predefined status response") + + // Read Status Request and Ping directly from the client connection + br := bufio.NewReader(frontendConn) + if isLegacy { + c.serveLegacyStatus(frontendConn) + } else { + c.serveStatus(frontendConn, br, serverAddress, clientProtocol) + } + } return } @@ -483,10 +625,10 @@ func (c *Connector) findAndConnectBackend(frontendConn net.Conn, c.metrics.ActiveConnections.Set(float64( atomic.AddInt32(&c.totalActiveConnections, 1))) - c.activeConnections.Increment(serverAddress) + c.activeConnections.Increment(backendHostPort) c.metrics.ServerActiveConnections. With("server_address", serverAddress). - Set(float64(c.activeConnections.GetCount(serverAddress))) + Set(float64(c.activeConnections.GetCount(backendHostPort))) if c.recordLogins && playerInfo != nil { logrus. @@ -624,3 +766,24 @@ func (c *Connector) UseReceiveProxyProto(trustedProxyNets []*net.IPNet) { c.trustedProxyNets = trustedProxyNets c.receiveProxyProto = true } + +// UseAsleepMOTD configures a predefined MOTD to serve when backends are asleep +func (c *Connector) UseAsleepMOTD(motd string) { + c.asleepMOTD = motd +} + +// getVersionInfo falls back to client protocol and a derived name but in future +// could be extended to cache server-reported versions +func (c *Connector) getVersionInfo(_ string, clientProtocol int) (string, int) { + // no cache; use client protocol + return protocolToName(clientProtocol), clientProtocol +} + +// protocolToName maps protocol numbers to a friendly name; falls back to "1.7+" +func protocolToName(proto int) string { + switch proto { + // TODO: expand this mapping as needed + default: + return "1.7+" + } +} diff --git a/server/docker.go b/server/docker.go index 98879ce..ebbdf75 100644 --- a/server/docker.go +++ b/server/docker.go @@ -3,12 +3,12 @@ package server import ( "context" "fmt" + "net" "strconv" "strings" "sync" "time" - dockertypes "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" "github.com/sirupsen/logrus" @@ -19,10 +19,13 @@ type IDockerWatcher interface { } const ( - DockerRouterLabelHost = "mc-router.host" - DockerRouterLabelPort = "mc-router.port" - DockerRouterLabelDefault = "mc-router.default" - DockerRouterLabelNetwork = "mc-router.network" + DockerRouterLabelHost = "mc-router.host" + DockerRouterLabelPort = "mc-router.port" + DockerRouterLabelDefault = "mc-router.default" + DockerRouterLabelNetwork = "mc-router.network" + DockerRouterLabelAutoScaleUp = "mc-router.auto-scale-up" + DockerRouterLabelAutoScaleDown = "mc-router.auto-scale-down" + DockerRouterLabelAutoScaleAsleepMOTD = "mc-router.auto-scale-asleep-motd" ) type dockerWatcherConfig struct { @@ -63,22 +66,94 @@ type dockerWatcherImpl struct { client *client.Client } -func (w *dockerWatcherImpl) makeWakerFunc(_ *routableContainer) ScalerFunc { - if !w.config.autoScaleUp { +func (w *dockerWatcherImpl) makeWakerFunc(rc *routableContainer) WakerFunc { + if rc == nil || !rc.autoScaleUp { return nil } - return func(ctx context.Context) error { - logrus.Fatal("Auto scale up is not yet supported for docker") - return nil + return func(ctx context.Context) (string, error) { + containerID := rc.containerID + if containerID == "" { + return "", fmt.Errorf("missing container id for wake") + } + inspect, err := w.client.ContainerInspect(ctx, containerID) + if err != nil { + return "", err + } + if inspect.State == nil { + return "", fmt.Errorf("unable to determine container state") + } + // If paused, unpause; if not running, start; otherwise no-op + if inspect.State.Paused { + logrus.WithFields(logrus.Fields{"containerID": containerID}).Debug("Unpausing container for wake") + if err := w.client.ContainerUnpause(ctx, containerID); err != nil { + return "", err + } + } else if !inspect.State.Running { + logrus.WithFields(logrus.Fields{"containerID": containerID}).Debug("Starting container for wake") + if err := w.client.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { + return "", err + } + } + + inspect, err = w.client.ContainerInspect(ctx, containerID) + if err != nil { + return "", err + } + data, ok := w.parseContainerData(&inspect) + if !ok { + return "", fmt.Errorf("failed to parse container data after starting") + } + if data.ip == "" { + return "", fmt.Errorf("container has no accessible IP after starting") + } + endpoint := net.JoinHostPort(data.ip, strconv.Itoa(int(data.port))) + + // Wait until the container is reachable + deadline := time.Now().Add(60 * time.Second) + for { + conn, err := net.DialTimeout("tcp", endpoint, 1*time.Second) + if err == nil { + _ = conn.Close() + break + } + if ctx.Err() != nil { + return endpoint, ctx.Err() + } + if time.Now().After(deadline) { + return endpoint, fmt.Errorf("timeout waiting for container to become reachable at %s", endpoint) + } + select { + case <-ctx.Done(): + return endpoint, ctx.Err() + case <-time.After(500 * time.Millisecond): + } + } + + return endpoint, nil } } -func (w *dockerWatcherImpl) makeSleeperFunc(_ *routableContainer) ScalerFunc { - if !w.config.autoScaleDown { +func (w *dockerWatcherImpl) makeSleeperFunc(rc *routableContainer) SleeperFunc { + if rc == nil || !rc.autoScaleDown { return nil } return func(ctx context.Context) error { - logrus.Fatal("Auto scale down is not yet supported for docker") + containerID := rc.containerID + if containerID == "" { + return fmt.Errorf("missing container id for sleep") + } + inspect, err := w.client.ContainerInspect(ctx, containerID) + if err != nil { + return err + } + if inspect.State != nil && inspect.State.Running { + // Graceful stop with 60s timeout + timeout := 60 + logrus.WithFields(logrus.Fields{"containerID": containerID}).Debug("Stopping container for sleep") + if err := w.client.ContainerStop(ctx, containerID, container.StopOptions{Timeout: &timeout}); err != nil { + return err + } + } return nil } } @@ -104,7 +179,6 @@ func (w *dockerWatcherImpl) Start(ctx context.Context) error { } ticker := time.NewTicker(refreshInterval) - containerMap := map[string]*routableContainer{} logrus.Trace("Performing initial listing of Docker containers") initialContainers, err := w.listContainers(ctx) @@ -112,12 +186,15 @@ func (w *dockerWatcherImpl) Start(ctx context.Context) error { return err } + containerMap := map[string]*routableContainer{} for _, c := range initialContainers { containerMap[c.externalContainerName] = c + wakerFunc := w.makeWakerFunc(c) + sleeperFunc := w.makeSleeperFunc(c) if c.externalContainerName != "" { - Routes.CreateMapping(c.externalContainerName, c.containerEndpoint, w.makeWakerFunc(c), w.makeSleeperFunc(c)) + Routes.CreateMapping(c.externalContainerName, c.containerEndpoint, wakerFunc, sleeperFunc, c.autoScaleAsleepMOTD) } else { - Routes.SetDefaultRoute(c.containerEndpoint) + Routes.SetDefaultRoute(c.containerEndpoint, wakerFunc, sleeperFunc, c.autoScaleAsleepMOTD) } } @@ -137,18 +214,26 @@ func (w *dockerWatcherImpl) Start(ctx context.Context) error { if oldRs, ok := containerMap[rs.externalContainerName]; !ok { containerMap[rs.externalContainerName] = rs logrus.WithField("routableContainer", rs).Debug("ADD") + wakerFunc := w.makeWakerFunc(rs) + sleeperFunc := w.makeSleeperFunc(rs) if rs.externalContainerName != "" { - Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs)) + Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD) } else { - Routes.SetDefaultRoute(rs.containerEndpoint) + Routes.SetDefaultRoute(rs.containerEndpoint, wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD) } - } else if oldRs.containerEndpoint != rs.containerEndpoint { + } else if oldRs.containerEndpoint != rs.containerEndpoint || + oldRs.containerID != rs.containerID || + oldRs.autoScaleUp != rs.autoScaleUp || + oldRs.autoScaleDown != rs.autoScaleDown || + oldRs.autoScaleAsleepMOTD != rs.autoScaleAsleepMOTD { containerMap[rs.externalContainerName] = rs + wakerFunc := w.makeWakerFunc(rs) + sleeperFunc := w.makeSleeperFunc(rs) if rs.externalContainerName != "" { Routes.DeleteMapping(rs.externalContainerName) - Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs)) + Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD) } else { - Routes.SetDefaultRoute(rs.containerEndpoint) + Routes.SetDefaultRoute(rs.containerEndpoint, wakerFunc, sleeperFunc, rs.autoScaleAsleepMOTD) } logrus.WithFields(logrus.Fields{"old": oldRs, "new": rs}).Debug("UPDATE") } @@ -160,7 +245,7 @@ func (w *dockerWatcherImpl) Start(ctx context.Context) error { if rs.externalContainerName != "" { Routes.DeleteMapping(rs.externalContainerName) } else { - Routes.SetDefaultRoute("") + Routes.SetDefaultRoute("", nil, nil, "") } logrus.WithField("routableContainer", rs).Debug("DELETE") } @@ -179,28 +264,46 @@ func (w *dockerWatcherImpl) Start(ctx context.Context) error { } func (w *dockerWatcherImpl) listContainers(ctx context.Context) ([]*routableContainer, error) { - containers, err := w.client.ContainerList(ctx, container.ListOptions{}) + containers, err := w.client.ContainerList(ctx, container.ListOptions{All: true}) if err != nil { return nil, err } var result []*routableContainer for _, container := range containers { - data, ok := w.parseContainerData(&container) + inspect, err := w.client.ContainerInspect(ctx, container.ID) + if err != nil { + logrus.WithFields(logrus.Fields{"containerID": container.ID}).WithError(err).Error("Failed to inspect Docker container") + continue + } + data, ok := w.parseContainerData(&inspect) if !ok { continue } + endpoint := "" + if !data.notRunning { + endpoint = fmt.Sprintf("%s:%d", data.ip, data.port) + } + for _, host := range data.hosts { result = append(result, &routableContainer{ - containerEndpoint: fmt.Sprintf("%s:%d", data.ip, data.port), + containerEndpoint: endpoint, externalContainerName: host, + containerID: container.ID, + autoScaleUp: data.autoScaleUp, + autoScaleDown: data.autoScaleDown, + autoScaleAsleepMOTD: data.autoScaleAsleepMOTD, }) } if data.def != nil && *data.def { result = append(result, &routableContainer{ - containerEndpoint: fmt.Sprintf("%s:%d", data.ip, data.port), + containerEndpoint: endpoint, externalContainerName: "", + containerID: container.ID, + autoScaleUp: data.autoScaleUp, + autoScaleDown: data.autoScaleDown, + autoScaleAsleepMOTD: data.autoScaleAsleepMOTD, }) } } @@ -209,18 +312,24 @@ func (w *dockerWatcherImpl) listContainers(ctx context.Context) ([]*routableCont } type parsedDockerContainerData struct { - hosts []string - port uint64 - def *bool - network *string - ip string + hosts []string + port uint64 + def *bool + network *string + ip string + autoScaleDown bool + autoScaleUp bool + autoScaleAsleepMOTD string + notRunning bool } -func (w *dockerWatcherImpl) parseContainerData(container *dockertypes.Container) (data parsedDockerContainerData, ok bool) { - for key, value := range container.Labels { +func (w *dockerWatcherImpl) parseContainerData(container *container.InspectResponse) (data parsedDockerContainerData, ok bool) { + data.autoScaleUp = w.config.autoScaleUp + data.autoScaleDown = w.config.autoScaleDown + for key, value := range container.Config.Labels { if key == DockerRouterLabelHost { if data.hosts != nil { - logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Names}). + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). Warnf("ignoring container with duplicate %s label", DockerRouterLabelHost) return } @@ -229,14 +338,14 @@ func (w *dockerWatcherImpl) parseContainerData(container *dockertypes.Container) if key == DockerRouterLabelPort { if data.port != 0 { - logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Names}). + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). Warnf("ignoring container with duplicate %s label", DockerRouterLabelPort) return } var err error data.port, err = strconv.ParseUint(value, 10, 32) if err != nil { - logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Names}). + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). WithError(err). Warnf("ignoring container with invalid %s label", DockerRouterLabelPort) return @@ -244,24 +353,51 @@ func (w *dockerWatcherImpl) parseContainerData(container *dockertypes.Container) } if key == DockerRouterLabelDefault { if data.def != nil { - logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Names}). + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). Warnf("ignoring container with duplicate %s label", DockerRouterLabelDefault) return } - data.def = new(bool) - - lowerValue := strings.TrimSpace(strings.ToLower(value)) - *data.def = lowerValue != "" && lowerValue != "0" && lowerValue != "false" && lowerValue != "no" + defaultValue, err := strconv.ParseBool(strings.TrimSpace(value)) + if err != nil { + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). + WithError(err). + Warnf("ignoring container with invalid value for %s label", DockerRouterLabelDefault) + return + } + data.def = &defaultValue } if key == DockerRouterLabelNetwork { if data.network != nil { - logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Names}). + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). Warnf("ignoring container with duplicate %s label", DockerRouterLabelNetwork) return } data.network = new(string) *data.network = value } + if key == DockerRouterLabelAutoScaleUp { + autoScaleUp, err := strconv.ParseBool(strings.TrimSpace(value)) + if err != nil { + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). + WithError(err). + Warnf("ignoring container with invalid value for %s label", DockerRouterLabelAutoScaleUp) + return + } + data.autoScaleUp = autoScaleUp + } + if key == DockerRouterLabelAutoScaleDown { + autoScaleDown, err := strconv.ParseBool(strings.TrimSpace(value)) + if err != nil { + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). + WithError(err). + Warnf("ignoring container with invalid value for %s label", DockerRouterLabelAutoScaleDown) + return + } + data.autoScaleDown = autoScaleDown + } + if key == DockerRouterLabelAutoScaleAsleepMOTD { + data.autoScaleAsleepMOTD = value + } } // probably not minecraft related @@ -270,7 +406,7 @@ func (w *dockerWatcherImpl) parseContainerData(container *dockertypes.Container) } if len(container.NetworkSettings.Networks) == 0 { - logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Names}). + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). Warnf("ignoring container, no networks found") return } @@ -304,7 +440,7 @@ func (w *dockerWatcherImpl) parseContainerData(container *dockertypes.Container) // if there's more than one network on this container, we should require that the user specifies a network to avoid // weird problems. if len(container.NetworkSettings.Networks) > 1 { - logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Names}). + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). Warnf("ignoring container, multiple networks found and none specified using label %s", DockerRouterLabelNetwork) return } @@ -315,12 +451,21 @@ func (w *dockerWatcherImpl) parseContainerData(container *dockertypes.Container) } } - if data.ip == "" { - logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Names}). + if data.ip == "" && container.State != nil && container.State.Running { + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). Warnf("ignoring container, unable to find accessible ip address") return } + if container.State != nil && !container.State.Running { + if !w.config.autoScaleUp { + logrus.WithFields(logrus.Fields{"containerId": container.ID, "containerNames": container.Name}). + Warnf("ignoring container, not running and auto scale up is disabled") + return + } + data.notRunning = true + } + ok = true return @@ -329,4 +474,8 @@ func (w *dockerWatcherImpl) parseContainerData(container *dockertypes.Container) type routableContainer struct { externalContainerName string containerEndpoint string + containerID string + autoScaleUp bool + autoScaleDown bool + autoScaleAsleepMOTD string } diff --git a/server/docker_swarm.go b/server/docker_swarm.go index e462d1d..64ab55a 100644 --- a/server/docker_swarm.go +++ b/server/docker_swarm.go @@ -38,17 +38,17 @@ type dockerSwarmWatcherImpl struct { client *client.Client } -func (w *dockerSwarmWatcherImpl) makeWakerFunc(_ *routableService) ScalerFunc { +func (w *dockerSwarmWatcherImpl) makeWakerFunc(_ *routableService) WakerFunc { if !w.config.autoScaleUp { return nil } - return func(ctx context.Context) error { + return func(ctx context.Context) (string, error) { logrus.Fatal("Auto scale up is not yet supported for docker swarm") - return nil + return "", nil } } -func (w *dockerSwarmWatcherImpl) makeSleeperFunc(_ *routableService) ScalerFunc { +func (w *dockerSwarmWatcherImpl) makeSleeperFunc(_ *routableService) SleeperFunc { if !w.config.autoScaleDown { return nil } @@ -89,10 +89,12 @@ func (w *dockerSwarmWatcherImpl) Start(ctx context.Context) error { for _, s := range initialServices { serviceMap[s.externalServiceName] = s + wakerFunc := w.makeWakerFunc(s) + sleeperFunc := w.makeSleeperFunc(s) if s.externalServiceName != "" { - Routes.CreateMapping(s.externalServiceName, s.containerEndpoint, w.makeWakerFunc(s), w.makeSleeperFunc(s)) + Routes.CreateMapping(s.externalServiceName, s.containerEndpoint, wakerFunc, sleeperFunc, "") } else { - Routes.SetDefaultRoute(s.containerEndpoint) + Routes.SetDefaultRoute(s.containerEndpoint, wakerFunc, sleeperFunc, "") } } @@ -111,18 +113,22 @@ func (w *dockerSwarmWatcherImpl) Start(ctx context.Context) error { if oldRs, ok := serviceMap[rs.externalServiceName]; !ok { serviceMap[rs.externalServiceName] = rs logrus.WithField("routableService", rs).Debug("ADD") + wakerFunc := w.makeWakerFunc(rs) + sleeperFunc := w.makeSleeperFunc(rs) if rs.externalServiceName != "" { - Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs)) + Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, wakerFunc, sleeperFunc, "") } else { - Routes.SetDefaultRoute(rs.containerEndpoint) + Routes.SetDefaultRoute(rs.containerEndpoint, wakerFunc, sleeperFunc, "") } } else if oldRs.containerEndpoint != rs.containerEndpoint { serviceMap[rs.externalServiceName] = rs + wakerFunc := w.makeWakerFunc(rs) + sleeperFunc := w.makeSleeperFunc(rs) if rs.externalServiceName != "" { Routes.DeleteMapping(rs.externalServiceName) - Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs)) + Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, wakerFunc, sleeperFunc, "") } else { - Routes.SetDefaultRoute(rs.containerEndpoint) + Routes.SetDefaultRoute(rs.containerEndpoint, wakerFunc, sleeperFunc, "") } logrus.WithFields(logrus.Fields{"old": oldRs, "new": rs}).Debug("UPDATE") } @@ -134,7 +140,7 @@ func (w *dockerSwarmWatcherImpl) Start(ctx context.Context) error { if rs.externalServiceName != "" { Routes.DeleteMapping(rs.externalServiceName) } else { - Routes.SetDefaultRoute("") + Routes.SetDefaultRoute("", nil, nil, "") } logrus.WithField("routableService", rs).Debug("DELETE") } diff --git a/server/down_scaler.go b/server/down_scaler.go index 91c6da6..0df1007 100644 --- a/server/down_scaler.go +++ b/server/down_scaler.go @@ -10,17 +10,17 @@ import ( type IDownScaler interface { Reset() - Begin(serverAddress string) - Cancel(serverAddress string) + Begin(backendEndpoint string) + Cancel(backendEndpoint string) } var DownScaler IDownScaler func NewDownScaler(ctx context.Context, enabled bool, delay time.Duration) IDownScaler { ds := &downScalerImpl{ - enabled: enabled, - delay: delay, - parentContext: ctx, + enabled: enabled, + delay: delay, + parentContext: ctx, contextCancellations: make(map[string]context.CancelFunc), } @@ -43,7 +43,7 @@ func (ds *downScalerImpl) Reset() { ds.contextCancellations = make(map[string]context.CancelFunc) } -func (ds *downScalerImpl) Begin(serverAddress string) { +func (ds *downScalerImpl) Begin(backendEndpoint string) { ds.Lock() defer ds.Unlock() @@ -52,17 +52,17 @@ func (ds *downScalerImpl) Begin(serverAddress string) { } // If an existing scale down routine exists, cancel it - if scaleDownCancel, ok := ds.contextCancellations[serverAddress]; ok { + if scaleDownCancel, ok := ds.contextCancellations[backendEndpoint]; ok { scaleDownCancel() } - - logrus.WithField("serverAddress", serverAddress).Debug("Beginning scale down") + + logrus.WithField("backendEndpoint", backendEndpoint).Debug("Beginning scale down") scaleDownContext, scaleDownContextCancellation := context.WithCancel(ds.parentContext) - ds.contextCancellations[serverAddress] = scaleDownContextCancellation - go ds.scaleDown(scaleDownContext, serverAddress) + ds.contextCancellations[backendEndpoint] = scaleDownContextCancellation + go ds.scaleDown(scaleDownContext, backendEndpoint) } -func (ds *downScalerImpl) Cancel(serverAddress string) { +func (ds *downScalerImpl) Cancel(backendEndpoint string) { ds.Lock() defer ds.Unlock() @@ -70,27 +70,34 @@ func (ds *downScalerImpl) Cancel(serverAddress string) { return } - if scaleDownContextCancellation, ok := ds.contextCancellations[serverAddress]; ok { - logrus.WithField("serverAddress", serverAddress).Debug("Canceling scale down") + if scaleDownContextCancellation, ok := ds.contextCancellations[backendEndpoint]; ok { + logrus.WithField("backendEndpoint", backendEndpoint).Debug("Canceling scale down") scaleDownContextCancellation() - delete(ds.contextCancellations, serverAddress) + delete(ds.contextCancellations, backendEndpoint) } } -func (ds *downScalerImpl) scaleDown(ctx context.Context, serverAddress string) { +func (ds *downScalerImpl) scaleDown(ctx context.Context, backendEndpoint string) { for { select { - case <-ctx.Done(): - return - case <-time.After(ds.delay): - _, _, _, sleeper := Routes.FindBackendForServerAddress(ctx, serverAddress) - if sleeper == nil { - return - } - if err := sleeper(ctx); err != nil { - logrus.WithField("serverAddress", serverAddress).WithError(err).Error("failed to scale down backend") - } + case <-ctx.Done(): + return + case <-time.After(ds.delay): + sleepers := Routes.GetSleepers(backendEndpoint) + if len(sleepers) == 0 { return + } + for _, sleeper := range sleepers { + go func(s SleeperFunc) { + err := s(ctx) + if err != nil { + logrus.WithError(err). + WithField("backendEndpoint", backendEndpoint). + Error("Error while executing sleeper function") + } + }(sleeper) + } + return } } } diff --git a/server/k8s.go b/server/k8s.go index be32640..05512f7 100644 --- a/server/k8s.go +++ b/server/k8s.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "strconv" + "strings" "sync" "github.com/pkg/errors" @@ -183,9 +184,9 @@ func (w *K8sWatcher) handleUpdate(oldObj interface{}, newObj interface{}) { "new": newRoutableService, }).Debug("UPDATE") if newRoutableService.externalServiceName != "" { - w.routesHandler.CreateMapping(newRoutableService.externalServiceName, newRoutableService.containerEndpoint, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown) + w.routesHandler.CreateMapping(newRoutableService.externalServiceName, newRoutableService.containerEndpoint, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown, "") } else { - w.routesHandler.SetDefaultRoute(newRoutableService.containerEndpoint) + w.routesHandler.SetDefaultRoute(newRoutableService.containerEndpoint, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown, "") } } } @@ -200,7 +201,7 @@ func (w *K8sWatcher) handleDelete(obj interface{}) { if routableService.externalServiceName != "" { w.routesHandler.DeleteMapping(routableService.externalServiceName) } else { - w.routesHandler.SetDefaultRoute("") + w.routesHandler.SetDefaultRoute("", nil, nil, "") } } } @@ -214,9 +215,9 @@ func (w *K8sWatcher) handleAdd(obj interface{}) { logrus.WithField("routableService", routableService).Debug("ADD") if routableService.externalServiceName != "" { - w.routesHandler.CreateMapping(routableService.externalServiceName, routableService.containerEndpoint, routableService.autoScaleUp, routableService.autoScaleDown) + w.routesHandler.CreateMapping(routableService.externalServiceName, routableService.containerEndpoint, routableService.autoScaleUp, routableService.autoScaleDown, "") } else { - w.routesHandler.SetDefaultRoute(routableService.containerEndpoint) + w.routesHandler.SetDefaultRoute(routableService.containerEndpoint, routableService.autoScaleUp, routableService.autoScaleDown, "") } } } @@ -225,8 +226,8 @@ func (w *K8sWatcher) handleAdd(obj interface{}) { type routableService struct { externalServiceName string containerEndpoint string - autoScaleUp ScalerFunc - autoScaleDown ScalerFunc + autoScaleUp WakerFunc + autoScaleDown SleeperFunc } // obj is expected to be a *v1.Service @@ -271,22 +272,37 @@ func (w *K8sWatcher) buildDetails(service *core.Service, externalServiceName str } else if len(mcPort) > 0 { port = mcPort } + endpoint := net.JoinHostPort(clusterIp, port) + wakerFunc := w.buildScaleFunction(service, 0, 1) rs := &routableService{ externalServiceName: externalServiceName, - containerEndpoint: net.JoinHostPort(clusterIp, port), - autoScaleUp: w.buildScaleFunction(service, 0, 1), - autoScaleDown: w.buildScaleFunction(service, 1, 0), + containerEndpoint: endpoint, + autoScaleUp: func(ctx context.Context) (string, error) { + if err := wakerFunc(ctx); err != nil { + return "", err + } + return endpoint, nil + }, + autoScaleDown: w.buildScaleFunction(service, 1, 0), } return rs } -func (w *K8sWatcher) buildScaleFunction(service *core.Service, from int32, to int32) ScalerFunc { +func (w *K8sWatcher) buildScaleFunction(service *core.Service, from int32, to int32) SleeperFunc { // Currently, annotations can only be used to opt-out of auto-scaling. - // However, this logic is prepared also for opt-in, as it returns a `ScalerFunc` when flags are false but annotations are set to `enabled`. + // However, this logic is prepared also for opt-in, as it returns a `SleeperFunc` when flags are false but annotations are set to `enabled`. if from <= to { enabled, exists := service.Annotations[AnnotationAutoScaleUp] if exists { - if enabled == "false" { + enabledBool, err := strconv.ParseBool(strings.TrimSpace(enabled)) + if err != nil { + logrus.WithFields(logrus.Fields{"service": service.Name}). + WithError(err). + Warnf("invalid value for %s annotation - disabling service auto-scale-up", AnnotationAutoScaleUp) + return nil + } + + if !enabledBool { return nil } } else { @@ -298,7 +314,15 @@ func (w *K8sWatcher) buildScaleFunction(service *core.Service, from int32, to in if from >= to { enabled, exists := service.Annotations[AnnotationAutoScaleDown] if exists { - if enabled == "false" { + enabledBool, err := strconv.ParseBool(strings.TrimSpace(enabled)) + if err != nil { + logrus.WithFields(logrus.Fields{"service": service.Name}). + WithError(err). + Warnf("invalid value for %s annotation - disabling service auto-scale-down", AnnotationAutoScaleDown) + return nil + } + + if !enabledBool { return nil } } else { diff --git a/server/k8s_test.go b/server/k8s_test.go index 4dc1a48..34f1a63 100644 --- a/server/k8s_test.go +++ b/server/k8s_test.go @@ -3,10 +3,11 @@ package server import ( "context" "encoding/json" - "github.com/stretchr/testify/mock" "testing" "time" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" @@ -27,22 +28,27 @@ func (m *MockedRoutesHandler) GetBackendForServer(server string) string { } } -func (m *MockedRoutesHandler) CreateMapping(serverAddress string, backend string, waker ScalerFunc, sleeper ScalerFunc) { - m.MethodCalled("CreateMapping", serverAddress, backend, waker, sleeper) +func (m *MockedRoutesHandler) CreateMapping(serverAddress string, backend string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) { + m.MethodCalled("CreateMapping", serverAddress, backend, waker, sleeper, asleepMOTD) if m.routes == nil { m.routes = make(map[string]string) } m.routes[serverAddress] = backend } -func (m *MockedRoutesHandler) SetDefaultRoute(backend string) { - m.MethodCalled("SetDefaultRoute", backend) +func (m *MockedRoutesHandler) SetDefaultRoute(backend string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) { + m.MethodCalled("SetDefaultRoute", backend, waker, sleeper, asleepMOTD) if m.routes == nil { m.routes = make(map[string]string) } m.defaultBackend = backend } +func (m *MockedRoutesHandler) GetAsleepMOTD(serverAddress string) string { + args := m.MethodCalled("GetAsleepMOTD", serverAddress) + return args.String(0) +} + func (m *MockedRoutesHandler) DeleteMapping(serverAddress string) bool { args := m.MethodCalled("DeleteMapping", serverAddress) if m.routes == nil { @@ -177,7 +183,9 @@ func TestK8sWatcherImpl_handleAddThenUpdate(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) watcher := &K8sWatcher{ @@ -256,7 +264,9 @@ func TestK8sWatcherImpl_handleAddThenDelete(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) watcher := &K8sWatcher{ @@ -353,7 +363,9 @@ func TestK8s_externalName(t *testing.T) { DownScaler = NewDownScaler(context.Background(), false, 1*time.Second) routesHandler := new(MockedRoutesHandler) - routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("CreateMapping", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("SetDefaultRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() + routesHandler.On("GetAsleepMOTD", mock.Anything).Return("") routesHandler.On("DeleteMapping", mock.Anything).Return(true) watcher := &K8sWatcher{ diff --git a/server/routes.go b/server/routes.go index 15340f9..14c56b1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,9 +9,11 @@ import ( "github.com/sirupsen/logrus" ) -type ScalerFunc func(ctx context.Context) error +// WakerFunc is a function that wakes up a server and returns its address. +type WakerFunc func(ctx context.Context) (string, error) -var EmptyScalerFunc = func(ctx context.Context) error { return nil } +// SleeperFunc is a function that puts a server to sleep. +type SleeperFunc func(ctx context.Context) error var tcpShieldPattern = regexp.MustCompile("///.*") @@ -22,8 +24,8 @@ type RouteFinder interface { } type RoutesHandler interface { - CreateMapping(serverAddress string, backend string, waker ScalerFunc, sleeper ScalerFunc) - SetDefaultRoute(backend string) + CreateMapping(serverAddress string, backend string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) + SetDefaultRoute(backend string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) // DeleteMapping requests that the serverAddress be removed from routes. // Returns true if the route existed. DeleteMapping(serverAddress string) bool @@ -38,9 +40,12 @@ type IRoutes interface { // Otherwise, an empty string is returned. Also returns the normalized version of the given serverAddress. // The 3rd value returned is an (optional) "waker" function which a caller must invoke to wake up serverAddress. // The 4th value returned is an (optional) "sleeper" function which a caller must invoke to shut down serverAddress. - FindBackendForServerAddress(ctx context.Context, serverAddress string) (string, string, ScalerFunc, ScalerFunc) + HasRoute(serverAddress string) bool + FindBackendForServerAddress(ctx context.Context, serverAddress string) (string, string, WakerFunc, SleeperFunc) + GetSleepers(backend string) []SleeperFunc GetMappings() map[string]string - GetDefaultRoute() string + GetDefaultRoute() (string, WakerFunc, SleeperFunc) + GetAsleepMOTD(serverAddress string) string SimplifySRV(srvEnabled bool) } @@ -56,20 +61,21 @@ func NewRoutes() IRoutes { func (r *routesImpl) RegisterAll(mappings map[string]string) { for k, v := range mappings { - r.CreateMapping(k, v, EmptyScalerFunc, EmptyScalerFunc) + r.CreateMapping(k, v, nil, nil, "") } } type mapping struct { - backend string - waker ScalerFunc - sleeper ScalerFunc + backend string + waker WakerFunc + sleeper SleeperFunc + asleepMOTD string } type routesImpl struct { sync.RWMutex mappings map[string]mapping - defaultRoute string + defaultRoute mapping simplifySRV bool } @@ -78,23 +84,45 @@ func (r *routesImpl) Reset() { DownScaler.Reset() } -func (r *routesImpl) SetDefaultRoute(backend string) { - r.defaultRoute = backend +func (r *routesImpl) SetDefaultRoute(backend string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) { + r.defaultRoute = mapping{backend: backend, waker: waker, sleeper: sleeper, asleepMOTD: asleepMOTD} logrus.WithFields(logrus.Fields{ "backend": backend, }).Info("Using default route") } -func (r *routesImpl) GetDefaultRoute() string { - return r.defaultRoute +func (r *routesImpl) GetDefaultRoute() (string, WakerFunc, SleeperFunc) { + return r.defaultRoute.backend, r.defaultRoute.waker, r.defaultRoute.sleeper +} + +func (r *routesImpl) GetAsleepMOTD(serverAddress string) string { + r.RLock() + defer r.RUnlock() + + if serverAddress == "" { + return r.defaultRoute.asleepMOTD + } + + if m, ok := r.mappings[serverAddress]; ok { + return m.asleepMOTD + } + return "" } func (r *routesImpl) SimplifySRV(srvEnabled bool) { r.simplifySRV = srvEnabled } -func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddress string) (string, string, ScalerFunc, ScalerFunc) { +func (r *routesImpl) HasRoute(serverAddress string) bool { + r.RLock() + defer r.RUnlock() + + _, exists := r.mappings[serverAddress] + return exists +} + +func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddress string) (string, string, WakerFunc, SleeperFunc) { r.RLock() defer r.RUnlock() @@ -136,7 +164,23 @@ func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddres return mapping.backend, serverAddress, mapping.waker, mapping.sleeper } } - return r.defaultRoute, serverAddress, nil, nil + return r.defaultRoute.backend, serverAddress, r.defaultRoute.waker, r.defaultRoute.sleeper +} + +func (r *routesImpl) GetSleepers(backend string) []SleeperFunc { + r.RLock() + defer r.RUnlock() + + var sleepers []SleeperFunc + for _, m := range r.mappings { + if m.backend == backend && m.sleeper != nil { + sleepers = append(sleepers, m.sleeper) + } + } + if r.defaultRoute.backend == backend && r.defaultRoute.sleeper != nil { + sleepers = append(sleepers, r.defaultRoute.sleeper) + } + return sleepers } func (r *routesImpl) GetMappings() map[string]string { @@ -155,9 +199,8 @@ func (r *routesImpl) DeleteMapping(serverAddress string) bool { defer r.Unlock() logrus.WithField("serverAddress", serverAddress).Info("Deleting route") - DownScaler.Cancel(serverAddress) - - if _, ok := r.mappings[serverAddress]; ok { + if m, ok := r.mappings[serverAddress]; ok { + DownScaler.Cancel(m.backend) delete(r.mappings, serverAddress) return true } else { @@ -165,7 +208,7 @@ func (r *routesImpl) DeleteMapping(serverAddress string) bool { } } -func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker ScalerFunc, sleeper ScalerFunc) { +func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker WakerFunc, sleeper SleeperFunc, asleepMOTD string) { r.Lock() defer r.Unlock() @@ -175,8 +218,10 @@ func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker S "serverAddress": serverAddress, "backend": backend, }).Info("Created route mapping") - r.mappings[serverAddress] = mapping{backend: backend, waker: waker, sleeper: sleeper} + r.mappings[serverAddress] = mapping{backend: backend, waker: waker, sleeper: sleeper, asleepMOTD: asleepMOTD} // Trigger auto scale down when mapping is created to ensure servers are shut down if router restarts - DownScaler.Begin(serverAddress) + if DownScaler != nil && backend != "" { + DownScaler.Begin(backend) + } } diff --git a/server/routes_config_loader.go b/server/routes_config_loader.go index 43a9aeb..9134600 100644 --- a/server/routes_config_loader.go +++ b/server/routes_config_loader.go @@ -44,7 +44,7 @@ func (r *routesConfigLoader) Load(routesConfigFileName string) error { } Routes.RegisterAll(config.Mappings) - Routes.SetDefaultRoute(config.DefaultServer) + Routes.SetDefaultRoute(config.DefaultServer, nil, nil, "") return nil } @@ -62,7 +62,7 @@ func (r *routesConfigLoader) Reload() error { logrus.WithField("routesConfig", r.fileName).Info("Re-loading routes config file") Routes.Reset() Routes.RegisterAll(config.Mappings) - Routes.SetDefaultRoute(config.DefaultServer) + Routes.SetDefaultRoute(config.DefaultServer, nil, nil, "") return nil } @@ -135,8 +135,9 @@ func (r *routesConfigLoader) SaveRoutes() { return } + server, _, _ := Routes.GetDefaultRoute() err := r.writeFile(&RoutesConfigSchema{ - DefaultServer: Routes.GetDefaultRoute(), + DefaultServer: server, Mappings: Routes.GetMappings(), }) if err != nil { diff --git a/server/routes_test.go b/server/routes_test.go index 2001d33..2e09f70 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -66,7 +66,7 @@ func Test_routesImpl_FindBackendForServerAddress(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r := NewRoutes() - r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend, EmptyScalerFunc, EmptyScalerFunc) + r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend, nil, nil, "") if got, server, _, _ := r.FindBackendForServerAddress(context.Background(), tt.args.serverAddress); got != tt.want { t.Errorf("routesImpl.FindBackendForServerAddress() = %v, want %v", got, tt.want) diff --git a/server/server.go b/server/server.go index 78ddd4e..253cfd2 100644 --- a/server/server.go +++ b/server/server.go @@ -49,7 +49,7 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) { metricsBuilder := NewMetricsBuilder(config.MetricsBackend, &config.MetricsBackendConfig) - downScalerEnabled := config.AutoScale.Down && (config.InKubeCluster || config.KubeConfig != "") + downScalerEnabled := config.AutoScale.Down && (config.InKubeCluster || config.KubeConfig != "" || config.InDocker) downScalerDelay, err := time.ParseDuration(config.AutoScale.DownAfter) if err != nil { return nil, fmt.Errorf("could not parse auto-scale-down-after duration: %w", err) @@ -73,7 +73,7 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) { Routes.RegisterAll(config.Mapping) if config.Default != "" { - Routes.SetDefaultRoute(config.Default) + Routes.SetDefaultRoute(config.Default, nil, nil, "") } if config.ConnectionRateLimit < 1 { @@ -86,6 +86,8 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) { config.RecordLogins, autoScaleAllowDenyConfig) + connector.UseAsleepMOTD(config.AutoScale.AsleepMOTD) + clientFilter, err := NewClientFilter(config.ClientsToAllow, config.ClientsToDeny) if err != nil { return nil, fmt.Errorf("could not create client filter: %w", err)