diff --git a/brutal.go b/brutal.go index 93e76b3..df52d64 100644 --- a/brutal.go +++ b/brutal.go @@ -7,7 +7,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) const ( @@ -32,7 +32,7 @@ func WriteBrutalResponse(writer io.Writer, receiveBPS uint64, ok bool, message s if ok { common.Must(binary.Write(buffer, binary.BigEndian, receiveBPS)) } else { - err := rw.WriteVString(buffer, message) + err := varbin.Write(buffer, binary.BigEndian, message) if err != nil { return err } @@ -52,7 +52,7 @@ func ReadBrutalResponse(reader io.Reader) (uint64, error) { return receiveBPS, err } else { var message string - message, err = rw.ReadVString(reader) + message, err = varbin.ReadValue[string](reader, binary.BigEndian) if err != nil { return 0, err } diff --git a/go.mod b/go.mod index 4ff6c9f..494dd2d 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,13 @@ module github.com/sagernet/sing-mux -go 1.18 +go 1.20 require ( - github.com/hashicorp/yamux v0.1.1 - github.com/sagernet/sing v0.3.0 + github.com/hashicorp/yamux v0.1.2 + github.com/sagernet/sing v0.5.0-rc.4.0.20241020064342-b036e5c3ee02 github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 - golang.org/x/net v0.19.0 - golang.org/x/sys v0.16.0 + golang.org/x/net v0.30.0 + golang.org/x/sys v0.26.0 ) -require golang.org/x/text v0.14.0 // indirect +require golang.org/x/text v0.19.0 // indirect diff --git a/go.sum b/go.sum index 8ee6605..3d068f0 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= -github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= +github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= +github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= -github.com/sagernet/sing v0.3.0 h1:PIDVFZHnQAAYRL1UYqNM+0k5s8f/tb1lUW6UDcQiOc8= -github.com/sagernet/sing v0.3.0/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g= +github.com/sagernet/sing v0.5.0-rc.4.0.20241020064342-b036e5c3ee02 h1:dbGXq6JHiizl+YC6V+zQYL0/+SXFh32s62kpy7XbTYU= +github.com/sagernet/sing v0.5.0-rc.4.0.20241020064342-b036e5c3ee02/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/protocol.go b/protocol.go index d93268b..176d51c 100644 --- a/protocol.go +++ b/protocol.go @@ -12,6 +12,7 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) const ( @@ -41,14 +42,18 @@ type Request struct { } func ReadRequest(reader io.Reader) (*Request, error) { - version, err := rw.ReadByte(reader) + var ( + version byte + protocol byte + ) + err := binary.Read(reader, binary.BigEndian, &version) if err != nil { return nil, err } if version < Version0 || version > Version1 { return nil, E.New("unsupported version: ", version) } - protocol, err := rw.ReadByte(reader) + err = binary.Read(reader, binary.BigEndian, &protocol) if err != nil { return nil, err } @@ -166,13 +171,12 @@ type StreamResponse struct { func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) { var response StreamResponse - status, err := rw.ReadByte(reader) + err := binary.Read(reader, binary.BigEndian, &response.Status) if err != nil { return nil, err } - response.Status = status - if status == statusError { - response.Message, err = rw.ReadVString(reader) + if response.Status == statusError { + response.Message, err = varbin.ReadValue[string](reader, binary.BigEndian) if err != nil { return nil, err } diff --git a/server.go b/server.go index b97a97d..a377ce9 100644 --- a/server.go +++ b/server.go @@ -13,15 +13,24 @@ import ( "github.com/sagernet/sing/common/task" ) +// Deprecated: Use ServiceHandlerEx instead. +// +//nolint:staticcheck type ServiceHandler interface { N.TCPConnectionHandler N.UDPConnectionHandler } +type ServiceHandlerEx interface { + N.TCPConnectionHandlerEx + N.UDPConnectionHandlerEx +} + type Service struct { newStreamContext func(context.Context, net.Conn) context.Context logger logger.ContextLogger handler ServiceHandler + handlerEx ServiceHandlerEx padding bool brutal BrutalOptions } @@ -30,6 +39,7 @@ type ServiceOptions struct { NewStreamContext func(context.Context, net.Conn) context.Context Logger logger.ContextLogger Handler ServiceHandler + HandlerEx ServiceHandlerEx Padding bool Brutal BrutalOptions } @@ -42,12 +52,26 @@ func NewService(options ServiceOptions) (*Service, error) { newStreamContext: options.NewStreamContext, logger: options.Logger, handler: options.Handler, + handlerEx: options.HandlerEx, padding: options.Padding, brutal: options.Brutal, }, nil } +// Deprecated: Use NewConnectionEx instead. func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + return s.newConnection(ctx, conn, metadata.Source) +} + +func (s *Service) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandler) { + err := s.newConnection(ctx, conn, source) + if err != nil { + N.CloseOnHandshakeFailure(conn, onClose, err) + s.logger.ErrorContext(ctx, E.Cause(err, "process multiplex connection from ", source)) + } +} + +func (s *Service) newConnection(ctx context.Context, conn net.Conn, source M.Socksaddr) error { request, err := ReadRequest(conn) if err != nil { return err @@ -71,9 +95,10 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M } streamCtx := s.newStreamContext(ctx, stream) go func() { - hErr := s.newConnection(streamCtx, conn, stream, metadata) + hErr := s.newSession(streamCtx, conn, stream, source) if hErr != nil { - s.logger.ErrorContext(streamCtx, E.Cause(hErr, "handle connection")) + stream.Close() + s.logger.ErrorContext(streamCtx, E.Cause(hErr, "process multiplex stream")) } }() } @@ -84,13 +109,13 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M return group.Run(ctx) } -func (s *Service) newConnection(ctx context.Context, sessionConn net.Conn, stream net.Conn, metadata M.Metadata) error { +func (s *Service) newSession(ctx context.Context, sessionConn net.Conn, stream net.Conn, source M.Socksaddr) error { stream = &wrapStream{stream} request, err := ReadStreamRequest(stream) if err != nil { return E.Cause(err, "read multiplex stream request") } - metadata.Destination = request.Destination + destination := request.Destination if request.Network == N.NetworkTCP { conn := &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)} if request.Destination.Fqdn == BrutalExchangeDomain { @@ -128,20 +153,28 @@ func (s *Service) newConnection(ctx context.Context, sessionConn net.Conn, strea } return nil } - s.logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination) - s.handler.NewConnection(ctx, conn, metadata) - stream.Close() + s.logger.InfoContext(ctx, "inbound multiplex connection to ", destination) + if s.handler != nil { + //nolint:staticcheck + s.handler.NewConnection(ctx, conn, M.Metadata{Source: source, Destination: destination}) + } else { + s.handlerEx.NewConnectionEx(ctx, conn, source, destination, nil) + } } else { var packetConn N.PacketConn if !request.PacketAddr { - s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination) + s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", destination) packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination} } else { s.logger.InfoContext(ctx, "inbound multiplex packet connection") packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)} } - s.handler.NewPacketConnection(ctx, packetConn, metadata) - stream.Close() + if s.handler != nil { + //nolint:staticcheck + s.handler.NewPacketConnection(ctx, packetConn, M.Metadata{Source: source, Destination: destination}) + } else { + s.handlerEx.NewPacketConnectionEx(ctx, packetConn, source, destination, nil) + } } return nil } diff --git a/server_conn.go b/server_conn.go index 41151c1..31ed4bf 100644 --- a/server_conn.go +++ b/server_conn.go @@ -10,7 +10,7 @@ import ( "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) type serverConn struct { @@ -24,11 +24,11 @@ func (c *serverConn) NeedHandshake() bool { func (c *serverConn) HandshakeFailure(err error) error { errMessage := err.Error() - buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) + buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage)) defer buffer.Release() common.Must( buffer.WriteByte(statusError), - rw.WriteVString(buffer, errMessage), + varbin.Write(buffer, binary.BigEndian, errMessage), ) return common.Error(c.ExtendedConn.Write(buffer.Bytes())) } @@ -88,11 +88,11 @@ func (c *serverPacketConn) NeedHandshake() bool { func (c *serverPacketConn) HandshakeFailure(err error) error { errMessage := err.Error() - buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) + buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage)) defer buffer.Release() common.Must( buffer.WriteByte(statusError), - rw.WriteVString(buffer, errMessage), + varbin.Write(buffer, binary.BigEndian, errMessage), ) return common.Error(c.ExtendedConn.Write(buffer.Bytes())) } @@ -188,11 +188,11 @@ func (c *serverPacketAddrConn) NeedHandshake() bool { func (c *serverPacketAddrConn) HandshakeFailure(err error) error { errMessage := err.Error() - buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) + buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage)) defer buffer.Release() common.Must( buffer.WriteByte(statusError), - rw.WriteVString(buffer, errMessage), + varbin.Write(buffer, binary.BigEndian, errMessage), ) return common.Error(c.ExtendedConn.Write(buffer.Bytes())) }