diff --git a/server/client_filter.go b/server/client_filter.go index 33b7bf4..fdee7c0 100644 --- a/server/client_filter.go +++ b/server/client_filter.go @@ -89,7 +89,7 @@ func NewClientFilter(allows []string, denies []string) (*ClientFilter, error) { }, nil } -// Allow determines if the given address is allowed by this filter +// Allow determines if this filter allows the given address // where addrStr is a netip.ParseAddr allowed address func (f *ClientFilter) Allow(addrPort netip.AddrPort) bool { if !f.allow.Empty() { diff --git a/server/connector.go b/server/connector.go index 6af8d9f..3f8e4bc 100644 --- a/server/connector.go +++ b/server/connector.go @@ -12,12 +12,9 @@ import ( "sync/atomic" "time" - "github.com/google/uuid" - "golang.ngrok.com/ngrok" "golang.ngrok.com/ngrok/config" - "github.com/go-kit/kit/metrics" "github.com/itzg/mc-router/mcproto" "github.com/juju/ratelimit" "github.com/pires/go-proxyproto" @@ -30,58 +27,18 @@ const ( var noDeadline time.Time -type ConnectorMetrics struct { - Errors metrics.Counter - BytesTransmitted metrics.Counter - ConnectionsFrontend metrics.Counter - ConnectionsBackend metrics.Counter - ActiveConnections metrics.Gauge - ServerActivePlayer metrics.Gauge - ServerLogins metrics.Counter - ServerActiveConnections metrics.Gauge -} - -type ClientInfo struct { - Host string `json:"host"` - Port int `json:"port"` -} - -func ClientInfoFromAddr(addr net.Addr) *ClientInfo { - if addr == nil { - return nil - } - - return &ClientInfo{ - Host: addr.(*net.TCPAddr).IP.String(), - Port: addr.(*net.TCPAddr).Port, - } -} - -type PlayerInfo struct { - Name string `json:"name"` - Uuid uuid.UUID `json:"uuid"` -} - -func (p *PlayerInfo) String() string { - if p == nil { - return "" - } - - return fmt.Sprintf("%s/%s", p.Name, p.Uuid) -} - -type ServerMetrics struct { +type ActiveConnections struct { sync.RWMutex activeConnections map[string]int } -func NewServerMetrics() *ServerMetrics { - return &ServerMetrics{ +func NewActiveConnections() *ActiveConnections { + return &ActiveConnections{ activeConnections: make(map[string]int), } } -func (sm *ServerMetrics) IncrementActiveConnections(serverAddress string) { +func (sm *ActiveConnections) Increment(serverAddress string) { sm.Lock() defer sm.Unlock() if _, ok := sm.activeConnections[serverAddress]; !ok { @@ -91,7 +48,7 @@ func (sm *ServerMetrics) IncrementActiveConnections(serverAddress string) { sm.activeConnections[serverAddress] += 1 } -func (sm *ServerMetrics) DecrementActiveConnections(serverAddress string) { +func (sm *ActiveConnections) Decrement(serverAddress string) { sm.Lock() defer sm.Unlock() if activeConnections, ok := sm.activeConnections[serverAddress]; ok && activeConnections <= 0 { @@ -101,7 +58,7 @@ func (sm *ServerMetrics) DecrementActiveConnections(serverAddress string) { sm.activeConnections[serverAddress] -= 1 } -func (sm *ServerMetrics) ActiveConnectionsValue(serverAddress string) int { +func (sm *ActiveConnections) GetCount(serverAddress string) int { sm.Lock() defer sm.Unlock() if activeConnections, ok := sm.activeConnections[serverAddress]; ok { @@ -110,60 +67,58 @@ func (sm *ServerMetrics) ActiveConnectionsValue(serverAddress string) int { return 0 } -func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet, recordLogins bool, autoScaleUpAllowDenyConfig *AllowDenyConfig) *Connector { +func NewConnector(ctx context.Context, metrics *ConnectorMetrics, sendProxyProto bool, recordLogins bool, autoScaleUpAllowDenyConfig *AllowDenyConfig) *Connector { return &Connector{ + ctx: ctx, metrics: metrics, sendProxyProto: sendProxyProto, connectionsCond: sync.NewCond(&sync.Mutex{}), - receiveProxyProto: receiveProxyProto, - trustedProxyNets: trustedProxyNets, recordLogins: recordLogins, autoScaleUpAllowDenyConfig: autoScaleUpAllowDenyConfig, - serverMetrics: NewServerMetrics(), + activeConnections: NewActiveConnections(), } } type Connector struct { - state mcproto.State - metrics *ConnectorMetrics - sendProxyProto bool - receiveProxyProto bool - recordLogins bool - trustedProxyNets []*net.IPNet - - activeConnections int32 - serverMetrics *ServerMetrics + ctx context.Context + state mcproto.State + metrics *ConnectorMetrics + sendProxyProto bool + receiveProxyProto bool + recordLogins bool + trustedProxyNets []*net.IPNet + totalActiveConnections int32 + activeConnections *ActiveConnections connectionsCond *sync.Cond ngrokToken string clientFilter *ClientFilter autoScaleUpAllowDenyConfig *AllowDenyConfig - - connectionNotifier ConnectionNotifier + connectionNotifier ConnectionNotifier } -func (c *Connector) SetConnectionNotifier(notifier ConnectionNotifier) { +func (c *Connector) UseConnectionNotifier(notifier ConnectionNotifier) { c.connectionNotifier = notifier } -func (c *Connector) SetClientFilter(filter *ClientFilter) { +func (c *Connector) UseClientFilter(filter *ClientFilter) { c.clientFilter = filter } -func (c *Connector) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error { - ln, err := c.createListener(ctx, listenAddress) +func (c *Connector) StartAcceptingConnections(listenAddress string, connRateLimit int) error { + ln, err := c.createListener(listenAddress) if err != nil { return err } - go c.acceptConnections(ctx, ln, connRateLimit) + go c.acceptConnections(ln, connRateLimit) return nil } -func (c *Connector) createListener(ctx context.Context, listenAddress string) (net.Listener, error) { +func (c *Connector) createListener(listenAddress string) (net.Listener, error) { if c.ngrokToken != "" { - ngrokTun, err := ngrok.Listen(ctx, + ngrokTun, err := ngrok.Listen(c.ctx, config.TCPEndpoint(), ngrok.WithAuthtoken(c.ngrokToken), ) @@ -184,8 +139,8 @@ func (c *Connector) createListener(ctx context.Context, listenAddress string) (n if c.receiveProxyProto { proxyListener := &proxyproto.Listener{ - Listener: listener, - Policy: c.createProxyProtoPolicy(), + Listener: listener, + ConnPolicy: c.createProxyProtoPolicy(), } logrus.Info("Using PROXY protocol listener") return proxyListener, nil @@ -194,8 +149,8 @@ func (c *Connector) createListener(ctx context.Context, listenAddress string) (n return listener, nil } -func (c *Connector) createProxyProtoPolicy() func(upstream net.Addr) (proxyproto.Policy, error) { - return func(upstream net.Addr) (proxyproto.Policy, error) { +func (c *Connector) createProxyProtoPolicy() proxyproto.ConnPolicyFunc { + return func(connPolicyOptions proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { trustedIpNets := c.trustedProxyNets if len(trustedIpNets) == 0 { @@ -203,6 +158,7 @@ func (c *Connector) createProxyProtoPolicy() func(upstream net.Addr) (proxyproto return proxyproto.USE, nil } + upstream := connPolicyOptions.Upstream upstreamIP := upstream.(*net.TCPAddr).IP for _, ipNet := range trustedIpNets { if ipNet.Contains(upstreamIP) { @@ -221,17 +177,23 @@ func (c *Connector) WaitForConnections() { defer c.connectionsCond.L.Unlock() for { - count := atomic.LoadInt32(&c.activeConnections) + count := atomic.LoadInt32(&c.totalActiveConnections) if count > 0 { logrus.Infof("Waiting on %d connection(s)", count) c.connectionsCond.Wait() } else { - break + return } } } -func (c *Connector) acceptConnections(ctx context.Context, ln net.Listener, connRateLimit int) { +// AcceptConnection provides a way to externally supply a connection to consume. +// Note that this will skip rate limiting. +func (c *Connector) AcceptConnection(conn net.Conn) { + go c.HandleConnection(conn) +} + +func (c *Connector) acceptConnections(ln net.Listener, connRateLimit int) { //noinspection GoUnhandledErrorResult defer ln.Close() @@ -239,7 +201,7 @@ func (c *Connector) acceptConnections(ctx context.Context, ln net.Listener, conn for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return case <-time.After(bucket.Take(1)): @@ -247,13 +209,13 @@ func (c *Connector) acceptConnections(ctx context.Context, ln net.Listener, conn if err != nil { logrus.WithError(err).Error("Failed to accept connection") } else { - go c.HandleConnection(ctx, conn) + go c.HandleConnection(conn) } } } } -func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn) { +func (c *Connector) HandleConnection(frontendConn net.Conn) { c.metrics.ConnectionsFrontend.Add(1) //noinspection GoUnhandledErrorResult defer frontendConn.Close() @@ -343,7 +305,7 @@ func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn) Debug("Got user info") } - c.findAndConnectBackend(ctx, frontendConn, clientAddr, inspectionBuffer, handshake.ServerAddress, playerInfo, handshake.NextState) + c.findAndConnectBackend(frontendConn, clientAddr, inspectionBuffer, handshake.ServerAddress, playerInfo, handshake.NextState) } else if packet.PacketID == mcproto.PacketIdLegacyServerListPing { handshake, ok := packet.Data.(*mcproto.LegacyServerListPing) @@ -363,7 +325,7 @@ func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn) serverAddress := handshake.ServerAddress - c.findAndConnectBackend(ctx, frontendConn, clientAddr, inspectionBuffer, serverAddress, nil, mcproto.StateStatus) + c.findAndConnectBackend(frontendConn, clientAddr, inspectionBuffer, serverAddress, nil, mcproto.StateStatus) } else { logrus. WithField("client", clientAddr). @@ -394,9 +356,9 @@ func (c *Connector) readPlayerInfo(protocolVersion mcproto.ProtocolVersion, buff } } -func (c *Connector) cleanupBackendConnection(ctx context.Context, clientAddr net.Addr, serverAddress string, playerInfo *PlayerInfo, backendHostPort string, cleanupMetrics bool, checkScaleDown bool) { +func (c *Connector) cleanupBackendConnection(clientAddr net.Addr, serverAddress string, playerInfo *PlayerInfo, backendHostPort string, cleanupMetrics bool, checkScaleDown bool) { if c.connectionNotifier != nil { - err := c.connectionNotifier.NotifyDisconnected(ctx, clientAddr, serverAddress, playerInfo, backendHostPort) + err := c.connectionNotifier.NotifyDisconnected(c.ctx, clientAddr, serverAddress, playerInfo, backendHostPort) if err != nil { logrus.WithError(err).Warn("failed to notify disconnected") } @@ -404,12 +366,12 @@ func (c *Connector) cleanupBackendConnection(ctx context.Context, clientAddr net if cleanupMetrics { c.metrics.ActiveConnections.Set(float64( - atomic.AddInt32(&c.activeConnections, -1))) + atomic.AddInt32(&c.totalActiveConnections, -1))) - c.serverMetrics.DecrementActiveConnections(serverAddress) + c.activeConnections.Decrement(serverAddress) c.metrics.ServerActiveConnections. With("server_address", serverAddress). - Set(float64(c.serverMetrics.ActiveConnectionsValue(serverAddress))) + Set(float64(c.activeConnections.GetCount(serverAddress))) if c.recordLogins && playerInfo != nil { c.metrics.ServerActivePlayer. @@ -419,21 +381,21 @@ func (c *Connector) cleanupBackendConnection(ctx context.Context, clientAddr net Set(0) } } - if checkScaleDown && c.serverMetrics.ActiveConnectionsValue(serverAddress) <= 0 { + if checkScaleDown && c.activeConnections.GetCount(serverAddress) <= 0 { DownScaler.Begin(serverAddress) } c.connectionsCond.Signal() } -func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.Conn, +func (c *Connector) findAndConnectBackend(frontendConn net.Conn, clientAddr net.Addr, preReadContent io.Reader, serverAddress string, playerInfo *PlayerInfo, nextState mcproto.State) { - backendHostPort, resolvedHost, waker, _ := Routes.FindBackendForServerAddress(ctx, serverAddress) + backendHostPort, resolvedHost, waker, _ := Routes.FindBackendForServerAddress(c.ctx, serverAddress) cleanupMetrics := false cleanupCheckScaleDown := false defer func() { - c.cleanupBackendConnection(ctx, clientAddr, serverAddress, playerInfo, backendHostPort, cleanupMetrics, cleanupCheckScaleDown) + c.cleanupBackendConnection(clientAddr, serverAddress, playerInfo, backendHostPort, cleanupMetrics, cleanupCheckScaleDown) }() if waker != nil && nextState > mcproto.StateStatus { @@ -448,7 +410,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. // Cancel down scaler if active before scale up DownScaler.Cancel(serverAddress) cleanupCheckScaleDown = true - if err := waker(ctx); err != nil { + if err := waker(c.ctx); err != nil { logrus.WithFields(logrus.Fields{"serverAddress": serverAddress}).WithError(err).Error("failed to wake up backend") c.metrics.Errors.With("type", "wakeup_failed").Add(1) return @@ -465,7 +427,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. c.metrics.Errors.With("type", "missing_backend").Add(1) if c.connectionNotifier != nil { - err := c.connectionNotifier.NotifyMissingBackend(ctx, clientAddr, serverAddress, playerInfo) + err := c.connectionNotifier.NotifyMissingBackend(c.ctx, clientAddr, serverAddress, playerInfo) if err != nil { logrus.WithError(err).Warn("failed to notify missing backend") } @@ -493,7 +455,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. c.metrics.Errors.With("type", "backend_failed").Add(1) if c.connectionNotifier != nil { - notifyErr := c.connectionNotifier.NotifyFailedBackendConnection(ctx, clientAddr, serverAddress, playerInfo, backendHostPort, err) + notifyErr := c.connectionNotifier.NotifyFailedBackendConnection(c.ctx, clientAddr, serverAddress, playerInfo, backendHostPort, err) if notifyErr != nil { logrus.WithError(notifyErr).Warn("failed to notify failed backend connection") } @@ -503,7 +465,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. } if c.connectionNotifier != nil { - err := c.connectionNotifier.NotifyConnected(ctx, clientAddr, serverAddress, playerInfo, backendHostPort) + err := c.connectionNotifier.NotifyConnected(c.ctx, clientAddr, serverAddress, playerInfo, backendHostPort) if err != nil { logrus.WithError(err).Warn("failed to notify connected") } @@ -512,12 +474,12 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. c.metrics.ConnectionsBackend.With("host", resolvedHost).Add(1) c.metrics.ActiveConnections.Set(float64( - atomic.AddInt32(&c.activeConnections, 1))) + atomic.AddInt32(&c.totalActiveConnections, 1))) - c.serverMetrics.IncrementActiveConnections(serverAddress) + c.activeConnections.Increment(serverAddress) c.metrics.ServerActiveConnections. With("server_address", serverAddress). - Set(float64(c.serverMetrics.ActiveConnectionsValue(serverAddress))) + Set(float64(c.activeConnections.GetCount(serverAddress))) if c.recordLogins && playerInfo != nil { logrus. @@ -598,23 +560,23 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net. return } - c.pumpConnections(ctx, frontendConn, backendConn, playerInfo) + c.pumpConnections(frontendConn, backendConn, playerInfo) } -func (c *Connector) pumpConnections(ctx context.Context, frontendConn, backendConn net.Conn, playerInfo *PlayerInfo) { +func (c *Connector) pumpConnections(frontendConn, backendConn net.Conn, playerInfo *PlayerInfo) { //noinspection GoUnhandledErrorResult defer backendConn.Close() clientAddr := frontendConn.RemoteAddr() defer logrus.WithField("client", clientAddr).Debug("Closing backend connection") - errors := make(chan error, 2) + errorsChan := make(chan error, 2) - go c.pumpFrames(backendConn, frontendConn, errors, "backend", "frontend", clientAddr, playerInfo) - go c.pumpFrames(frontendConn, backendConn, errors, "frontend", "backend", clientAddr, playerInfo) + go c.pumpFrames(backendConn, frontendConn, errorsChan, "backend", "frontend", clientAddr, playerInfo) + go c.pumpFrames(frontendConn, backendConn, errorsChan, "frontend", "backend", clientAddr, playerInfo) select { - case err := <-errors: + case err := <-errorsChan: if err != io.EOF { logrus.WithError(err). WithField("client", clientAddr). @@ -622,8 +584,8 @@ func (c *Connector) pumpConnections(ctx context.Context, frontendConn, backendCo c.metrics.Errors.With("type", "relay").Add(1) } - case <-ctx.Done(): - logrus.Debug("Observed context cancellation") + case <-c.ctx.Done(): + logrus.Debug("Connector observed context cancellation") } } @@ -649,3 +611,8 @@ func (c *Connector) pumpFrames(incoming io.Reader, outgoing io.Writer, errors ch func (c *Connector) UseNgrok(token string) { c.ngrokToken = token } + +func (c *Connector) UseReceiveProxyProto(trustedProxyNets []*net.IPNet) { + c.trustedProxyNets = trustedProxyNets + c.receiveProxyProto = true +} diff --git a/server/connector_test.go b/server/connector_test.go index 357aa44..ab26a53 100644 --- a/server/connector_test.go +++ b/server/connector_test.go @@ -61,7 +61,9 @@ func TestTrustedProxyNetworkPolicy(t *testing.T) { policy := c.createProxyProtoPolicy() upstreamAddr := &net.TCPAddr{IP: net.ParseIP(test.upstreamIP)} - policyResult, _ := policy(upstreamAddr) + policyResult, _ := policy(proxyproto.ConnPolicyOptions{ + Upstream: upstreamAddr, + }) assert.Equal(t, test.expectedPolicy, policyResult, "Unexpected policy result for %s", test.name) }) } diff --git a/server/metrics.go b/server/metrics.go index 9da6dc5..d704e49 100644 --- a/server/metrics.go +++ b/server/metrics.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/go-kit/kit/metrics" "strings" "time" @@ -68,6 +69,17 @@ func (b expvarMetricsBuilder) Start(ctx context.Context) error { return nil } +type ConnectorMetrics struct { + Errors metrics.Counter + BytesTransmitted metrics.Counter + ConnectionsFrontend metrics.Counter + ConnectionsBackend metrics.Counter + ActiveConnections metrics.Gauge + ServerActivePlayer metrics.Gauge + ServerLogins metrics.Counter + ServerActiveConnections metrics.Gauge +} + func (b expvarMetricsBuilder) BuildConnectorMetrics() *ConnectorMetrics { c := expvarMetrics.NewCounter("connections") return &ConnectorMetrics{ diff --git a/server/notifier.go b/server/notifier.go index ad17c51..7d322a3 100644 --- a/server/notifier.go +++ b/server/notifier.go @@ -2,9 +2,40 @@ package server import ( "context" + "fmt" + "github.com/google/uuid" "net" ) +type PlayerInfo struct { + Name string `json:"name"` + Uuid uuid.UUID `json:"uuid"` +} + +func (p *PlayerInfo) String() string { + if p == nil { + return "" + } + + return fmt.Sprintf("%s/%s", p.Name, p.Uuid) +} + +type ClientInfo struct { + Host string `json:"host"` + Port int `json:"port"` +} + +func ClientInfoFromAddr(addr net.Addr) *ClientInfo { + if addr == nil { + return nil + } + + return &ClientInfo{ + Host: addr.(*net.TCPAddr).IP.String(), + Port: addr.(*net.TCPAddr).Port, + } +} + type ConnectionNotifier interface { // NotifyMissingBackend is called when an inbound connection is received for a server that does not have a backend. NotifyMissingBackend(ctx context.Context, clientAddr net.Addr, server string, playerInfo *PlayerInfo) error diff --git a/server/server.go b/server/server.go index f20cdb8..1ca7652 100644 --- a/server/server.go +++ b/server/server.go @@ -79,28 +79,23 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) { config.ConnectionRateLimit = 1 } - trustedIpNets := make([]*net.IPNet, 0) - for _, ip := range config.TrustedProxies { - _, ipNet, err := net.ParseCIDR(ip) - if err != nil { - return nil, fmt.Errorf("could not parse trusted proxy CIDR block: %w", err) - } - trustedIpNets = append(trustedIpNets, ipNet) - } - - connector := NewConnector(metricsBuilder.BuildConnectorMetrics(), config.UseProxyProtocol, config.ReceiveProxyProtocol, trustedIpNets, config.RecordLogins, autoScaleAllowDenyConfig) + connector := NewConnector(ctx, + metricsBuilder.BuildConnectorMetrics(), + config.UseProxyProtocol, + config.RecordLogins, + autoScaleAllowDenyConfig) clientFilter, err := NewClientFilter(config.ClientsToAllow, config.ClientsToDeny) if err != nil { return nil, fmt.Errorf("could not create client filter: %w", err) } - connector.SetClientFilter(clientFilter) + connector.UseClientFilter(clientFilter) if config.Webhook.Url != "" { logrus.WithField("url", config.Webhook.Url). WithField("require-user", config.Webhook.RequireUser). Info("Using webhook for connection status notifications") - connector.SetConnectionNotifier( + connector.UseConnectionNotifier( NewWebhookNotifier(config.Webhook.Url, config.Webhook.RequireUser)) } @@ -108,6 +103,19 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) { connector.UseNgrok(config.NgrokToken) } + if config.ReceiveProxyProtocol { + trustedIpNets := make([]*net.IPNet, 0) + for _, ip := range config.TrustedProxies { + _, ipNet, err := net.ParseCIDR(ip) + if err != nil { + return nil, fmt.Errorf("could not parse trusted proxy CIDR block: %w", err) + } + trustedIpNets = append(trustedIpNets, ipNet) + } + + connector.UseReceiveProxyProto(trustedIpNets) + } + if config.ApiBinding != "" { StartApiServer(config.ApiBinding) } @@ -177,10 +185,16 @@ func (s *Server) ReloadConfig() { s.reloadConfigChan <- struct{}{} } +// AcceptConnection provides a way to externally supply a connection to consume +// Note that this will skip rate limiting. +func (s *Server) AcceptConnection(conn net.Conn) { + s.connector.AcceptConnection(conn) +} + // Run will run the server until the context is done or a fatal error occurs, so this should be // in a go routine. func (s *Server) Run() { - err := s.connector.StartAcceptingConnections(s.ctx, + err := s.connector.StartAcceptingConnections( net.JoinHostPort("", strconv.Itoa(s.config.Port)), s.config.ConnectionRateLimit, )