Add support for allow/deny clients by IP (#355)
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type addrMatcher struct {
|
||||
addrs []netip.Addr
|
||||
prefixes []netip.Prefix
|
||||
}
|
||||
|
||||
func newAddrMatcher(filters []string) (*addrMatcher, error) {
|
||||
addrs := make([]netip.Addr, 0)
|
||||
prefixes := make([]netip.Prefix, 0)
|
||||
|
||||
if filters != nil {
|
||||
for _, filter := range filters {
|
||||
if strings.Contains(filter, "/") {
|
||||
prefix, err := netip.ParsePrefix(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prefixes = append(prefixes, prefix)
|
||||
} else {
|
||||
addr, err := netip.ParseAddr(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &addrMatcher{
|
||||
addrs: addrs,
|
||||
prefixes: prefixes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *addrMatcher) Match(addr netip.Addr) bool {
|
||||
for _, a := range a.addrs {
|
||||
|
||||
// Before comparison, need to unmap addresses such as
|
||||
// ::ffff:127.0.0.1
|
||||
unmapped := addr.Unmap()
|
||||
if a == unmapped {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, p := range a.prefixes {
|
||||
if p.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *addrMatcher) Empty() bool {
|
||||
return len(a.addrs) == 0 && len(a.prefixes) == 0
|
||||
}
|
||||
|
||||
// ClientFilter performs allow/deny filtering of client IP addresses
|
||||
type ClientFilter struct {
|
||||
allow *addrMatcher
|
||||
deny *addrMatcher
|
||||
}
|
||||
|
||||
// NewClientFilter provides a mechanism to evaluate client IP addresses and determine if
|
||||
// they should be allowed access or not.
|
||||
// The allows and denies can each or both be nil or netip.ParseAddr allowed values.
|
||||
func NewClientFilter(allows []string, denies []string) (*ClientFilter, error) {
|
||||
allow, err := newAddrMatcher(allows)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid allow filter")
|
||||
}
|
||||
deny, err := newAddrMatcher(denies)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid deny filter")
|
||||
}
|
||||
return &ClientFilter{
|
||||
allow: allow,
|
||||
deny: deny,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Allow determines if the given address is allowed by this filter
|
||||
// where addrStr is a netip.ParseAddr allowed address
|
||||
func (f *ClientFilter) Allow(addrPort netip.AddrPort) bool {
|
||||
if !f.allow.Empty() {
|
||||
matched := f.allow.Match(addrPort.Addr())
|
||||
return matched
|
||||
}
|
||||
if !f.deny.Empty() {
|
||||
matched := f.deny.Match(addrPort.Addr())
|
||||
return !matched
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientFilter_Allow(t *testing.T) {
|
||||
type args struct {
|
||||
allow []string
|
||||
deny []string
|
||||
input string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
assertErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "defaults",
|
||||
args: args{
|
||||
input: "192.168.1.1",
|
||||
},
|
||||
want: true,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "just allow - matches",
|
||||
args: args{
|
||||
allow: []string{"192.168.1.1"},
|
||||
input: "192.168.1.1",
|
||||
},
|
||||
want: true,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "just allow - not match",
|
||||
args: args{
|
||||
allow: []string{"192.168.1.1"},
|
||||
input: "192.168.1.2",
|
||||
},
|
||||
want: false,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "just allow cidr - matches",
|
||||
args: args{
|
||||
allow: []string{"192.168.1.0/8"},
|
||||
input: "192.168.1.2",
|
||||
},
|
||||
want: true,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "just allow cidr or specific - matches cidr",
|
||||
args: args{
|
||||
allow: []string{"192.168.1.0/8", "192.168.2.5"},
|
||||
input: "192.168.1.2",
|
||||
},
|
||||
want: true,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "just allow cidr or specific - matches specific",
|
||||
args: args{
|
||||
allow: []string{"192.168.1.0/8", "192.168.2.5"},
|
||||
input: "192.168.2.5",
|
||||
},
|
||||
want: true,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "just deny - matches",
|
||||
args: args{
|
||||
deny: []string{"192.168.2.5"},
|
||||
input: "192.168.2.5",
|
||||
},
|
||||
want: false,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "just deny - not match",
|
||||
args: args{
|
||||
deny: []string{"192.168.2.5"},
|
||||
input: "192.168.1.1",
|
||||
},
|
||||
want: true,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "mix allow",
|
||||
args: args{
|
||||
allow: []string{"192.168.1.6"},
|
||||
deny: []string{"192.168.1.0/8"},
|
||||
input: "192.168.1.6",
|
||||
},
|
||||
want: true,
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := NewClientFilter(tt.args.allow, tt.args.deny)
|
||||
assert.NoError(t, err)
|
||||
addr, err := netip.ParseAddr(tt.args.input)
|
||||
assert.NoError(t, err)
|
||||
got := f.Allow(netip.AddrPortFrom(addr, 25565))
|
||||
assert.Equalf(t, tt.want, got, "Allow(%v)", tt.args.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientFilter(t *testing.T) {
|
||||
type args struct {
|
||||
allow []string
|
||||
deny []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
assertErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "allow single",
|
||||
args: args{
|
||||
allow: []string{"192.168.1.1"},
|
||||
},
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "allow cidr",
|
||||
args: args{
|
||||
allow: []string{"192.168.1.0/8"},
|
||||
},
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "deny single",
|
||||
args: args{
|
||||
deny: []string{"192.168.1.1"},
|
||||
},
|
||||
assertErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "allow invalid",
|
||||
args: args{
|
||||
allow: []string{"7"},
|
||||
},
|
||||
assertErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "deny invalid",
|
||||
args: args{
|
||||
deny: []string{"7"},
|
||||
},
|
||||
assertErr: assert.Error,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewClientFilter(tt.args.allow, tt.args.deny)
|
||||
tt.assertErr(t, err, fmt.Sprintf("NewClientFilter(%v, %v)", tt.args.allow, tt.args.deny))
|
||||
})
|
||||
}
|
||||
}
|
||||
+15
-1
@@ -33,13 +33,15 @@ type ConnectorMetrics struct {
|
||||
ActiveConnections metrics.Gauge
|
||||
}
|
||||
|
||||
func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet) *Connector {
|
||||
func NewConnector(metrics *ConnectorMetrics, sendProxyProto bool, receiveProxyProto bool, trustedProxyNets []*net.IPNet,
|
||||
clientFilter *ClientFilter) *Connector {
|
||||
return &Connector{
|
||||
metrics: metrics,
|
||||
sendProxyProto: sendProxyProto,
|
||||
connectionsCond: sync.NewCond(&sync.Mutex{}),
|
||||
receiveProxyProto: receiveProxyProto,
|
||||
trustedProxyNets: trustedProxyNets,
|
||||
clientFilter: clientFilter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +55,7 @@ type Connector struct {
|
||||
activeConnections int32
|
||||
connectionsCond *sync.Cond
|
||||
ngrokToken string
|
||||
clientFilter *ClientFilter
|
||||
}
|
||||
|
||||
func (c *Connector) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error {
|
||||
@@ -164,6 +167,17 @@ func (c *Connector) HandleConnection(ctx context.Context, frontendConn net.Conn)
|
||||
defer frontendConn.Close()
|
||||
|
||||
clientAddr := frontendConn.RemoteAddr()
|
||||
|
||||
if tcpAddr, ok := clientAddr.(*net.TCPAddr); ok {
|
||||
allow := c.clientFilter.Allow(tcpAddr.AddrPort())
|
||||
if !allow {
|
||||
logrus.WithField("client", clientAddr).Debug("Client is blocked")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
logrus.WithField("client", clientAddr).Warn("Remote address is not a TCP address, skipping filtering")
|
||||
}
|
||||
|
||||
logrus.
|
||||
WithField("client", clientAddr).
|
||||
Info("Got connection")
|
||||
|
||||
Reference in New Issue
Block a user