diff --git a/mcproto/read.go b/mcproto/read.go index ab3986e..9eb5118 100644 --- a/mcproto/read.go +++ b/mcproto/read.go @@ -1,20 +1,38 @@ package mcproto import ( + "bufio" "bytes" - "errors" + "encoding/binary" + "github.com/pkg/errors" "github.com/sirupsen/logrus" + "golang.org/x/text/encoding/unicode" + "golang.org/x/text/transform" "io" "net" "strings" "time" ) -func ReadPacket(reader io.Reader, addr net.Addr) (*Packet, error) { +func ReadPacket(reader io.Reader, addr net.Addr, state State) (*Packet, error) { logrus. WithField("client", addr). Debug("Reading packet") + if state == StateHandshaking { + bufReader := bufio.NewReader(reader) + data, err := bufReader.Peek(1) + if err != nil { + return nil, err + } + + if data[0] == PacketIdLegacyServerListPing { + return ReadLegacyServerListPing(bufReader, addr) + } else { + reader = bufReader + } + } + frame, err := ReadFrame(reader, addr) if err != nil { return nil, err @@ -38,6 +56,97 @@ func ReadPacket(reader io.Reader, addr net.Addr) (*Packet, error) { return packet, nil } +func ReadLegacyServerListPing(reader *bufio.Reader, addr net.Addr) (*Packet, error) { + logrus. + WithField("client", addr). + Debug("Reading legacy server list ping") + + packetId, err := reader.ReadByte() + if err != nil { + return nil, err + } + if packetId != PacketIdLegacyServerListPing { + return nil, errors.Errorf("expected legacy server listing ping packet ID, got %x", packetId) + } + + payload, err := reader.ReadByte() + if err != nil { + return nil, err + } + if payload != 0x01 { + return nil, errors.Errorf("expected payload=1 from legacy server listing ping, got %x", payload) + } + + packetIdForPluginMsg, err := reader.ReadByte() + if err != nil { + return nil, err + } + if packetIdForPluginMsg != 0xFA { + return nil, errors.Errorf("expected packetIdForPluginMsg=0xFA from legacy server listing ping, got %x", packetIdForPluginMsg) + } + + messageNameShortLen, err := ReadUnsignedShort(reader) + if err != nil { + return nil, err + } + if messageNameShortLen != 11 { + return nil, errors.Errorf("expected messageNameShortLen=11 from legacy server listing ping, got %d", messageNameShortLen) + } + + messageName, err := ReadUTF16BEString(reader, messageNameShortLen) + if messageName != "MC|PingHost" { + return nil, errors.Errorf("expected messageName=MC|PingHost, got %s", messageName) + } + + remainingLen, err := ReadUnsignedShort(reader) + remainingReader := io.LimitReader(reader, int64(remainingLen)) + + protocolVersion, err := ReadByte(remainingReader) + if err != nil { + return nil, err + } + + hostnameLen, err := ReadUnsignedShort(remainingReader) + if err != nil { + return nil, err + } + hostname, err := ReadUTF16BEString(remainingReader, hostnameLen) + if err != nil { + return nil, err + } + + port, err := ReadUnsignedInt(remainingReader) + if err != nil { + return nil, err + } + + return &Packet{ + PacketID: PacketIdLegacyServerListPing, + Length: 0, + Data: &LegacyServerListPing{ + ProtocolVersion: int(protocolVersion), + ServerAddress: hostname, + ServerPort: uint16(port), + }, + }, nil +} + +func ReadUTF16BEString(reader io.Reader, symbolLen uint16) (string, error) { + bsUtf16be := make([]byte, symbolLen*2) + + _, err := io.ReadFull(reader, bsUtf16be) + if err != nil { + return "", err + } + + result, _, err := transform.Bytes(unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder(), bsUtf16be) + if err != nil { + return "", err + } + + return string(result), nil +} + func ReadFrame(reader io.Reader, addr net.Addr) (*Frame, error) { logrus. WithField("client", addr). @@ -136,25 +245,43 @@ func ReadString(reader io.Reader) (string, error) { return strBuilder.String(), nil } -func ReadUnsignedShort(reader io.Reader) (uint16, error) { - upper := make([]byte, 1) - _, err := reader.Read(upper) +func ReadByte(reader io.Reader) (byte, error) { + buf := make([]byte, 1) + _, err := reader.Read(buf) if err != nil { return 0, err + } else { + return buf[0], nil } - lower := make([]byte, 1) - _, err = reader.Read(lower) - if err != nil { - return 0, err - } - - return (uint16(upper[0]) << 8) | uint16(lower[0]), nil } -func ReadHandshake(data []byte) (*Handshake, error) { +func ReadUnsignedShort(reader io.Reader) (uint16, error) { + var value uint16 + err := binary.Read(reader, binary.BigEndian, &value) + if err != nil { + return 0, err + } + return value, nil +} + +func ReadUnsignedInt(reader io.Reader) (uint32, error) { + var value uint32 + err := binary.Read(reader, binary.BigEndian, &value) + if err != nil { + return 0, err + } + return value, nil +} + +func ReadHandshake(data interface{}) (*Handshake, error) { + + dataBytes, ok := data.([]byte) + if !ok { + return nil, errors.New("data is not expected byte slice") + } handshake := &Handshake{} - buffer := bytes.NewBuffer(data) + buffer := bytes.NewBuffer(dataBytes) var err error handshake.ProtocolVersion, err = ReadVarInt(buffer) diff --git a/mcproto/read_test.go b/mcproto/read_test.go new file mode 100644 index 0000000..ed2c208 --- /dev/null +++ b/mcproto/read_test.go @@ -0,0 +1,36 @@ +package mcproto + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestReadVarInt(t *testing.T) { + tests := []struct { + Name string + Input []byte + Expected int + }{ + { + Name: "Single byte", + Input: []byte{0xFA, 0x00}, + Expected: 0x7A, + }, + { + Name: "Two byte", + Input: []byte{0x81, 0x04}, + Expected: 0x0201, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + result, err := ReadVarInt(bytes.NewBuffer(tt.Input)) + require.NoError(t, err) + + assert.Equal(t, tt.Expected, result) + }) + } +} diff --git a/mcproto/types.go b/mcproto/types.go index 8ef5360..5ca7c0d 100644 --- a/mcproto/types.go +++ b/mcproto/types.go @@ -7,6 +7,12 @@ type Frame struct { Payload []byte } +type State int + +const ( + StateHandshaking = iota +) + var trimLimit = 64 func trimBytes(data []byte) ([]byte, string) { @@ -25,15 +31,23 @@ func (f *Frame) String() string { type Packet struct { Length int PacketID int - Data []byte + // Data is either a byte slice of raw content or a parsed message + Data interface{} } func (p *Packet) String() string { - trimmed, cont := trimBytes(p.Data) - return fmt.Sprintf("Frame:[len=%d, packetId=%d, data=%#X%s]", p.Length, p.PacketID, trimmed, cont) + if dataBytes, ok := p.Data.([]byte); ok { + trimmed, cont := trimBytes(dataBytes) + return fmt.Sprintf("Frame:[len=%d, packetId=%d, data=%#X%s]", p.Length, p.PacketID, trimmed, cont) + } else { + return fmt.Sprintf("Frame:[len=%d, packetId=%d, data=%+v]", p.Length, p.PacketID, p.Data) + } } -const PacketIdHandshake = 0x00 +const ( + PacketIdHandshake = 0x00 + PacketIdLegacyServerListPing = 0xFE +) type Handshake struct { ProtocolVersion int @@ -42,6 +56,12 @@ type Handshake struct { NextState int } +type LegacyServerListPing struct { + ProtocolVersion int + ServerAddress string + ServerPort uint16 +} + type ByteReader interface { ReadByte() (byte, error) } diff --git a/server/connector.go b/server/connector.go index 435b6e2..7df2b78 100644 --- a/server/connector.go +++ b/server/connector.go @@ -12,7 +12,7 @@ import ( ) const ( - handshakeTimeout = 2 * time.Second + handshakeTimeout = 5 * time.Second ) var noDeadline time.Time @@ -21,9 +21,12 @@ type IConnector interface { StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error } -var Connector IConnector = &connectorImpl{} +var Connector IConnector = &connectorImpl{ + state: mcproto.StateHandshaking, +} type connectorImpl struct { + state mcproto.State } func (c *connectorImpl) StartAcceptingConnections(ctx context.Context, listenAddress string, connRateLimit int) error { @@ -70,71 +73,65 @@ func (c *connectorImpl) HandleConnection(ctx context.Context, frontendConn net.C logrus. WithField("client", clientAddr). Info("Got connection") + defer logrus.WithField("client", clientAddr).Debug("Closing frontend connection") inspectionBuffer := new(bytes.Buffer) inspectionReader := io.TeeReader(frontendConn, inspectionBuffer) - if err := frontendConn.SetReadDeadline(time.Now().Add(handshakeTimeout)); err != nil { - logrus. - WithError(err). - WithField("client", clientAddr). - Error("Failed to set read deadline") - return - } - packet, err := mcproto.ReadPacket(inspectionReader, clientAddr) + /* if err := frontendConn.SetReadDeadline(time.Now().Add(handshakeTimeout)); err != nil { + logrus. + WithError(err). + WithField("client", clientAddr). + Error("Failed to set read deadline") + return + } + */packet, err := mcproto.ReadPacket(inspectionReader, clientAddr, c.state) if err != nil { logrus.WithError(err).WithField("clientAddr", clientAddr).Error("Failed to read packet") return } - logrus.WithFields(logrus.Fields{"length": packet.Length, "packetID": packet.PacketID}).Info("Got packet") + logrus. + WithField("client", clientAddr). + WithField("length", packet.Length). + WithField("packetID", packet.PacketID). + Debug("Got packet") if packet.PacketID == mcproto.PacketIdHandshake { handshake, err := mcproto.ReadHandshake(packet.Data) if err != nil { - logrus.WithError(err).WithField("clientAddr", clientAddr).Error("Failed to read handshake") + logrus.WithError(err).WithField("clientAddr", clientAddr). + Error("Failed to read handshake") return } - logrus.WithFields(logrus.Fields{ - "protocolVersion": handshake.ProtocolVersion, - "server": handshake.ServerAddress, - "serverPort": handshake.ServerPort, - "nextState": handshake.NextState, - }).Info("Got handshake") + logrus. + WithField("client", clientAddr). + WithField("handshake", handshake). + Debug("Got handshake") - backendHostPort := Routes.FindBackendForServerAddress(handshake.ServerAddress) - if backendHostPort == "" { - logrus.WithField("serverAddress", handshake.ServerAddress).Warn("Unable to find registered backend") - return - } + serverAddress := handshake.ServerAddress - logrus.WithField("backendHostPort", backendHostPort).Info("Connecting to backend") - backendConn, err := net.Dial("tcp", backendHostPort) - if err != nil { - logrus.WithError(err).WithFields(logrus.Fields{ - "serverAddress": handshake.ServerAddress, - "backend": backendHostPort, - }).Warn("Unable to connect to backend") - return - } - - amount, err := io.Copy(backendConn, inspectionBuffer) - if err != nil { - logrus.WithError(err).Error("Failed to write handshake to backend connection") - return - } - logrus.WithField("amount", amount).Debug("Relayed handshake to backend") - - if err = frontendConn.SetReadDeadline(noDeadline); err != nil { + c.findAndConnectBackend(ctx, frontendConn, clientAddr, inspectionBuffer, serverAddress) + } else if packet.PacketID == mcproto.PacketIdLegacyServerListPing { + handshake, ok := packet.Data.(*mcproto.LegacyServerListPing) + if !ok { logrus. - WithError(err). WithField("client", clientAddr). - Error("Failed to clear read deadline") + WithField("packet", packet). + Warn("Unexpected data type for PacketIdLegacyServerListPing") return } - pumpConnections(ctx, frontendConn, backendConn) + + logrus. + WithField("client", clientAddr). + WithField("handshake", handshake). + Debug("Got legacy server list ping") + + serverAddress := handshake.ServerAddress + + c.findAndConnectBackend(ctx, frontendConn, clientAddr, inspectionBuffer, serverAddress) } else { logrus. WithField("client", clientAddr). @@ -144,40 +141,82 @@ func (c *connectorImpl) HandleConnection(ctx context.Context, frontendConn net.C } } +func (c *connectorImpl) findAndConnectBackend(ctx context.Context, frontendConn net.Conn, + clientAddr net.Addr, preReadContent io.Reader, serverAddress string) { + + backendHostPort := Routes.FindBackendForServerAddress(serverAddress) + if backendHostPort == "" { + logrus.WithField("serverAddress", serverAddress).Warn("Unable to find registered backend") + return + } + logrus. + WithField("client", clientAddr). + WithField("server", serverAddress). + WithField("backendHostPort", backendHostPort). + Info("Connecting to backend") + backendConn, err := net.Dial("tcp", backendHostPort) + if err != nil { + logrus. + WithError(err). + WithField("client", clientAddr). + WithField("serverAddress", serverAddress). + WithField("backend", backendHostPort). + Warn("Unable to connect to backend") + return + } + amount, err := io.Copy(backendConn, preReadContent) + if err != nil { + logrus.WithError(err).Error("Failed to write handshake to backend connection") + return + } + logrus.WithField("amount", amount).Debug("Relayed handshake to backend") + if err = frontendConn.SetReadDeadline(noDeadline); err != nil { + logrus. + WithError(err). + WithField("client", clientAddr). + Error("Failed to clear read deadline") + return + } + pumpConnections(ctx, frontendConn, backendConn) + return +} + func pumpConnections(ctx context.Context, frontendConn, backendConn net.Conn) { //noinspection GoUnhandledErrorResult defer backendConn.Close() + clientAddr := frontendConn.RemoteAddr() + defer logrus.WithField("client", clientAddr).Debug("Closing backend connection") errors := make(chan error, 2) go pumpFrames(backendConn, frontendConn, errors, "backend", "frontend", clientAddr) go pumpFrames(frontendConn, backendConn, errors, "frontend", "backend", clientAddr) - for { - select { - case err := <-errors: - if err != io.EOF { - logrus.WithError(err). - WithField("client", clientAddr). - Error("Error observed on connection relay") - } - - return - - case <-ctx.Done(): - return + select { + case err := <-errors: + if err != io.EOF { + logrus.WithError(err). + WithField("client", clientAddr). + Error("Error observed on connection relay") } + + case <-ctx.Done(): + logrus.Debug("Observed context cancellation") } } func pumpFrames(incoming io.Reader, outgoing io.Writer, errors chan<- error, from, to string, clientAddr net.Addr) { amount, err := io.Copy(outgoing, incoming) - if err != nil { - errors <- err - } logrus. WithField("client", clientAddr). WithField("amount", amount). Infof("Finished relay %s->%s", from, to) + + if err != nil { + errors <- err + } else { + // successful io.Copy return nil error, not EOF...to simulate that to trigger outer handling + errors <- io.EOF + } }