feat: add ability to receive proxy protocol (#307)
This commit is contained in:
+64
-23
@@ -3,14 +3,15 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"golang.ngrok.com/ngrok"
|
||||
"golang.ngrok.com/ngrok/config"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.ngrok.com/ngrok"
|
||||
"golang.ngrok.com/ngrok/config"
|
||||
|
||||
"github.com/go-kit/kit/metrics"
|
||||
"github.com/itzg/mc-router/mcproto"
|
||||
"github.com/juju/ratelimit"
|
||||
@@ -31,19 +32,22 @@ type ConnectorMetrics struct {
|
||||
ActiveConnections metrics.Gauge
|
||||
}
|
||||
|
||||
func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool) *Connector {
|
||||
|
||||
func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet) *Connector {
|
||||
return &Connector{
|
||||
metrics: metrics,
|
||||
sendProxyProto: sendProxyProto,
|
||||
connectionsCond: sync.NewCond(&sync.Mutex{}),
|
||||
metrics: metrics,
|
||||
sendProxyProto: sendProxyProto,
|
||||
connectionsCond: sync.NewCond(&sync.Mutex{}),
|
||||
receiveProxyProto: receiveProxyProto,
|
||||
trustedProxyNets: trustedProxyNets,
|
||||
}
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
state mcproto.State
|
||||
metrics *ConnectorMetrics
|
||||
sendProxyProto bool
|
||||
state mcproto.State
|
||||
metrics *ConnectorMetrics
|
||||
sendProxyProto bool
|
||||
receiveProxyProto bool
|
||||
trustedProxyNets []*net.IPNet
|
||||
|
||||
activeConnections int32
|
||||
connectionsCond *sync.Cond
|
||||
@@ -51,9 +55,17 @@ type Connector struct {
|
||||
}
|
||||
|
||||
func (c *Connector) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error {
|
||||
ln, err := c.createListener(ctx, listenAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ln net.Listener
|
||||
var err error
|
||||
go c.acceptConnections(ctx, ln, connRateLimit)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connector) createListener(ctx context.Context, listenAddress string) (net.Listener, error) {
|
||||
if c.ngrokToken != "" {
|
||||
ngrokTun, err := ngrok.Listen(ctx,
|
||||
config.TCPEndpoint(),
|
||||
@@ -61,22 +73,51 @@ func (c *Connector) StartAcceptingConnections(ctx context.Context, listenAddress
|
||||
)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Unable to start ngrok tunnel")
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
ln = ngrokTun
|
||||
logrus.WithField("ngrokUrl", ngrokTun.URL()).Info("Listening for Minecraft client connections via ngrok tunnel")
|
||||
} else {
|
||||
ln, err = net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Unable to start listening")
|
||||
return err
|
||||
}
|
||||
logrus.WithField("listenAddress", listenAddress).Info("Listening for Minecraft client connections")
|
||||
return ngrokTun, nil
|
||||
}
|
||||
|
||||
go c.acceptConnections(ctx, ln, connRateLimit)
|
||||
listener, err := net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Unable to start listening")
|
||||
return nil, err
|
||||
}
|
||||
logrus.WithField("listenAddress", listenAddress).Info("Listening for Minecraft client connections")
|
||||
|
||||
return nil
|
||||
if c.receiveProxyProto {
|
||||
proxyListener := &proxyproto.Listener{
|
||||
Listener: listener,
|
||||
Policy: c.createProxyProtoPolicy(),
|
||||
}
|
||||
logrus.Info("Using PROXY protocol listener")
|
||||
return proxyListener, nil
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func (c *Connector) createProxyProtoPolicy() func(upstream net.Addr) (proxyproto.Policy, error) {
|
||||
return func(upstream net.Addr) (proxyproto.Policy, error) {
|
||||
trustedIpNets := c.trustedProxyNets
|
||||
|
||||
if len(trustedIpNets) == 0 {
|
||||
logrus.Debug("No trusted proxy networks configured, using the PROXY header by default")
|
||||
return proxyproto.USE, nil
|
||||
}
|
||||
|
||||
upstreamIP := upstream.(*net.TCPAddr).IP
|
||||
for _, ipNet := range trustedIpNets {
|
||||
if ipNet.Contains(upstreamIP) {
|
||||
logrus.WithField("upstream", upstream).Debug("IP is in trusted proxies, using the PROXY header")
|
||||
return proxyproto.USE, nil
|
||||
}
|
||||
}
|
||||
|
||||
logrus.WithField("upstream", upstream).Debug("IP is not in trusted proxies, discarding PROXY header")
|
||||
return proxyproto.IGNORE, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connector) WaitForConnections() {
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTrustedProxyNetworkPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
trustedNets []string
|
||||
upstreamIP string
|
||||
expectedPolicy proxyproto.Policy
|
||||
}{
|
||||
{
|
||||
name: "trusted IP",
|
||||
trustedNets: []string{"10.0.0.0/8"},
|
||||
upstreamIP: "10.0.0.1",
|
||||
expectedPolicy: proxyproto.USE,
|
||||
},
|
||||
{
|
||||
name: "untrusted IP",
|
||||
trustedNets: []string{"10.0.0.0/8"},
|
||||
upstreamIP: "192.168.1.1",
|
||||
expectedPolicy: proxyproto.IGNORE,
|
||||
},
|
||||
{
|
||||
name: "multiple trusted nets",
|
||||
trustedNets: []string{"10.0.0.0/8", "172.16.0.0/12"},
|
||||
upstreamIP: "172.16.0.1",
|
||||
expectedPolicy: proxyproto.USE,
|
||||
},
|
||||
{
|
||||
name: "no trusted nets",
|
||||
trustedNets: []string{},
|
||||
upstreamIP: "148.184.129.202",
|
||||
expectedPolicy: proxyproto.USE,
|
||||
},
|
||||
{
|
||||
name: "remote trusted IP",
|
||||
trustedNets: []string{"203.0.113.0/24"},
|
||||
upstreamIP: "203.0.113.10",
|
||||
expectedPolicy: proxyproto.USE,
|
||||
},
|
||||
{
|
||||
name: "remote untrusted IP",
|
||||
trustedNets: []string{"203.0.113.0/24"},
|
||||
upstreamIP: "198.51.100.1",
|
||||
expectedPolicy: proxyproto.IGNORE,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
c := &Connector{
|
||||
trustedProxyNets: parseTrustedProxyNets(test.trustedNets),
|
||||
}
|
||||
|
||||
policy := c.createProxyProtoPolicy()
|
||||
upstreamAddr := &net.TCPAddr{IP: net.ParseIP(test.upstreamIP)}
|
||||
policyResult, _ := policy(upstreamAddr)
|
||||
assert.Equal(t, test.expectedPolicy, policyResult, "Unexpected policy result for %s", test.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func parseTrustedProxyNets(nets []string) []*net.IPNet {
|
||||
parsedNets := make([]*net.IPNet, 0, len(nets))
|
||||
for _, n := range nets {
|
||||
_, ipNet, _ := net.ParseCIDR(n)
|
||||
parsedNets = append(parsedNets, ipNet)
|
||||
}
|
||||
return parsedNets
|
||||
}
|
||||
Reference in New Issue
Block a user