Code cleanup in and around connector (#427)

This commit is contained in:
Geoff Bourne
2025-07-05 21:30:23 -05:00
committed by GitHub
parent 05c57c3b85
commit b3e88db48c
6 changed files with 148 additions and 122 deletions
+1 -1
View File
@@ -89,7 +89,7 @@ func NewClientFilter(allows []string, denies []string) (*ClientFilter, error) {
}, nil
}
// Allow determines if the given address is allowed by this filter
// Allow determines if this filter allows the given address
// where addrStr is a netip.ParseAddr allowed address
func (f *ClientFilter) Allow(addrPort netip.AddrPort) bool {
if !f.allow.Empty() {
+74 -107
View File
@@ -12,12 +12,9 @@ import (
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.ngrok.com/ngrok"
"golang.ngrok.com/ngrok/config"
"github.com/go-kit/kit/metrics"
"github.com/itzg/mc-router/mcproto"
"github.com/juju/ratelimit"
"github.com/pires/go-proxyproto"
@@ -30,58 +27,18 @@ const (
var noDeadline time.Time
type ConnectorMetrics struct {
Errors metrics.Counter
BytesTransmitted metrics.Counter
ConnectionsFrontend metrics.Counter
ConnectionsBackend metrics.Counter
ActiveConnections metrics.Gauge
ServerActivePlayer metrics.Gauge
ServerLogins metrics.Counter
ServerActiveConnections metrics.Gauge
}
type ClientInfo struct {
Host string `json:"host"`
Port int `json:"port"`
}
func ClientInfoFromAddr(addr net.Addr) *ClientInfo {
if addr == nil {
return nil
}
return &ClientInfo{
Host: addr.(*net.TCPAddr).IP.String(),
Port: addr.(*net.TCPAddr).Port,
}
}
type PlayerInfo struct {
Name string `json:"name"`
Uuid uuid.UUID `json:"uuid"`
}
func (p *PlayerInfo) String() string {
if p == nil {
return ""
}
return fmt.Sprintf("%s/%s", p.Name, p.Uuid)
}
type ServerMetrics struct {
type ActiveConnections struct {
sync.RWMutex
activeConnections map[string]int
}
func NewServerMetrics() *ServerMetrics {
return &ServerMetrics{
func NewActiveConnections() *ActiveConnections {
return &ActiveConnections{
activeConnections: make(map[string]int),
}
}
func (sm *ServerMetrics) IncrementActiveConnections(serverAddress string) {
func (sm *ActiveConnections) Increment(serverAddress string) {
sm.Lock()
defer sm.Unlock()
if _, ok := sm.activeConnections[serverAddress]; !ok {
@@ -91,7 +48,7 @@ func (sm *ServerMetrics) IncrementActiveConnections(serverAddress string) {
sm.activeConnections[serverAddress] += 1
}
func (sm *ServerMetrics) DecrementActiveConnections(serverAddress string) {
func (sm *ActiveConnections) Decrement(serverAddress string) {
sm.Lock()
defer sm.Unlock()
if activeConnections, ok := sm.activeConnections[serverAddress]; ok && activeConnections <= 0 {
@@ -101,7 +58,7 @@ func (sm *ServerMetrics) DecrementActiveConnections(serverAddress string) {
sm.activeConnections[serverAddress] -= 1
}
func (sm *ServerMetrics) ActiveConnectionsValue(serverAddress string) int {
func (sm *ActiveConnections) GetCount(serverAddress string) int {
sm.Lock()
defer sm.Unlock()
if activeConnections, ok := sm.activeConnections[serverAddress]; ok {
@@ -110,60 +67,58 @@ func (sm *ServerMetrics) ActiveConnectionsValue(serverAddress string) int {
return 0
}
func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet, recordLogins bool, autoScaleUpAllowDenyConfig *AllowDenyConfig) *Connector {
func NewConnector(ctx context.Context, metrics *ConnectorMetrics, sendProxyProto bool, recordLogins bool, autoScaleUpAllowDenyConfig *AllowDenyConfig) *Connector {
return &Connector{
ctx: ctx,
metrics: metrics,
sendProxyProto: sendProxyProto,
connectionsCond: sync.NewCond(&sync.Mutex{}),
receiveProxyProto: receiveProxyProto,
trustedProxyNets: trustedProxyNets,
recordLogins: recordLogins,
autoScaleUpAllowDenyConfig: autoScaleUpAllowDenyConfig,
serverMetrics: NewServerMetrics(),
activeConnections: NewActiveConnections(),
}
}
type Connector struct {
state mcproto.State
metrics *ConnectorMetrics
sendProxyProto bool
receiveProxyProto bool
recordLogins bool
trustedProxyNets []*net.IPNet
activeConnections int32
serverMetrics *ServerMetrics
ctx context.Context
state mcproto.State
metrics *ConnectorMetrics
sendProxyProto bool
receiveProxyProto bool
recordLogins bool
trustedProxyNets []*net.IPNet
totalActiveConnections int32
activeConnections *ActiveConnections
connectionsCond *sync.Cond
ngrokToken string
clientFilter *ClientFilter
autoScaleUpAllowDenyConfig *AllowDenyConfig
connectionNotifier ConnectionNotifier
connectionNotifier ConnectionNotifier
}
func (c *Connector) SetConnectionNotifier(notifier ConnectionNotifier) {
func (c *Connector) UseConnectionNotifier(notifier ConnectionNotifier) {
c.connectionNotifier = notifier
}
func (c *Connector) SetClientFilter(filter *ClientFilter) {
func (c *Connector) UseClientFilter(filter *ClientFilter) {
c.clientFilter = filter
}
func (c *Connector) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error {
ln, err := c.createListener(ctx, listenAddress)
func (c *Connector) StartAcceptingConnections(listenAddress string, connRateLimit int) error {
ln, err := c.createListener(listenAddress)
if err != nil {
return err
}
go c.acceptConnections(ctx, ln, connRateLimit)
go c.acceptConnections(ln, connRateLimit)
return nil
}
func (c *Connector) createListener(ctx context.Context, listenAddress string) (net.Listener, error) {
func (c *Connector) createListener(listenAddress string) (net.Listener, error) {
if c.ngrokToken != "" {
ngrokTun, err := ngrok.Listen(ctx,
ngrokTun, err := ngrok.Listen(c.ctx,
config.TCPEndpoint(),
ngrok.WithAuthtoken(c.ngrokToken),
)
@@ -184,8 +139,8 @@ func (c *Connector) createListener(ctx context.Context, listenAddress string) (n
if c.receiveProxyProto {
proxyListener := &proxyproto.Listener{
Listener: listener,
Policy: c.createProxyProtoPolicy(),
Listener: listener,
ConnPolicy: c.createProxyProtoPolicy(),
}
logrus.Info("Using PROXY protocol listener")
return proxyListener, nil
@@ -194,8 +149,8 @@ func (c *Connector) createListener(ctx context.Context, listenAddress string) (n
return listener, nil
}
func (c *Connector) createProxyProtoPolicy() func(upstream net.Addr) (proxyproto.Policy, error) {
return func(upstream net.Addr) (proxyproto.Policy, error) {
func (c *Connector) createProxyProtoPolicy() proxyproto.ConnPolicyFunc {
return func(connPolicyOptions proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
trustedIpNets := c.trustedProxyNets
if len(trustedIpNets) == 0 {
@@ -203,6 +158,7 @@ func (c *Connector) createProxyProtoPolicy() func(upstream net.Addr) (proxyproto
return proxyproto.USE, nil
}
upstream := connPolicyOptions.Upstream
upstreamIP := upstream.(*net.TCPAddr).IP
for _, ipNet := range trustedIpNets {
if ipNet.Contains(upstreamIP) {
@@ -221,17 +177,23 @@ func (c *Connector) WaitForConnections() {
defer c.connectionsCond.L.Unlock()
for {
count := atomic.LoadInt32(&c.activeConnections)
count := atomic.LoadInt32(&c.totalActiveConnections)
if count > 0 {
logrus.Infof("Waiting on %d connection(s)", count)
c.connectionsCond.Wait()
} else {
break
return
}
}
}
func (c *Connector) acceptConnections(ctx context.Context, ln net.Listener, connRateLimit int) {
// AcceptConnection provides a way to externally supply a connection to consume.
// Note that this will skip rate limiting.
func (c *Connector) AcceptConnection(conn net.Conn) {
go c.HandleConnection(conn)
}
func (c *Connector) acceptConnections(ln net.Listener, connRateLimit int) {
//noinspection GoUnhandledErrorResult
defer ln.Close()
@@ -239,7 +201,7 @@ func (c *Connector) acceptConnections(ctx context.Context, ln net.Listener, conn
for {
select {
case <-ctx.Done():
case <-c.ctx.Done():
return
case <-time.After(bucket.Take(1)):
@@ -247,13 +209,13 @@ func (c *Connector) acceptConnections(ctx context.Context, ln net.Listener, conn
if err != nil {
logrus.WithError(err).Error("Failed to accept connection")
} else {
go c.HandleConnection(ctx, conn)
go c.HandleConnection(conn)
}
}
}
}
func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn) {
func (c *Connector) HandleConnection(frontendConn net.Conn) {
c.metrics.ConnectionsFrontend.Add(1)
//noinspection GoUnhandledErrorResult
defer frontendConn.Close()
@@ -343,7 +305,7 @@ func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn)
Debug("Got user info")
}
c.findAndConnectBackend(ctx, frontendConn, clientAddr, inspectionBuffer, handshake.ServerAddress, playerInfo, handshake.NextState)
c.findAndConnectBackend(frontendConn, clientAddr, inspectionBuffer, handshake.ServerAddress, playerInfo, handshake.NextState)
} else if packet.PacketID == mcproto.PacketIdLegacyServerListPing {
handshake, ok := packet.Data.(*mcproto.LegacyServerListPing)
@@ -363,7 +325,7 @@ func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn)
serverAddress := handshake.ServerAddress
c.findAndConnectBackend(ctx, frontendConn, clientAddr, inspectionBuffer, serverAddress, nil, mcproto.StateStatus)
c.findAndConnectBackend(frontendConn, clientAddr, inspectionBuffer, serverAddress, nil, mcproto.StateStatus)
} else {
logrus.
WithField("client", clientAddr).
@@ -394,9 +356,9 @@ func (c *Connector) readPlayerInfo(protocolVersion mcproto.ProtocolVersion, buff
}
}
func (c *Connector) cleanupBackendConnection(ctx context.Context, clientAddr net.Addr, serverAddress string, playerInfo *PlayerInfo, backendHostPort string, cleanupMetrics bool, checkScaleDown bool) {
func (c *Connector) cleanupBackendConnection(clientAddr net.Addr, serverAddress string, playerInfo *PlayerInfo, backendHostPort string, cleanupMetrics bool, checkScaleDown bool) {
if c.connectionNotifier != nil {
err := c.connectionNotifier.NotifyDisconnected(ctx, clientAddr, serverAddress, playerInfo, backendHostPort)
err := c.connectionNotifier.NotifyDisconnected(c.ctx, clientAddr, serverAddress, playerInfo, backendHostPort)
if err != nil {
logrus.WithError(err).Warn("failed to notify disconnected")
}
@@ -404,12 +366,12 @@ func (c *Connector) cleanupBackendConnection(ctx context.Context, clientAddr net
if cleanupMetrics {
c.metrics.ActiveConnections.Set(float64(
atomic.AddInt32(&c.activeConnections, -1)))
atomic.AddInt32(&c.totalActiveConnections, -1)))
c.serverMetrics.DecrementActiveConnections(serverAddress)
c.activeConnections.Decrement(serverAddress)
c.metrics.ServerActiveConnections.
With("server_address", serverAddress).
Set(float64(c.serverMetrics.ActiveConnectionsValue(serverAddress)))
Set(float64(c.activeConnections.GetCount(serverAddress)))
if c.recordLogins && playerInfo != nil {
c.metrics.ServerActivePlayer.
@@ -419,21 +381,21 @@ func (c *Connector) cleanupBackendConnection(ctx context.Context, clientAddr net
Set(0)
}
}
if checkScaleDown && c.serverMetrics.ActiveConnectionsValue(serverAddress) <= 0 {
if checkScaleDown && c.activeConnections.GetCount(serverAddress) <= 0 {
DownScaler.Begin(serverAddress)
}
c.connectionsCond.Signal()
}
func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.Conn,
func (c *Connector) findAndConnectBackend(frontendConn net.Conn,
clientAddr net.Addr, preReadContent io.Reader, serverAddress string, playerInfo *PlayerInfo, nextState mcproto.State) {
backendHostPort, resolvedHost, waker, _ := Routes.FindBackendForServerAddress(ctx, serverAddress)
backendHostPort, resolvedHost, waker, _ := Routes.FindBackendForServerAddress(c.ctx, serverAddress)
cleanupMetrics := false
cleanupCheckScaleDown := false
defer func() {
c.cleanupBackendConnection(ctx, clientAddr, serverAddress, playerInfo, backendHostPort, cleanupMetrics, cleanupCheckScaleDown)
c.cleanupBackendConnection(clientAddr, serverAddress, playerInfo, backendHostPort, cleanupMetrics, cleanupCheckScaleDown)
}()
if waker != nil && nextState > mcproto.StateStatus {
@@ -448,7 +410,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
// Cancel down scaler if active before scale up
DownScaler.Cancel(serverAddress)
cleanupCheckScaleDown = true
if err := waker(ctx); err != nil {
if err := waker(c.ctx); err != nil {
logrus.WithFields(logrus.Fields{"serverAddress": serverAddress}).WithError(err).Error("failed to wake up backend")
c.metrics.Errors.With("type", "wakeup_failed").Add(1)
return
@@ -465,7 +427,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
c.metrics.Errors.With("type", "missing_backend").Add(1)
if c.connectionNotifier != nil {
err := c.connectionNotifier.NotifyMissingBackend(ctx, clientAddr, serverAddress, playerInfo)
err := c.connectionNotifier.NotifyMissingBackend(c.ctx, clientAddr, serverAddress, playerInfo)
if err != nil {
logrus.WithError(err).Warn("failed to notify missing backend")
}
@@ -493,7 +455,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
c.metrics.Errors.With("type", "backend_failed").Add(1)
if c.connectionNotifier != nil {
notifyErr := c.connectionNotifier.NotifyFailedBackendConnection(ctx, clientAddr, serverAddress, playerInfo, backendHostPort, err)
notifyErr := c.connectionNotifier.NotifyFailedBackendConnection(c.ctx, clientAddr, serverAddress, playerInfo, backendHostPort, err)
if notifyErr != nil {
logrus.WithError(notifyErr).Warn("failed to notify failed backend connection")
}
@@ -503,7 +465,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
}
if c.connectionNotifier != nil {
err := c.connectionNotifier.NotifyConnected(ctx, clientAddr, serverAddress, playerInfo, backendHostPort)
err := c.connectionNotifier.NotifyConnected(c.ctx, clientAddr, serverAddress, playerInfo, backendHostPort)
if err != nil {
logrus.WithError(err).Warn("failed to notify connected")
}
@@ -512,12 +474,12 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
c.metrics.ConnectionsBackend.With("host", resolvedHost).Add(1)
c.metrics.ActiveConnections.Set(float64(
atomic.AddInt32(&c.activeConnections, 1)))
atomic.AddInt32(&c.totalActiveConnections, 1)))
c.serverMetrics.IncrementActiveConnections(serverAddress)
c.activeConnections.Increment(serverAddress)
c.metrics.ServerActiveConnections.
With("server_address", serverAddress).
Set(float64(c.serverMetrics.ActiveConnectionsValue(serverAddress)))
Set(float64(c.activeConnections.GetCount(serverAddress)))
if c.recordLogins && playerInfo != nil {
logrus.
@@ -598,23 +560,23 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
return
}
c.pumpConnections(ctx, frontendConn, backendConn, playerInfo)
c.pumpConnections(frontendConn, backendConn, playerInfo)
}
func (c *Connector) pumpConnections(ctx context.Context, frontendConn, backendConn net.Conn, playerInfo *PlayerInfo) {
func (c *Connector) pumpConnections(frontendConn, backendConn net.Conn, playerInfo *PlayerInfo) {
//noinspection GoUnhandledErrorResult
defer backendConn.Close()
clientAddr := frontendConn.RemoteAddr()
defer logrus.WithField("client", clientAddr).Debug("Closing backend connection")
errors := make(chan error, 2)
errorsChan := make(chan error, 2)
go c.pumpFrames(backendConn, frontendConn, errors, "backend", "frontend", clientAddr, playerInfo)
go c.pumpFrames(frontendConn, backendConn, errors, "frontend", "backend", clientAddr, playerInfo)
go c.pumpFrames(backendConn, frontendConn, errorsChan, "backend", "frontend", clientAddr, playerInfo)
go c.pumpFrames(frontendConn, backendConn, errorsChan, "frontend", "backend", clientAddr, playerInfo)
select {
case err := <-errors:
case err := <-errorsChan:
if err != io.EOF {
logrus.WithError(err).
WithField("client", clientAddr).
@@ -622,8 +584,8 @@ func (c *Connector) pumpConnections(ctx context.Context, frontendConn, backendCo
c.metrics.Errors.With("type", "relay").Add(1)
}
case <-ctx.Done():
logrus.Debug("Observed context cancellation")
case <-c.ctx.Done():
logrus.Debug("Connector observed context cancellation")
}
}
@@ -649,3 +611,8 @@ func (c *Connector) pumpFrames(incoming io.Reader, outgoing io.Writer, errors ch
func (c *Connector) UseNgrok(token string) {
c.ngrokToken = token
}
func (c *Connector) UseReceiveProxyProto(trustedProxyNets []*net.IPNet) {
c.trustedProxyNets = trustedProxyNets
c.receiveProxyProto = true
}
+3 -1
View File
@@ -61,7 +61,9 @@ func TestTrustedProxyNetworkPolicy(t *testing.T) {
policy := c.createProxyProtoPolicy()
upstreamAddr := &net.TCPAddr{IP: net.ParseIP(test.upstreamIP)}
policyResult, _ := policy(upstreamAddr)
policyResult, _ := policy(proxyproto.ConnPolicyOptions{
Upstream: upstreamAddr,
})
assert.Equal(t, test.expectedPolicy, policyResult, "Unexpected policy result for %s", test.name)
})
}
+12
View File
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/go-kit/kit/metrics"
"strings"
"time"
@@ -68,6 +69,17 @@ func (b expvarMetricsBuilder) Start(ctx context.Context) error {
return nil
}
type ConnectorMetrics struct {
Errors metrics.Counter
BytesTransmitted metrics.Counter
ConnectionsFrontend metrics.Counter
ConnectionsBackend metrics.Counter
ActiveConnections metrics.Gauge
ServerActivePlayer metrics.Gauge
ServerLogins metrics.Counter
ServerActiveConnections metrics.Gauge
}
func (b expvarMetricsBuilder) BuildConnectorMetrics() *ConnectorMetrics {
c := expvarMetrics.NewCounter("connections")
return &ConnectorMetrics{
+31
View File
@@ -2,9 +2,40 @@ package server
import (
"context"
"fmt"
"github.com/google/uuid"
"net"
)
type PlayerInfo struct {
Name string `json:"name"`
Uuid uuid.UUID `json:"uuid"`
}
func (p *PlayerInfo) String() string {
if p == nil {
return ""
}
return fmt.Sprintf("%s/%s", p.Name, p.Uuid)
}
type ClientInfo struct {
Host string `json:"host"`
Port int `json:"port"`
}
func ClientInfoFromAddr(addr net.Addr) *ClientInfo {
if addr == nil {
return nil
}
return &ClientInfo{
Host: addr.(*net.TCPAddr).IP.String(),
Port: addr.(*net.TCPAddr).Port,
}
}
type ConnectionNotifier interface {
// NotifyMissingBackend is called when an inbound connection is received for a server that does not have a backend.
NotifyMissingBackend(ctx context.Context, clientAddr net.Addr, server string, playerInfo *PlayerInfo) error
+27 -13
View File
@@ -79,28 +79,23 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) {
config.ConnectionRateLimit = 1
}
trustedIpNets := make([]*net.IPNet, 0)
for _, ip := range config.TrustedProxies {
_, ipNet, err := net.ParseCIDR(ip)
if err != nil {
return nil, fmt.Errorf("could not parse trusted proxy CIDR block: %w", err)
}
trustedIpNets = append(trustedIpNets, ipNet)
}
connector := NewConnector(metricsBuilder.BuildConnectorMetrics(), config.UseProxyProtocol, config.ReceiveProxyProtocol, trustedIpNets, config.RecordLogins, autoScaleAllowDenyConfig)
connector := NewConnector(ctx,
metricsBuilder.BuildConnectorMetrics(),
config.UseProxyProtocol,
config.RecordLogins,
autoScaleAllowDenyConfig)
clientFilter, err := NewClientFilter(config.ClientsToAllow, config.ClientsToDeny)
if err != nil {
return nil, fmt.Errorf("could not create client filter: %w", err)
}
connector.SetClientFilter(clientFilter)
connector.UseClientFilter(clientFilter)
if config.Webhook.Url != "" {
logrus.WithField("url", config.Webhook.Url).
WithField("require-user", config.Webhook.RequireUser).
Info("Using webhook for connection status notifications")
connector.SetConnectionNotifier(
connector.UseConnectionNotifier(
NewWebhookNotifier(config.Webhook.Url, config.Webhook.RequireUser))
}
@@ -108,6 +103,19 @@ func NewServer(ctx context.Context, config *Config) (*Server, error) {
connector.UseNgrok(config.NgrokToken)
}
if config.ReceiveProxyProtocol {
trustedIpNets := make([]*net.IPNet, 0)
for _, ip := range config.TrustedProxies {
_, ipNet, err := net.ParseCIDR(ip)
if err != nil {
return nil, fmt.Errorf("could not parse trusted proxy CIDR block: %w", err)
}
trustedIpNets = append(trustedIpNets, ipNet)
}
connector.UseReceiveProxyProto(trustedIpNets)
}
if config.ApiBinding != "" {
StartApiServer(config.ApiBinding)
}
@@ -177,10 +185,16 @@ func (s *Server) ReloadConfig() {
s.reloadConfigChan <- struct{}{}
}
// AcceptConnection provides a way to externally supply a connection to consume
// Note that this will skip rate limiting.
func (s *Server) AcceptConnection(conn net.Conn) {
s.connector.AcceptConnection(conn)
}
// Run will run the server until the context is done or a fatal error occurs, so this should be
// in a go routine.
func (s *Server) Run() {
err := s.connector.StartAcceptingConnections(s.ctx,
err := s.connector.StartAcceptingConnections(
net.JoinHostPort("", strconv.Itoa(s.config.Port)),
s.config.ConnectionRateLimit,
)