From 9e478d95e2491af216cd19698699ad39ce89b773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 23 Jun 2024 16:05:08 +0800 Subject: [PATCH] Remove bad rw usages --- common/cond.go | 2 ++ common/metadata/serializer.go | 25 +++++++++----- protocol/socks/client.go | 38 +++++++++++++++++---- protocol/socks/handshake.go | 39 ++++++++++----------- protocol/socks/socks4/protocol.go | 24 ++++++------- protocol/socks/socks5/protocol.go | 57 ++++++++++++++++--------------- service/filemanager/default.go | 2 +- 7 files changed, 110 insertions(+), 77 deletions(-) diff --git a/common/cond.go b/common/cond.go index a4c66a787..6fe11bc2a 100644 --- a/common/cond.go +++ b/common/cond.go @@ -363,10 +363,12 @@ func Close(closers ...any) error { return retErr } +// Deprecated: wtf is this? type Starter interface { Start() error } +// Deprecated: wtf is this? func Start(starters ...any) error { for _, rawStarter := range starters { if rawStarter == nil { diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index b858f5c56..344d3d325 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -8,7 +8,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) const ( @@ -116,7 +116,7 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) erro return err } if !isBuffer { - err = rw.WriteBytes(writer, buffer.Bytes()) + err = common.Error(writer.Write(buffer.Bytes())) } return err } @@ -129,8 +129,9 @@ func (s *Serializer) AddrPortLen(destination Socksaddr) int { } } -func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) { - af, err := rw.ReadByte(reader) +func (s *Serializer) ReadAddress(rawRedaer io.Reader) (Socksaddr, error) { + reader := varbin.NewReader(rawRedaer) + af, err := reader.ReadByte() if err != nil { return Socksaddr{}, err } @@ -164,11 +165,12 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) { } func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) { - port, err := rw.ReadBytes(reader, 2) + var port uint16 + err := binary.Read(reader, binary.BigEndian, &port) if err != nil { return 0, E.Cause(err, "read port") } - return binary.BigEndian.Uint16(port), nil + return port, nil } func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) { @@ -194,12 +196,17 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err return addr, nil } -func ReadSockString(reader io.Reader) (string, error) { - strLen, err := rw.ReadByte(reader) +func ReadSockString(reader varbin.Reader) (string, error) { + strLen, err := reader.ReadByte() + if err != nil { + return "", err + } + strBytes := make([]byte, strLen) + _, err = io.ReadFull(reader, strBytes) if err != nil { return "", err } - return rw.ReadString(reader, int(strLen)) + return string(strBytes), nil } func WriteSocksString(buffer *buf.Buffer, str string) error { diff --git a/protocol/socks/client.go b/protocol/socks/client.go index 9e4a1f552..45b2affac 100644 --- a/protocol/socks/client.go +++ b/protocol/socks/client.go @@ -1,12 +1,15 @@ package socks import ( + std_bufio "bufio" "context" "net" "net/url" "os" "strings" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -118,31 +121,53 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock return nil, err } if c.version == Version4 && address.IsFqdn() { - tcpAddr, err := net.ResolveTCPAddr(network, address.String()) + var tcpAddr *net.TCPAddr + tcpAddr, err = net.ResolveTCPAddr(network, address.String()) if err != nil { tcpConn.Close() return nil, err } address = M.SocksaddrFromNet(tcpAddr) } + reader := std_bufio.NewReader(tcpConn) switch c.version { case Version4, Version4A: - _, err = ClientHandshake4(tcpConn, command, address, c.username) + _, err = ClientHandshake4(reader, tcpConn, command, address, c.username) if err != nil { tcpConn.Close() return nil, err } + if reader.Buffered() > 0 { + buffer := buf.NewSize(reader.Buffered()) + _, err = buffer.ReadFullFrom(reader, reader.Buffered()) + if err != nil { + tcpConn.Close() + return nil, err + } + return bufio.NewCachedConn(tcpConn, buffer), nil + } return tcpConn, nil case Version5: - response, err := ClientHandshake5(tcpConn, command, address, c.username, c.password) + var response socks5.Response + response, err = ClientHandshake5(reader, tcpConn, command, address, c.username, c.password) if err != nil { tcpConn.Close() return nil, err } if command == socks5.CommandConnect { + if reader.Buffered() > 0 { + buffer := buf.NewSize(reader.Buffered()) + _, err = buffer.ReadFullFrom(reader, reader.Buffered()) + if err != nil { + tcpConn.Close() + return nil, err + } + return bufio.NewCachedConn(tcpConn, buffer), nil + } return tcpConn, nil } - udpConn, err := c.dialer.DialContext(ctx, N.NetworkUDP, response.Bind) + var udpConn net.Conn + udpConn, err = c.dialer.DialContext(ctx, N.NetworkUDP, response.Bind) if err != nil { tcpConn.Close() return nil, err @@ -166,16 +191,17 @@ func (c *Client) BindContext(ctx context.Context, address M.Socksaddr) (net.Conn if err != nil { return nil, err } + reader := std_bufio.NewReader(tcpConn) switch c.version { case Version4, Version4A: - _, err = ClientHandshake4(tcpConn, socks4.CommandBind, address, c.username) + _, err = ClientHandshake4(reader, tcpConn, socks4.CommandBind, address, c.username) if err != nil { tcpConn.Close() return nil, err } return tcpConn, nil case Version5: - _, err = ClientHandshake5(tcpConn, socks5.CommandBind, address, c.username, c.password) + _, err = ClientHandshake5(reader, tcpConn, socks5.CommandBind, address, c.username, c.password) if err != nil { tcpConn.Close() return nil, err diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 9e742a27c..4c5909b53 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -1,6 +1,7 @@ package socks import ( + std_bufio "bufio" "context" "io" "net" @@ -13,7 +14,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" "github.com/sagernet/sing/protocol/socks/socks4" "github.com/sagernet/sing/protocol/socks/socks5" ) @@ -23,8 +24,8 @@ type Handler interface { N.UDPConnectionHandler } -func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) { - err := socks4.WriteRequest(conn, socks4.Request{ +func ClientHandshake4(reader varbin.Reader, writer io.Writer, command byte, destination M.Socksaddr, username string) (socks4.Response, error) { + err := socks4.WriteRequest(writer, socks4.Request{ Command: command, Destination: destination, Username: username, @@ -32,7 +33,7 @@ func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, if err != nil { return socks4.Response{}, err } - response, err := socks4.ReadResponse(conn) + response, err := socks4.ReadResponse(reader) if err != nil { return socks4.Response{}, err } @@ -42,32 +43,32 @@ func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, return response, err } -func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, username string, password string) (socks5.Response, error) { +func ClientHandshake5(reader varbin.Reader, writer io.Writer, command byte, destination M.Socksaddr, username string, password string) (socks5.Response, error) { var method byte if username == "" { method = socks5.AuthTypeNotRequired } else { method = socks5.AuthTypeUsernamePassword } - err := socks5.WriteAuthRequest(conn, socks5.AuthRequest{ + err := socks5.WriteAuthRequest(writer, socks5.AuthRequest{ Methods: []byte{method}, }) if err != nil { return socks5.Response{}, err } - authResponse, err := socks5.ReadAuthResponse(conn) + authResponse, err := socks5.ReadAuthResponse(reader) if err != nil { return socks5.Response{}, err } if authResponse.Method == socks5.AuthTypeUsernamePassword { - err = socks5.WriteUsernamePasswordAuthRequest(conn, socks5.UsernamePasswordAuthRequest{ + err = socks5.WriteUsernamePasswordAuthRequest(writer, socks5.UsernamePasswordAuthRequest{ Username: username, Password: password, }) if err != nil { return socks5.Response{}, err } - usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(conn) + usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(reader) if err != nil { return socks5.Response{}, err } @@ -77,14 +78,14 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, } else if authResponse.Method != socks5.AuthTypeNotRequired { return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method) } - err = socks5.WriteRequest(conn, socks5.Request{ + err = socks5.WriteRequest(writer, socks5.Request{ Command: command, Destination: destination, }) if err != nil { return socks5.Response{}, err } - response, err := socks5.ReadResponse(conn) + response, err := socks5.ReadResponse(reader) if err != nil { return socks5.Response{}, err } @@ -94,18 +95,14 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, return response, err } -func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { - version, err := rw.ReadByte(conn) +func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { + version, err := reader.ReadByte() if err != nil { return err } - return HandleConnection0(ctx, conn, version, authenticator, handler, metadata) -} - -func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { switch version { case socks4.Version: - request, err := socks4.ReadRequest0(conn) + request, err := socks4.ReadRequest0(reader) if err != nil { return err } @@ -142,7 +139,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent return E.New("socks4: unsupported command ", request.Command) } case socks5.Version: - authRequest, err := socks5.ReadAuthRequest0(conn) + authRequest, err := socks5.ReadAuthRequest0(reader) if err != nil { return err } @@ -167,7 +164,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent return err } if authMethod == socks5.AuthTypeUsernamePassword { - usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(conn) + usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(reader) if err != nil { return err } @@ -186,7 +183,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password) } } - request, err := socks5.ReadRequest(conn) + request, err := socks5.ReadRequest(reader) if err != nil { return err } diff --git a/protocol/socks/socks4/protocol.go b/protocol/socks/socks4/protocol.go index 8b3879d5c..d8b2ae001 100644 --- a/protocol/socks/socks4/protocol.go +++ b/protocol/socks/socks4/protocol.go @@ -10,7 +10,7 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) const ( @@ -31,8 +31,8 @@ type Request struct { Username string } -func ReadRequest(reader io.Reader) (request Request, err error) { - version, err := rw.ReadByte(reader) +func ReadRequest(reader varbin.Reader) (request Request, err error) { + version, err := reader.ReadByte() if err != nil { return } @@ -43,8 +43,8 @@ func ReadRequest(reader io.Reader) (request Request, err error) { return ReadRequest0(reader) } -func ReadRequest0(reader io.Reader) (request Request, err error) { - request.Command, err = rw.ReadByte(reader) +func ReadRequest0(reader varbin.Reader) (request Request, err error) { + request.Command, err = reader.ReadByte() if err != nil { return } @@ -108,7 +108,7 @@ func WriteRequest(writer io.Writer, request Request) error { common.Must1(buffer.WriteString(request.Destination.AddrString())) common.Must(buffer.WriteZero()) } - return rw.WriteBytes(writer, buffer.Bytes()) + return common.Error(writer.Write(buffer.Bytes())) } type Response struct { @@ -116,8 +116,8 @@ type Response struct { Destination M.Socksaddr } -func ReadResponse(reader io.Reader) (response Response, err error) { - version, err := rw.ReadByte(reader) +func ReadResponse(reader varbin.Reader) (response Response, err error) { + version, err := reader.ReadByte() if err != nil { return } @@ -125,7 +125,7 @@ func ReadResponse(reader io.Reader) (response Response, err error) { err = E.New("excepted socks4 response version 0, got ", version) return } - response.ReplyCode, err = rw.ReadByte(reader) + response.ReplyCode, err = reader.ReadByte() if err != nil { return } @@ -151,13 +151,13 @@ func WriteResponse(writer io.Writer, response Response) error { binary.Write(buffer, binary.BigEndian, response.Destination.Port), common.Error(buffer.Write(response.Destination.Addr.AsSlice())), ) - return rw.WriteBytes(writer, buffer.Bytes()) + return common.Error(writer.Write(buffer.Bytes())) } -func readString(reader io.Reader) (string, error) { +func readString(reader varbin.Reader) (string, error) { buffer := bytes.Buffer{} for { - b, err := rw.ReadByte(reader) + b, err := reader.ReadByte() if err != nil { return "", err } diff --git a/protocol/socks/socks5/protocol.go b/protocol/socks/socks5/protocol.go index 67d9797f7..29ff3db58 100644 --- a/protocol/socks/socks5/protocol.go +++ b/protocol/socks/socks5/protocol.go @@ -8,7 +8,7 @@ import ( "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/varbin" ) const ( @@ -55,11 +55,11 @@ func WriteAuthRequest(writer io.Writer, request AuthRequest) error { buffer.WriteByte(byte(len(request.Methods))), common.Error(buffer.Write(request.Methods)), ) - return rw.WriteBytes(writer, buffer.Bytes()) + return common.Error(writer.Write(buffer.Bytes())) } -func ReadAuthRequest(reader io.Reader) (request AuthRequest, err error) { - version, err := rw.ReadByte(reader) +func ReadAuthRequest(reader varbin.Reader) (request AuthRequest, err error) { + version, err := reader.ReadByte() if err != nil { return } @@ -70,12 +70,13 @@ func ReadAuthRequest(reader io.Reader) (request AuthRequest, err error) { return ReadAuthRequest0(reader) } -func ReadAuthRequest0(reader io.Reader) (request AuthRequest, err error) { - methodLen, err := rw.ReadByte(reader) +func ReadAuthRequest0(reader varbin.Reader) (request AuthRequest, err error) { + methodLen, err := reader.ReadByte() if err != nil { return } - request.Methods, err = rw.ReadBytes(reader, int(methodLen)) + request.Methods = make([]byte, methodLen) + _, err = io.ReadFull(reader, request.Methods) return } @@ -90,11 +91,11 @@ type AuthResponse struct { } func WriteAuthResponse(writer io.Writer, response AuthResponse) error { - return rw.WriteBytes(writer, []byte{Version, response.Method}) + return common.Error(writer.Write([]byte{Version, response.Method})) } -func ReadAuthResponse(reader io.Reader) (response AuthResponse, err error) { - version, err := rw.ReadByte(reader) +func ReadAuthResponse(reader varbin.Reader) (response AuthResponse, err error) { + version, err := reader.ReadByte() if err != nil { return } @@ -102,7 +103,7 @@ func ReadAuthResponse(reader io.Reader) (response AuthResponse, err error) { err = E.New("expected socks version 5, got ", version) return } - response.Method, err = rw.ReadByte(reader) + response.Method, err = reader.ReadByte() return } @@ -125,11 +126,11 @@ func WriteUsernamePasswordAuthRequest(writer io.Writer, request UsernamePassword M.WriteSocksString(buffer, request.Username), M.WriteSocksString(buffer, request.Password), ) - return rw.WriteBytes(writer, buffer.Bytes()) + return common.Error(writer.Write(buffer.Bytes())) } -func ReadUsernamePasswordAuthRequest(reader io.Reader) (request UsernamePasswordAuthRequest, err error) { - version, err := rw.ReadByte(reader) +func ReadUsernamePasswordAuthRequest(reader varbin.Reader) (request UsernamePasswordAuthRequest, err error) { + version, err := reader.ReadByte() if err != nil { return } @@ -159,11 +160,11 @@ type UsernamePasswordAuthResponse struct { } func WriteUsernamePasswordAuthResponse(writer io.Writer, response UsernamePasswordAuthResponse) error { - return rw.WriteBytes(writer, []byte{1, response.Status}) + return common.Error(writer.Write([]byte{1, response.Status})) } -func ReadUsernamePasswordAuthResponse(reader io.Reader) (response UsernamePasswordAuthResponse, err error) { - version, err := rw.ReadByte(reader) +func ReadUsernamePasswordAuthResponse(reader varbin.Reader) (response UsernamePasswordAuthResponse, err error) { + version, err := reader.ReadByte() if err != nil { return } @@ -171,7 +172,7 @@ func ReadUsernamePasswordAuthResponse(reader io.Reader) (response UsernamePasswo err = E.New("excepted password request version 1, got ", version) return } - response.Status, err = rw.ReadByte(reader) + response.Status, err = reader.ReadByte() return } @@ -198,11 +199,11 @@ func WriteRequest(writer io.Writer, request Request) error { if err != nil { return err } - return rw.WriteBytes(writer, buffer.Bytes()) + return common.Error(writer.Write(buffer.Bytes())) } -func ReadRequest(reader io.Reader) (request Request, err error) { - version, err := rw.ReadByte(reader) +func ReadRequest(reader varbin.Reader) (request Request, err error) { + version, err := reader.ReadByte() if err != nil { return } @@ -210,11 +211,11 @@ func ReadRequest(reader io.Reader) (request Request, err error) { err = E.New("expected socks version 5, got ", version) return } - request.Command, err = rw.ReadByte(reader) + request.Command, err = reader.ReadByte() if err != nil { return } - err = rw.Skip(reader) + _, err = reader.ReadByte() if err != nil { return } @@ -252,11 +253,11 @@ func WriteResponse(writer io.Writer, response Response) error { if err != nil { return err } - return rw.WriteBytes(writer, buffer.Bytes()) + return common.Error(writer.Write(buffer.Bytes())) } -func ReadResponse(reader io.Reader) (response Response, err error) { - version, err := rw.ReadByte(reader) +func ReadResponse(reader varbin.Reader) (response Response, err error) { + version, err := reader.ReadByte() if err != nil { return } @@ -264,11 +265,11 @@ func ReadResponse(reader io.Reader) (response Response, err error) { err = E.New("expected socks version 5, got ", version) return } - response.ReplyCode, err = rw.ReadByte(reader) + response.ReplyCode, err = reader.ReadByte() if err != nil { return } - err = rw.Skip(reader) + _, err = reader.ReadByte() if err != nil { return } diff --git a/service/filemanager/default.go b/service/filemanager/default.go index c5744048c..372bc0162 100644 --- a/service/filemanager/default.go +++ b/service/filemanager/default.go @@ -44,7 +44,7 @@ func (m *defaultManager) BasePath(name string) string { func (m *defaultManager) OpenFile(name string, flag int, perm os.FileMode) (*os.File, error) { name = m.BasePath(name) - willCreate := flag&os.O_CREATE != 0 && !rw.FileExists(name) + willCreate := flag&os.O_CREATE != 0 && !rw.IsFile(name) file, err := os.OpenFile(name, flag, perm) if err != nil { return nil, err