Add support for allow/deny clients by IP (#355)

This commit is contained in:
Geoff Bourne
2024-12-19 07:37:08 -06:00
committed by GitHub
parent 513e0b86a7
commit 7526a7078a
5 changed files with 334 additions and 33 deletions
+101
View File
@@ -0,0 +1,101 @@
package server
import (
"github.com/pkg/errors"
"net/netip"
"strings"
)
type addrMatcher struct {
addrs []netip.Addr
prefixes []netip.Prefix
}
func newAddrMatcher(filters []string) (*addrMatcher, error) {
addrs := make([]netip.Addr, 0)
prefixes := make([]netip.Prefix, 0)
if filters != nil {
for _, filter := range filters {
if strings.Contains(filter, "/") {
prefix, err := netip.ParsePrefix(filter)
if err != nil {
return nil, err
}
prefixes = append(prefixes, prefix)
} else {
addr, err := netip.ParseAddr(filter)
if err != nil {
return nil, err
}
addrs = append(addrs, addr)
}
}
}
return &addrMatcher{
addrs: addrs,
prefixes: prefixes,
}, nil
}
func (a *addrMatcher) Match(addr netip.Addr) bool {
for _, a := range a.addrs {
// Before comparison, need to unmap addresses such as
// ::ffff:127.0.0.1
unmapped := addr.Unmap()
if a == unmapped {
return true
}
}
for _, p := range a.prefixes {
if p.Contains(addr) {
return true
}
}
return false
}
func (a *addrMatcher) Empty() bool {
return len(a.addrs) == 0 && len(a.prefixes) == 0
}
// ClientFilter performs allow/deny filtering of client IP addresses
type ClientFilter struct {
allow *addrMatcher
deny *addrMatcher
}
// NewClientFilter provides a mechanism to evaluate client IP addresses and determine if
// they should be allowed access or not.
// The allows and denies can each or both be nil or netip.ParseAddr allowed values.
func NewClientFilter(allows []string, denies []string) (*ClientFilter, error) {
allow, err := newAddrMatcher(allows)
if err != nil {
return nil, errors.Wrap(err, "invalid allow filter")
}
deny, err := newAddrMatcher(denies)
if err != nil {
return nil, errors.Wrap(err, "invalid deny filter")
}
return &ClientFilter{
allow: allow,
deny: deny,
}, nil
}
// Allow determines if the given address is allowed by this filter
// where addrStr is a netip.ParseAddr allowed address
func (f *ClientFilter) Allow(addrPort netip.AddrPort) bool {
if !f.allow.Empty() {
matched := f.allow.Match(addrPort.Addr())
return matched
}
if !f.deny.Empty() {
matched := f.deny.Match(addrPort.Addr())
return !matched
}
return true
}
+173
View File
@@ -0,0 +1,173 @@
package server
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/netip"
"testing"
)
func TestClientFilter_Allow(t *testing.T) {
type args struct {
allow []string
deny []string
input string
}
tests := []struct {
name string
args args
want bool
assertErr assert.ErrorAssertionFunc
}{
{
name: "defaults",
args: args{
input: "192.168.1.1",
},
want: true,
assertErr: assert.NoError,
},
{
name: "just allow - matches",
args: args{
allow: []string{"192.168.1.1"},
input: "192.168.1.1",
},
want: true,
assertErr: assert.NoError,
},
{
name: "just allow - not match",
args: args{
allow: []string{"192.168.1.1"},
input: "192.168.1.2",
},
want: false,
assertErr: assert.NoError,
},
{
name: "just allow cidr - matches",
args: args{
allow: []string{"192.168.1.0/8"},
input: "192.168.1.2",
},
want: true,
assertErr: assert.NoError,
},
{
name: "just allow cidr or specific - matches cidr",
args: args{
allow: []string{"192.168.1.0/8", "192.168.2.5"},
input: "192.168.1.2",
},
want: true,
assertErr: assert.NoError,
},
{
name: "just allow cidr or specific - matches specific",
args: args{
allow: []string{"192.168.1.0/8", "192.168.2.5"},
input: "192.168.2.5",
},
want: true,
assertErr: assert.NoError,
},
{
name: "just deny - matches",
args: args{
deny: []string{"192.168.2.5"},
input: "192.168.2.5",
},
want: false,
assertErr: assert.NoError,
},
{
name: "just deny - not match",
args: args{
deny: []string{"192.168.2.5"},
input: "192.168.1.1",
},
want: true,
assertErr: assert.NoError,
},
{
name: "mix allow",
args: args{
allow: []string{"192.168.1.6"},
deny: []string{"192.168.1.0/8"},
input: "192.168.1.6",
},
want: true,
assertErr: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f, err := NewClientFilter(tt.args.allow, tt.args.deny)
assert.NoError(t, err)
addr, err := netip.ParseAddr(tt.args.input)
assert.NoError(t, err)
got := f.Allow(netip.AddrPortFrom(addr, 25565))
assert.Equalf(t, tt.want, got, "Allow(%v)", tt.args.input)
})
}
}
func TestNewClientFilter(t *testing.T) {
type args struct {
allow []string
deny []string
}
tests := []struct {
name string
args args
assertErr assert.ErrorAssertionFunc
}{
{
name: "default",
assertErr: assert.NoError,
},
{
name: "allow single",
args: args{
allow: []string{"192.168.1.1"},
},
assertErr: assert.NoError,
},
{
name: "allow cidr",
args: args{
allow: []string{"192.168.1.0/8"},
},
assertErr: assert.NoError,
},
{
name: "deny single",
args: args{
deny: []string{"192.168.1.1"},
},
assertErr: assert.NoError,
},
{
name: "allow invalid",
args: args{
allow: []string{"7"},
},
assertErr: assert.Error,
},
{
name: "deny invalid",
args: args{
deny: []string{"7"},
},
assertErr: assert.Error,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewClientFilter(tt.args.allow, tt.args.deny)
tt.assertErr(t, err, fmt.Sprintf("NewClientFilter(%v, %v)", tt.args.allow, tt.args.deny))
})
}
}
+15 -1
View File
@@ -33,13 +33,15 @@ type ConnectorMetrics struct {
ActiveConnections metrics.Gauge
}
func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet) *Connector {
func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet,
clientFilter *ClientFilter) *Connector {
return &Connector{
metrics: metrics,
sendProxyProto: sendProxyProto,
connectionsCond: sync.NewCond(&sync.Mutex{}),
receiveProxyProto: receiveProxyProto,
trustedProxyNets: trustedProxyNets,
clientFilter: clientFilter,
}
}
@@ -53,6 +55,7 @@ type Connector struct {
activeConnections int32
connectionsCond *sync.Cond
ngrokToken string
clientFilter *ClientFilter
}
func (c *Connector) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error {
@@ -164,6 +167,17 @@ func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn)
defer frontendConn.Close()
clientAddr := frontendConn.RemoteAddr()
if tcpAddr, ok := clientAddr.(*net.TCPAddr); ok {
allow := c.clientFilter.Allow(tcpAddr.AddrPort())
if !allow {
logrus.WithField("client", clientAddr).Debug("Client is blocked")
return
}
} else {
logrus.WithField("client", clientAddr).Warn("Remote address is not a TCP address, skipping filtering")
}
logrus.
WithField("client", clientAddr).
Info("Got connection")