Add auto scale down option (#405)

This commit is contained in:
Samuel McBroom
2025-05-02 16:12:53 -07:00
committed by GitHub
parent f6300d6a8a
commit bc81e03f19
11 changed files with 373 additions and 109 deletions
+100 -26
View File
@@ -30,13 +30,14 @@ const (
var noDeadline time.Time
type ConnectorMetrics struct {
Errors metrics.Counter
BytesTransmitted metrics.Counter
ConnectionsFrontend metrics.Counter
ConnectionsBackend metrics.Counter
ActiveConnections metrics.Gauge
ServerActivePlayer metrics.Gauge
ServerLogins metrics.Counter
Errors metrics.Counter
BytesTransmitted metrics.Counter
ConnectionsFrontend metrics.Counter
ConnectionsBackend metrics.Counter
ActiveConnections metrics.Gauge
ServerActivePlayer metrics.Gauge
ServerLogins metrics.Counter
ServerActiveConnections metrics.Gauge
}
type ClientInfo struct {
@@ -67,7 +68,48 @@ func (p *PlayerInfo) String() string {
return fmt.Sprintf("%s/%s", p.Name, p.Uuid)
}
type ServerMetrics struct {
sync.RWMutex
activeConnections map[string]int
}
func NewServerMetrics() *ServerMetrics {
return &ServerMetrics{
activeConnections: make(map[string]int),
}
}
func (sm *ServerMetrics) IncrementActiveConnections(serverAddress string) {
sm.Lock()
defer sm.Unlock()
if _, ok := sm.activeConnections[serverAddress]; !ok {
sm.activeConnections[serverAddress] = 1
return
}
sm.activeConnections[serverAddress] += 1
}
func (sm *ServerMetrics) DecrementActiveConnections(serverAddress string) {
sm.Lock()
defer sm.Unlock()
if activeConnections, ok := sm.activeConnections[serverAddress]; ok && activeConnections <= 0 {
sm.activeConnections[serverAddress] = 0
return
}
sm.activeConnections[serverAddress] -= 1
}
func (sm *ServerMetrics) ActiveConnectionsValue(serverAddress string) int {
sm.Lock()
defer sm.Unlock()
if activeConnections, ok := sm.activeConnections[serverAddress]; ok {
return activeConnections
}
return 0
}
func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet, recordLogins bool, autoScaleUpAllowDenyConfig *AllowDenyConfig) *Connector {
return &Connector{
metrics: metrics,
sendProxyProto: sendProxyProto,
@@ -76,6 +118,7 @@ func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyPr
trustedProxyNets: trustedProxyNets,
recordLogins: recordLogins,
autoScaleUpAllowDenyConfig: autoScaleUpAllowDenyConfig,
serverMetrics: NewServerMetrics(),
}
}
@@ -88,6 +131,7 @@ type Connector struct {
trustedProxyNets []*net.IPNet
activeConnections int32
serverMetrics *ServerMetrics
connectionsCond *sync.Cond
ngrokToken string
clientFilter *ClientFilter
@@ -348,10 +392,48 @@ func (c *Connector) readPlayerInfo(bufferedReader *bufio.Reader, clientAddr net.
}
}
func (c *Connector) cleanupBackendConnection(ctx context.Context, clientAddr net.Addr, serverAddress string, playerInfo *PlayerInfo, backendHostPort string, cleanupMetrics bool, checkScaleDown bool) {
if c.connectionNotifier != nil {
err := c.connectionNotifier.NotifyDisconnected(ctx, clientAddr, serverAddress, playerInfo, backendHostPort)
if err != nil {
logrus.WithError(err).Warn("failed to notify disconnected")
}
}
if cleanupMetrics {
c.metrics.ActiveConnections.Set(float64(
atomic.AddInt32(&c.activeConnections, -1)))
c.serverMetrics.DecrementActiveConnections(serverAddress)
c.metrics.ServerActiveConnections.
With("server_address", serverAddress).
Set(float64(c.serverMetrics.ActiveConnectionsValue(serverAddress)))
if c.recordLogins && playerInfo != nil {
c.metrics.ServerActivePlayer.
With("player_name", playerInfo.Name).
With("player_uuid", playerInfo.Uuid.String()).
With("server_address", serverAddress).
Set(0)
}
}
if checkScaleDown && c.serverMetrics.ActiveConnectionsValue(serverAddress) <= 0 {
DownScaler.Begin(serverAddress)
}
c.connectionsCond.Signal()
}
func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.Conn,
clientAddr net.Addr, preReadContent io.Reader, serverAddress string, playerInfo *PlayerInfo, nextState mcproto.State) {
backendHostPort, resolvedHost, waker := Routes.FindBackendForServerAddress(ctx, serverAddress)
backendHostPort, resolvedHost, waker, _ := Routes.FindBackendForServerAddress(ctx, serverAddress)
cleanupMetrics := false
cleanupCheckScaleDown := false
defer func() {
c.cleanupBackendConnection(ctx, clientAddr, serverAddress, playerInfo, backendHostPort, cleanupMetrics, cleanupCheckScaleDown)
}()
if waker != nil && nextState > mcproto.StateStatus {
serverAllowsPlayer := c.autoScaleUpAllowDenyConfig.ServerAllowsPlayer(serverAddress, playerInfo)
logrus.
@@ -361,6 +443,9 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
WithField("serverAllowsPlayer", serverAllowsPlayer).
Debug("checked if player is allowed to wake up the server")
if serverAllowsPlayer {
// Cancel down scaler if active before scale up
DownScaler.Cancel(serverAddress)
cleanupCheckScaleDown = true
if err := waker(ctx); err != nil {
logrus.WithFields(logrus.Fields{"serverAddress": serverAddress}).WithError(err).Error("failed to wake up backend")
c.metrics.Errors.With("type", "wakeup_failed").Add(1)
@@ -426,6 +511,12 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
c.metrics.ActiveConnections.Set(float64(
atomic.AddInt32(&c.activeConnections, 1)))
c.serverMetrics.IncrementActiveConnections(serverAddress)
c.metrics.ServerActiveConnections.
With("server_address", serverAddress).
Set(float64(c.serverMetrics.ActiveConnectionsValue(serverAddress)))
if c.recordLogins && playerInfo != nil {
logrus.
WithField("client", clientAddr).
@@ -446,24 +537,7 @@ func (c *Connector) findAndConnectBackend(ctx context.Context, frontendConn net.
Add(1)
}
defer func() {
if c.connectionNotifier != nil {
err := c.connectionNotifier.NotifyDisconnected(ctx, clientAddr, serverAddress, playerInfo, backendHostPort)
if err != nil {
logrus.WithError(err).Warn("failed to notify disconnected")
}
}
c.metrics.ActiveConnections.Set(float64(
atomic.AddInt32(&c.activeConnections, -1)))
if c.recordLogins && playerInfo != nil {
c.metrics.ServerActivePlayer.
With("player_name", playerInfo.Name).
With("player_uuid", playerInfo.Uuid.String()).
With("server_address", serverAddress).
Set(0)
}
c.connectionsCond.Signal()
}()
cleanupMetrics = true
// PROXY protocol implementation
if c.sendProxyProto {
+25 -6
View File
@@ -15,7 +15,7 @@ import (
)
type IDockerWatcher interface {
Start(socket string, timeoutSeconds int, refreshIntervalSeconds int) error
Start(socket string, timeoutSeconds int, refreshIntervalSeconds int, autoScaleUp bool, autoScaleDown bool) error
Stop()
}
@@ -31,19 +31,38 @@ var DockerWatcher IDockerWatcher = &dockerWatcherImpl{}
type dockerWatcherImpl struct {
sync.RWMutex
autoScaleUp bool
autoScaleDown bool
client *client.Client
contextCancel context.CancelFunc
}
func (w *dockerWatcherImpl) makeWakerFunc(_ *routableContainer) func(ctx context.Context) error {
func (w *dockerWatcherImpl) makeWakerFunc(_ *routableContainer) ScalerFunc {
if !w.autoScaleUp {
return nil
}
return func(ctx context.Context) error {
logrus.Fatal("Auto scale up is not yet supported for docker")
return nil
}
}
func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshIntervalSeconds int) error {
func (w *dockerWatcherImpl) makeSleeperFunc(_ *routableContainer) ScalerFunc {
if !w.autoScaleDown {
return nil
}
return func(ctx context.Context) error {
logrus.Fatal("Auto scale down is not yet supported for docker")
return nil
}
}
func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshIntervalSeconds int, autoScaleUp bool, autoScaleDown bool) error {
var err error
w.autoScaleUp = autoScaleUp
w.autoScaleDown = autoScaleDown
timeout := time.Duration(timeoutSeconds) * time.Second
refreshInterval := time.Duration(refreshIntervalSeconds) * time.Second
@@ -75,7 +94,7 @@ func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshInte
for _, c := range initialContainers {
containerMap[c.externalContainerName] = c
if c.externalContainerName != "" {
Routes.CreateMapping(c.externalContainerName, c.containerEndpoint, w.makeWakerFunc(c))
Routes.CreateMapping(c.externalContainerName, c.containerEndpoint, w.makeWakerFunc(c), w.makeSleeperFunc(c))
} else {
Routes.SetDefaultRoute(c.containerEndpoint)
}
@@ -97,7 +116,7 @@ func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshInte
containerMap[rs.externalContainerName] = rs
logrus.WithField("routableContainer", rs).Debug("ADD")
if rs.externalContainerName != "" {
Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs))
Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs))
} else {
Routes.SetDefaultRoute(rs.containerEndpoint)
}
@@ -105,7 +124,7 @@ func (w *dockerWatcherImpl) Start(socket string, timeoutSeconds int, refreshInte
containerMap[rs.externalContainerName] = rs
if rs.externalContainerName != "" {
Routes.DeleteMapping(rs.externalContainerName)
Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs))
Routes.CreateMapping(rs.externalContainerName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs))
} else {
Routes.SetDefaultRoute(rs.containerEndpoint)
}
+24 -5
View File
@@ -23,19 +23,38 @@ var DockerSwarmWatcher IDockerWatcher = &dockerSwarmWatcherImpl{}
type dockerSwarmWatcherImpl struct {
sync.RWMutex
autoScaleUp bool
autoScaleDown bool
client *client.Client
contextCancel context.CancelFunc
}
func (w *dockerSwarmWatcherImpl) makeWakerFunc(_ *routableService) func(ctx context.Context) error {
func (w *dockerSwarmWatcherImpl) makeWakerFunc(_ *routableService) ScalerFunc {
if !w.autoScaleUp {
return nil
}
return func(ctx context.Context) error {
logrus.Fatal("Auto scale up is not yet supported for docker swarm")
return nil
}
}
func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refreshIntervalSeconds int) error {
func (w *dockerSwarmWatcherImpl) makeSleeperFunc(_ *routableService) ScalerFunc {
if !w.autoScaleDown {
return nil
}
return func(ctx context.Context) error {
logrus.Fatal("Auto scale down is not yet supported for docker swarm")
return nil
}
}
func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refreshIntervalSeconds int, autoScaleUp bool, autoScaleDown bool) error {
var err error
w.autoScaleUp = autoScaleUp
w.autoScaleDown = autoScaleDown
timeout := time.Duration(timeoutSeconds) * time.Second
refreshInterval := time.Duration(refreshIntervalSeconds) * time.Second
@@ -67,7 +86,7 @@ func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refres
for _, s := range initialServices {
serviceMap[s.externalServiceName] = s
if s.externalServiceName != "" {
Routes.CreateMapping(s.externalServiceName, s.containerEndpoint, w.makeWakerFunc(s))
Routes.CreateMapping(s.externalServiceName, s.containerEndpoint, w.makeWakerFunc(s), w.makeSleeperFunc(s))
} else {
Routes.SetDefaultRoute(s.containerEndpoint)
}
@@ -89,7 +108,7 @@ func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refres
serviceMap[rs.externalServiceName] = rs
logrus.WithField("routableService", rs).Debug("ADD")
if rs.externalServiceName != "" {
Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs))
Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs))
} else {
Routes.SetDefaultRoute(rs.containerEndpoint)
}
@@ -97,7 +116,7 @@ func (w *dockerSwarmWatcherImpl) Start(socket string, timeoutSeconds int, refres
serviceMap[rs.externalServiceName] = rs
if rs.externalServiceName != "" {
Routes.DeleteMapping(rs.externalServiceName)
Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs))
Routes.CreateMapping(rs.externalServiceName, rs.containerEndpoint, w.makeWakerFunc(rs), w.makeSleeperFunc(rs))
} else {
Routes.SetDefaultRoute(rs.containerEndpoint)
}
+96
View File
@@ -0,0 +1,96 @@
package server
import (
"context"
"sync"
"time"
"github.com/sirupsen/logrus"
)
type IDownScaler interface {
Reset()
Begin(serverAddress string)
Cancel(serverAddress string)
}
var DownScaler IDownScaler
func NewDownScaler(ctx context.Context, enabled bool, delay time.Duration) IDownScaler {
ds := &downScalerImpl{
enabled: enabled,
delay: delay,
parentContext: ctx,
contextCancellations: make(map[string]context.CancelFunc),
}
return ds
}
type downScalerImpl struct {
sync.RWMutex
enabled bool
delay time.Duration
parentContext context.Context
contextCancellations map[string]context.CancelFunc
}
func (ds *downScalerImpl) Reset() {
// Cancel all existing scale down routines
for _, scaleDownCancel := range ds.contextCancellations {
scaleDownCancel()
}
ds.contextCancellations = make(map[string]context.CancelFunc)
}
func (ds *downScalerImpl) Begin(serverAddress string) {
ds.Lock()
defer ds.Unlock()
if !ds.enabled {
return
}
// If an existing scale down routine exists, cancel it
if scaleDownCancel, ok := ds.contextCancellations[serverAddress]; ok {
scaleDownCancel()
}
logrus.WithField("serverAddress", serverAddress).Debug("Beginning scale down")
scaleDownContext, scaleDownContextCancellation := context.WithCancel(ds.parentContext)
ds.contextCancellations[serverAddress] = scaleDownContextCancellation
go ds.scaleDown(scaleDownContext, serverAddress)
}
func (ds *downScalerImpl) Cancel(serverAddress string) {
ds.Lock()
defer ds.Unlock()
if !ds.enabled {
return
}
if scaleDownContextCancellation, ok := ds.contextCancellations[serverAddress]; ok {
logrus.WithField("serverAddress", serverAddress).Debug("Canceling scale down")
scaleDownContextCancellation()
delete(ds.contextCancellations, serverAddress)
}
}
func (ds *downScalerImpl) scaleDown(ctx context.Context, serverAddress string) {
for {
select {
case <-ctx.Done():
return
case <-time.After(ds.delay):
_, _, _, sleeper := Routes.FindBackendForServerAddress(ctx, serverAddress)
if sleeper == nil {
return
}
if err := sleeper(ctx); err != nil {
logrus.WithField("serverAddress", serverAddress).WithError(err).Error("failed to scale down backend")
}
return
}
}
}
+29 -17
View File
@@ -27,8 +27,8 @@ const (
)
type IK8sWatcher interface {
StartWithConfig(kubeConfigFile string, autoScaleUp bool) error
StartInCluster(autoScaleUp bool) error
StartWithConfig(kubeConfigFile string, autoScaleUp bool, autoScaleDown bool) error
StartInCluster(autoScaleUp bool, autoScaleDown bool) error
Stop()
}
@@ -36,6 +36,8 @@ var K8sWatcher IK8sWatcher = &k8sWatcherImpl{}
type k8sWatcherImpl struct {
sync.RWMutex
autoScaleUp bool
autoScaleDown bool
// The key in mappings is a Service, and the value the StatefulSet name
mappings map[string]string
@@ -43,26 +45,28 @@ type k8sWatcherImpl struct {
stop chan struct{}
}
func (w *k8sWatcherImpl) StartInCluster(autoScaleUp bool) error {
func (w *k8sWatcherImpl) StartInCluster(autoScaleUp bool, autoScaleDown bool) error {
config, err := rest.InClusterConfig()
if err != nil {
return errors.Wrap(err, "Unable to load in-cluster config")
}
return w.startWithLoadedConfig(config, autoScaleUp)
return w.startWithLoadedConfig(config, autoScaleUp, autoScaleDown)
}
func (w *k8sWatcherImpl) StartWithConfig(kubeConfigFile string, autoScaleUp bool) error {
func (w *k8sWatcherImpl) StartWithConfig(kubeConfigFile string, autoScaleUp bool, autoScaleDown bool) error {
config, err := clientcmd.BuildConfigFromFlags("", kubeConfigFile)
if err != nil {
return errors.Wrap(err, "Could not load kube config file")
}
return w.startWithLoadedConfig(config, autoScaleUp)
return w.startWithLoadedConfig(config, autoScaleUp, autoScaleDown)
}
func (w *k8sWatcherImpl) startWithLoadedConfig(config *rest.Config, autoScaleUp bool) error {
func (w *k8sWatcherImpl) startWithLoadedConfig(config *rest.Config, autoScaleUp bool, autoScaleDown bool) error {
w.stop = make(chan struct{}, 1)
w.autoScaleUp = autoScaleUp
w.autoScaleDown = autoScaleDown
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
@@ -88,7 +92,7 @@ func (w *k8sWatcherImpl) startWithLoadedConfig(config *rest.Config, autoScaleUp
go serviceController.Run(w.stop)
w.mappings = make(map[string]string)
if autoScaleUp {
if autoScaleUp || autoScaleDown {
_, statefulSetController := cache.NewInformer(
cache.NewListWatchFromClient(
clientset.AppsV1().RESTClient(),
@@ -156,7 +160,7 @@ func (w *k8sWatcherImpl) handleUpdate(oldObj interface{}, newObj interface{}) {
"new": newRoutableService,
}).Debug("UPDATE")
if newRoutableService.externalServiceName != "" {
Routes.CreateMapping(newRoutableService.externalServiceName, newRoutableService.containerEndpoint, newRoutableService.autoScaleUp)
Routes.CreateMapping(newRoutableService.externalServiceName, newRoutableService.containerEndpoint, newRoutableService.autoScaleUp, newRoutableService.autoScaleDown)
} else {
Routes.SetDefaultRoute(newRoutableService.containerEndpoint)
}
@@ -187,7 +191,7 @@ func (w *k8sWatcherImpl) handleAdd(obj interface{}) {
logrus.WithField("routableService", routableService).Debug("ADD")
if routableService.externalServiceName != "" {
Routes.CreateMapping(routableService.externalServiceName, routableService.containerEndpoint, routableService.autoScaleUp)
Routes.CreateMapping(routableService.externalServiceName, routableService.containerEndpoint, routableService.autoScaleUp, routableService.autoScaleDown)
} else {
Routes.SetDefaultRoute(routableService.containerEndpoint)
}
@@ -204,7 +208,8 @@ func (w *k8sWatcherImpl) Stop() {
type routableService struct {
externalServiceName string
containerEndpoint string
autoScaleUp func(ctx context.Context) error
autoScaleUp ScalerFunc
autoScaleDown ScalerFunc
}
// obj is expected to be a *v1.Service
@@ -239,12 +244,19 @@ func (w *k8sWatcherImpl) buildDetails(service *core.Service, externalServiceName
rs := &routableService{
externalServiceName: externalServiceName,
containerEndpoint: net.JoinHostPort(clusterIp, port),
autoScaleUp: w.buildScaleUpFunction(service),
autoScaleUp: w.buildScaleFunction(service, 0, 1),
autoScaleDown: w.buildScaleFunction(service, 1, 0),
}
return rs
}
func (w *k8sWatcherImpl) buildScaleUpFunction(service *core.Service) func(ctx context.Context) error {
func (w *k8sWatcherImpl) buildScaleFunction(service *core.Service, from int32, to int32) ScalerFunc {
if from <= to && !w.autoScaleUp {
return nil
}
if from >= to && !w.autoScaleDown {
return nil
}
return func(ctx context.Context) error {
serviceName := service.Name
if statefulSetName, exists := w.mappings[serviceName]; exists {
@@ -255,7 +267,7 @@ func (w *k8sWatcherImpl) buildScaleUpFunction(service *core.Service) func(ctx co
"statefulSet": statefulSetName,
"replicas": replicas,
}).Debug("StatefulSet of Service Replicas")
if replicas == 0 {
if replicas == from {
if _, err := w.clientset.AppsV1().StatefulSets(service.Namespace).UpdateScale(ctx, statefulSetName, &autoscaling.Scale{
ObjectMeta: meta.ObjectMeta{
Name: scale.Name,
@@ -263,15 +275,15 @@ func (w *k8sWatcherImpl) buildScaleUpFunction(service *core.Service) func(ctx co
UID: scale.UID,
ResourceVersion: scale.ResourceVersion,
},
Spec: autoscaling.ScaleSpec{Replicas: 1}}, meta.UpdateOptions{},
Spec: autoscaling.ScaleSpec{Replicas: to}}, meta.UpdateOptions{},
); err == nil {
logrus.WithFields(logrus.Fields{
"service": serviceName,
"statefulSet": statefulSetName,
"replicas": replicas,
}).Info("StatefulSet Replicas Autoscaled from 0 to 1 (wake up)")
}).Infof("StatefulSet Replicas Autoscaled from %d to %d", from, to)
} else {
return errors.Wrap(err, "UpdateScale for Replicas=1 failed for StatefulSet: "+statefulSetName)
return errors.Wrapf(err, "UpdateScale for Replicas=%d failed for StatefulSet: %s", to, statefulSetName)
}
}
} else {
+9 -4
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -78,6 +79,8 @@ func TestK8sWatcherImpl_handleAddThenUpdate(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// DownScaler needs to be instantiated
DownScaler = NewDownScaler(context.Background(), false, 1 * time.Second)
Routes.Reset()
watcher := &k8sWatcherImpl{}
@@ -87,7 +90,7 @@ func TestK8sWatcherImpl_handleAddThenUpdate(t *testing.T) {
watcher.handleAdd(&initialSvc)
for _, s := range test.initial.scenarios {
backend, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given)
backend, _, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given)
assert.Equal(t, s.expect, backend, "initial: given=%s", s.given)
}
@@ -97,7 +100,7 @@ func TestK8sWatcherImpl_handleAddThenUpdate(t *testing.T) {
watcher.handleUpdate(&initialSvc, &updatedSvc)
for _, s := range test.update.scenarios {
backend, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given)
backend, _, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given)
assert.Equal(t, s.expect, backend, "update: given=%s", s.given)
}
})
@@ -149,6 +152,8 @@ func TestK8sWatcherImpl_handleAddThenDelete(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// DownScaler needs to be instantiated
DownScaler = NewDownScaler(context.Background(), false, 1 * time.Second)
Routes.Reset()
watcher := &k8sWatcherImpl{}
@@ -158,13 +163,13 @@ func TestK8sWatcherImpl_handleAddThenDelete(t *testing.T) {
watcher.handleAdd(&initialSvc)
for _, s := range test.initial.scenarios {
backend, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given)
backend, _, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given)
assert.Equal(t, s.expect, backend, "initial: given=%s", s.given)
}
watcher.handleDelete(&initialSvc)
for _, s := range test.delete {
backend, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given)
backend, _, _, _ := Routes.FindBackendForServerAddress(context.Background(), s.given)
assert.Equal(t, s.expect, backend, "update: given=%s", s.given)
}
})
+22 -10
View File
@@ -12,6 +12,10 @@ import (
"github.com/sirupsen/logrus"
)
type ScalerFunc func(ctx context.Context) error
var EmptyScalerFunc = func(ctx context.Context) error { return nil }
var tcpShieldPattern = regexp.MustCompile("///.*")
func init() {
@@ -70,7 +74,7 @@ func routesCreateHandler(writer http.ResponseWriter, request *http.Request) {
return
}
Routes.CreateMapping(definition.ServerAddress, definition.Backend, func(ctx context.Context) error { return nil })
Routes.CreateMapping(definition.ServerAddress, definition.Backend, EmptyScalerFunc, EmptyScalerFunc)
RoutesConfig.AddMapping(definition.ServerAddress, definition.Backend)
writer.WriteHeader(http.StatusCreated)
}
@@ -102,10 +106,11 @@ type IRoutes interface {
// FindBackendForServerAddress returns the host:port for the external server address, if registered.
// Otherwise, an empty string is returned. Also returns the normalized version of the given serverAddress.
// The 3rd value returned is an (optional) "waker" function which a caller must invoke to wake up serverAddress.
FindBackendForServerAddress(ctx context.Context, serverAddress string) (string, string, func(ctx context.Context) error)
// The 4th value returned is an (optional) "sleeper" function which a caller must invoke to shut down serverAddress.
FindBackendForServerAddress(ctx context.Context, serverAddress string) (string, string, ScalerFunc, ScalerFunc)
GetMappings() map[string]string
DeleteMapping(serverAddress string) bool
CreateMapping(serverAddress string, backend string, waker func(ctx context.Context) error)
CreateMapping(serverAddress string, backend string, waker ScalerFunc, sleeper ScalerFunc)
SetDefaultRoute(backend string)
SimplifySRV(srvEnabled bool)
}
@@ -122,13 +127,14 @@ func NewRoutes() IRoutes {
func (r *routesImpl) RegisterAll(mappings map[string]string) {
for k, v := range mappings {
r.CreateMapping(k, v, func(ctx context.Context) error { return nil })
r.CreateMapping(k, v, EmptyScalerFunc, EmptyScalerFunc)
}
}
type mapping struct {
backend string
waker func(ctx context.Context) error
waker ScalerFunc
sleeper ScalerFunc
}
type routesImpl struct {
@@ -140,6 +146,7 @@ type routesImpl struct {
func (r *routesImpl) Reset() {
r.mappings = make(map[string]mapping)
DownScaler.Reset()
}
func (r *routesImpl) SetDefaultRoute(backend string) {
@@ -154,7 +161,7 @@ func (r *routesImpl) SimplifySRV(srvEnabled bool) {
r.simplifySRV = srvEnabled
}
func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddress string) (string, string, func(ctx context.Context) error) {
func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddress string) (string, string, ScalerFunc, ScalerFunc) {
r.RLock()
defer r.RUnlock()
@@ -190,10 +197,10 @@ func (r *routesImpl) FindBackendForServerAddress(_ context.Context, serverAddres
if r.mappings != nil {
if mapping, exists := r.mappings[serverAddress]; exists {
return mapping.backend, serverAddress, mapping.waker
return mapping.backend, serverAddress, mapping.waker, mapping.sleeper
}
}
return r.defaultRoute, serverAddress, nil
return r.defaultRoute, serverAddress, nil, nil
}
func (r *routesImpl) GetMappings() map[string]string {
@@ -212,6 +219,8 @@ func (r *routesImpl) DeleteMapping(serverAddress string) bool {
defer r.Unlock()
logrus.WithField("serverAddress", serverAddress).Info("Deleting route")
DownScaler.Cancel(serverAddress)
if _, ok := r.mappings[serverAddress]; ok {
delete(r.mappings, serverAddress)
return true
@@ -220,7 +229,7 @@ func (r *routesImpl) DeleteMapping(serverAddress string) bool {
}
}
func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker func(ctx context.Context) error) {
func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker ScalerFunc, sleeper ScalerFunc) {
r.Lock()
defer r.Unlock()
@@ -230,5 +239,8 @@ func (r *routesImpl) CreateMapping(serverAddress string, backend string, waker f
"serverAddress": serverAddress,
"backend": backend,
}).Info("Created route mapping")
r.mappings[serverAddress] = mapping{backend: backend, waker: waker}
r.mappings[serverAddress] = mapping{backend: backend, waker: waker, sleeper: sleeper}
// Trigger auto scale down when mapping is created to ensure servers are shut down if router restarts
DownScaler.Begin(serverAddress)
}
+2 -2
View File
@@ -66,9 +66,9 @@ func Test_routesImpl_FindBackendForServerAddress(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
r := NewRoutes()
r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend, func(ctx context.Context) error { return nil })
r.CreateMapping(tt.mapping.serverAddress, tt.mapping.backend, EmptyScalerFunc, EmptyScalerFunc)
if got, server, _ := r.FindBackendForServerAddress(context.Background(), tt.args.serverAddress); got != tt.want {
if got, server, _, _ := r.FindBackendForServerAddress(context.Background(), tt.args.serverAddress); got != tt.want {
t.Errorf("routesImpl.FindBackendForServerAddress() = %v, want %v", got, tt.want)
} else {
assert.Equal(t, tt.mapping.serverAddress, server)