diff --git a/cmd/mc-router/main.go b/cmd/mc-router/main.go index 2a14946..2019038 100644 --- a/cmd/mc-router/main.go +++ b/cmd/mc-router/main.go @@ -43,6 +43,8 @@ type Config struct { MetricsBackend string `default:"discard" usage:"Backend to use for metrics exposure/publishing: discard,expvar,influxdb"` UseProxyProtocol bool `default:"false" usage:"Send PROXY protocol to backend servers"` MetricsBackendConfig MetricsBackendConfig + + SimplifySRV bool `default:"false" usage:"Simplify fully qualified SRV records for mapping"` } var ( @@ -129,6 +131,8 @@ func main() { } } + server.Routes.SimplifySRV(config.SimplifySRV) + err = metricsBuilder.Start(ctx) if err != nil { logrus.WithError(err).Fatal("Unable to start metrics reporter") diff --git a/server/routes.go b/server/routes.go index 4cbfd69..a315c27 100644 --- a/server/routes.go +++ b/server/routes.go @@ -95,6 +95,7 @@ type IRoutes interface { DeleteMapping(serverAddress string) bool CreateMapping(serverAddress string, backend string, waker func(ctx context.Context) error) SetDefaultRoute(backend string) + SimplifySRV(srvEnabled bool) } var Routes IRoutes = &routesImpl{} @@ -126,6 +127,7 @@ type routesImpl struct { sync.RWMutex mappings map[string]mapping defaultRoute string + simplifySRV bool } func (r *routesImpl) SetDefaultRoute(backend string) { @@ -136,10 +138,31 @@ func (r *routesImpl) SetDefaultRoute(backend string) { }).Info("Using default route") } +func (r *routesImpl) SimplifySRV(srvEnabled bool) { + r.simplifySRV = srvEnabled +} + func (r *routesImpl) FindBackendForServerAddress(ctx context.Context, serverAddress string) (string, string, func(ctx context.Context) error) { r.RLock() defer r.RUnlock() + if r.simplifySRV { + serverAddress = strings.TrimSuffix(serverAddress, ".") + parts := strings.Split(serverAddress, ".") + tcpIndex := -1 + for i, part := range parts { + if part == "_tcp" { + tcpIndex = i + break + } + } + if tcpIndex != -1 { + parts = parts[tcpIndex+1:] + } + + serverAddress = strings.Join(parts, ".") + } + addressParts := strings.Split(serverAddress, "\x00") address := strings.ToLower(addressParts[0])