197 lines
5.2 KiB
Go
197 lines
5.2 KiB
Go
package server
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/itzg/mc-router/mcproto"
|
|
"github.com/pires/go-proxyproto"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestTrustedProxyNetworkPolicy(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
trustedNets []string
|
|
upstreamIP string
|
|
expectedPolicy proxyproto.Policy
|
|
}{
|
|
{
|
|
name: "trusted IP",
|
|
trustedNets: []string{"10.0.0.0/8"},
|
|
upstreamIP: "10.0.0.1",
|
|
expectedPolicy: proxyproto.USE,
|
|
},
|
|
{
|
|
name: "untrusted IP",
|
|
trustedNets: []string{"10.0.0.0/8"},
|
|
upstreamIP: "192.168.1.1",
|
|
expectedPolicy: proxyproto.IGNORE,
|
|
},
|
|
{
|
|
name: "multiple trusted nets",
|
|
trustedNets: []string{"10.0.0.0/8", "172.16.0.0/12"},
|
|
upstreamIP: "172.16.0.1",
|
|
expectedPolicy: proxyproto.USE,
|
|
},
|
|
{
|
|
name: "no trusted nets",
|
|
trustedNets: []string{},
|
|
upstreamIP: "148.184.129.202",
|
|
expectedPolicy: proxyproto.USE,
|
|
},
|
|
{
|
|
name: "remote trusted IP",
|
|
trustedNets: []string{"203.0.113.0/24"},
|
|
upstreamIP: "203.0.113.10",
|
|
expectedPolicy: proxyproto.USE,
|
|
},
|
|
{
|
|
name: "remote untrusted IP",
|
|
trustedNets: []string{"203.0.113.0/24"},
|
|
upstreamIP: "198.51.100.1",
|
|
expectedPolicy: proxyproto.IGNORE,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
c := &Connector{
|
|
trustedProxyNets: parseTrustedProxyNets(test.trustedNets),
|
|
}
|
|
|
|
policy := c.createProxyProtoPolicy()
|
|
upstreamAddr := &net.TCPAddr{IP: net.ParseIP(test.upstreamIP)}
|
|
policyResult, _ := policy(proxyproto.ConnPolicyOptions{
|
|
Upstream: upstreamAddr,
|
|
})
|
|
assert.Equal(t, test.expectedPolicy, policyResult, "Unexpected policy result for %s", test.name)
|
|
})
|
|
}
|
|
}
|
|
|
|
func parseTrustedProxyNets(nets []string) []*net.IPNet {
|
|
parsedNets := make([]*net.IPNet, 0, len(nets))
|
|
for _, n := range nets {
|
|
_, ipNet, _ := net.ParseCIDR(n)
|
|
parsedNets = append(parsedNets, ipNet)
|
|
}
|
|
return parsedNets
|
|
}
|
|
|
|
func TestConnectorWakeTracking(t *testing.T) {
|
|
c := &Connector{wakingServers: NewActiveConnections()}
|
|
|
|
assert.False(t, c.isWakeInProgress("scale-target"))
|
|
c.wakingServers.Increment("scale-target")
|
|
assert.True(t, c.isWakeInProgress("scale-target"))
|
|
|
|
// track concurrent wake operations for same route
|
|
c.wakingServers.Increment("scale-target")
|
|
c.wakingServers.Decrement("scale-target")
|
|
assert.True(t, c.isWakeInProgress("scale-target"))
|
|
|
|
c.wakingServers.Decrement("scale-target")
|
|
assert.False(t, c.isWakeInProgress("scale-target"))
|
|
}
|
|
|
|
func TestConnectorGetLoadingMOTD(t *testing.T) {
|
|
oldRoutes := Routes
|
|
defer func() {
|
|
Routes = oldRoutes
|
|
}()
|
|
|
|
Routes = NewRoutes()
|
|
Routes.CreateMapping("mc.example.com", "backend:25565", "", nil, nil, "", "route loading")
|
|
|
|
c := &Connector{loadingMOTD: "global loading"}
|
|
assert.Equal(t, "route loading", c.getLoadingMOTD("mc.example.com"))
|
|
assert.Equal(t, "global loading", c.getLoadingMOTD("other.example.com"))
|
|
|
|
Routes.SetDefaultRoute("default:25565", "", nil, nil, "", "default loading")
|
|
assert.Equal(t, "default loading", c.getLoadingMOTD(""))
|
|
}
|
|
|
|
func writeTestPacket(w io.Writer, packetID int32, payload func(w io.Writer)) error {
|
|
var b bytes.Buffer
|
|
_ = mcproto.WriteVarInt(&b, packetID)
|
|
if payload != nil {
|
|
payload(&b)
|
|
}
|
|
|
|
var framed bytes.Buffer
|
|
_ = mcproto.WriteVarInt(&framed, int32(b.Len()))
|
|
framed.Write(b.Bytes())
|
|
_, err := w.Write(framed.Bytes())
|
|
return err
|
|
}
|
|
|
|
func TestConnectorMOTDFallback(t *testing.T) {
|
|
oldRoutes := Routes
|
|
defer func() {
|
|
Routes = oldRoutes
|
|
}()
|
|
|
|
Routes = NewRoutes()
|
|
|
|
backendAddress := "127.0.0.1:0"
|
|
|
|
scaleUpCalled := false
|
|
waker := func(ctx context.Context) (string, error) {
|
|
scaleUpCalled = true
|
|
return backendAddress, nil
|
|
}
|
|
|
|
Routes.CreateMapping("mc.example.com", backendAddress, "", waker, nil, "fallback asleep", "fallback loading")
|
|
|
|
metricsBuilder := discardMetricsBuilder{}
|
|
c := NewConnector(context.Background(), metricsBuilder.BuildConnectorMetrics(), false, false, nil)
|
|
c.UseAsleepMOTD("global asleep")
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
go c.acceptConnections(ln, 100, 0)
|
|
|
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
|
require.NoError(t, err)
|
|
defer clientConn.Close()
|
|
|
|
err = writeTestPacket(clientConn, 0x00, func(w io.Writer) {
|
|
_ = mcproto.WriteVarInt(w, 758)
|
|
_ = mcproto.WriteString(w, "mc.example.com")
|
|
w.Write([]byte{0x63, 0xdd})
|
|
_ = mcproto.WriteVarInt(w, 1)
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// 2. Send Status Request
|
|
err = writeTestPacket(clientConn, 0x00, nil)
|
|
require.NoError(t, err)
|
|
|
|
// 3. Read Status Response
|
|
_ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
|
reader := bufio.NewReader(clientConn)
|
|
|
|
frame, err := mcproto.ReadFrame(reader, clientConn.RemoteAddr())
|
|
require.NoError(t, err)
|
|
|
|
packetReader := bytes.NewReader(frame.Payload)
|
|
packetID, err := mcproto.ReadVarInt(packetReader)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 0x00, packetID)
|
|
|
|
jsonStr, err := mcproto.ReadString(packetReader)
|
|
require.NoError(t, err)
|
|
|
|
assert.Contains(t, jsonStr, "fallback asleep")
|
|
assert.False(t, scaleUpCalled, "Waker should NOT be called for a status request")
|
|
}
|