Skip to content

Commit

Permalink
Remove bad rw usages
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Jun 23, 2024
1 parent 30bcee1 commit 9e478d9
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 77 deletions.
2 changes: 2 additions & 0 deletions common/cond.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,12 @@ func Close(closers ...any) error {
return retErr
}

// Deprecated: wtf is this?
type Starter interface {
Start() error
}

// Deprecated: wtf is this?
func Start(starters ...any) error {
for _, rawStarter := range starters {
if rawStarter == nil {
Expand Down
25 changes: 16 additions & 9 deletions common/metadata/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
)

const (
Expand Down Expand Up @@ -116,7 +116,7 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) erro
return err
}
if !isBuffer {
err = rw.WriteBytes(writer, buffer.Bytes())
err = common.Error(writer.Write(buffer.Bytes()))
}
return err
}
Expand All @@ -129,8 +129,9 @@ func (s *Serializer) AddrPortLen(destination Socksaddr) int {
}
}

func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
af, err := rw.ReadByte(reader)
func (s *Serializer) ReadAddress(rawRedaer io.Reader) (Socksaddr, error) {
reader := varbin.NewReader(rawRedaer)
af, err := reader.ReadByte()
if err != nil {
return Socksaddr{}, err
}
Expand Down Expand Up @@ -164,11 +165,12 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
}

func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
port, err := rw.ReadBytes(reader, 2)
var port uint16
err := binary.Read(reader, binary.BigEndian, &port)
if err != nil {
return 0, E.Cause(err, "read port")
}
return binary.BigEndian.Uint16(port), nil
return port, nil
}

func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) {
Expand All @@ -194,12 +196,17 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err
return addr, nil
}

func ReadSockString(reader io.Reader) (string, error) {
strLen, err := rw.ReadByte(reader)
func ReadSockString(reader varbin.Reader) (string, error) {
strLen, err := reader.ReadByte()
if err != nil {
return "", err
}
strBytes := make([]byte, strLen)
_, err = io.ReadFull(reader, strBytes)
if err != nil {
return "", err
}
return rw.ReadString(reader, int(strLen))
return string(strBytes), nil
}

func WriteSocksString(buffer *buf.Buffer, str string) error {
Expand Down
38 changes: 32 additions & 6 deletions protocol/socks/client.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package socks

import (
std_bufio "bufio"
"context"
"net"
"net/url"
"os"
"strings"

"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
Expand Down Expand Up @@ -118,31 +121,53 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
return nil, err
}
if c.version == Version4 && address.IsFqdn() {
tcpAddr, err := net.ResolveTCPAddr(network, address.String())
var tcpAddr *net.TCPAddr
tcpAddr, err = net.ResolveTCPAddr(network, address.String())
if err != nil {
tcpConn.Close()
return nil, err
}
address = M.SocksaddrFromNet(tcpAddr)
}
reader := std_bufio.NewReader(tcpConn)
switch c.version {
case Version4, Version4A:
_, err = ClientHandshake4(tcpConn, command, address, c.username)
_, err = ClientHandshake4(reader, tcpConn, command, address, c.username)
if err != nil {
tcpConn.Close()
return nil, err
}
if reader.Buffered() > 0 {
buffer := buf.NewSize(reader.Buffered())
_, err = buffer.ReadFullFrom(reader, reader.Buffered())
if err != nil {
tcpConn.Close()
return nil, err
}
return bufio.NewCachedConn(tcpConn, buffer), nil
}
return tcpConn, nil
case Version5:
response, err := ClientHandshake5(tcpConn, command, address, c.username, c.password)
var response socks5.Response
response, err = ClientHandshake5(reader, tcpConn, command, address, c.username, c.password)
if err != nil {
tcpConn.Close()
return nil, err
}
if command == socks5.CommandConnect {
if reader.Buffered() > 0 {
buffer := buf.NewSize(reader.Buffered())
_, err = buffer.ReadFullFrom(reader, reader.Buffered())
if err != nil {
tcpConn.Close()
return nil, err
}
return bufio.NewCachedConn(tcpConn, buffer), nil
}
return tcpConn, nil
}
udpConn, err := c.dialer.DialContext(ctx, N.NetworkUDP, response.Bind)
var udpConn net.Conn
udpConn, err = c.dialer.DialContext(ctx, N.NetworkUDP, response.Bind)
if err != nil {
tcpConn.Close()
return nil, err
Expand All @@ -166,16 +191,17 @@ func (c *Client) BindContext(ctx context.Context, address M.Socksaddr) (net.Conn
if err != nil {
return nil, err
}
reader := std_bufio.NewReader(tcpConn)
switch c.version {
case Version4, Version4A:
_, err = ClientHandshake4(tcpConn, socks4.CommandBind, address, c.username)
_, err = ClientHandshake4(reader, tcpConn, socks4.CommandBind, address, c.username)
if err != nil {
tcpConn.Close()
return nil, err
}
return tcpConn, nil
case Version5:
_, err = ClientHandshake5(tcpConn, socks5.CommandBind, address, c.username, c.password)
_, err = ClientHandshake5(reader, tcpConn, socks5.CommandBind, address, c.username, c.password)
if err != nil {
tcpConn.Close()
return nil, err
Expand Down
39 changes: 18 additions & 21 deletions protocol/socks/handshake.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package socks

import (
std_bufio "bufio"
"context"
"io"
"net"
Expand All @@ -13,7 +14,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
)
Expand All @@ -23,16 +24,16 @@ type Handler interface {
N.UDPConnectionHandler
}

func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
err := socks4.WriteRequest(conn, socks4.Request{
func ClientHandshake4(reader varbin.Reader, writer io.Writer, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
err := socks4.WriteRequest(writer, socks4.Request{
Command: command,
Destination: destination,
Username: username,
})
if err != nil {
return socks4.Response{}, err
}
response, err := socks4.ReadResponse(conn)
response, err := socks4.ReadResponse(reader)
if err != nil {
return socks4.Response{}, err
}
Expand All @@ -42,32 +43,32 @@ func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr,
return response, err
}

func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, username string, password string) (socks5.Response, error) {
func ClientHandshake5(reader varbin.Reader, writer io.Writer, command byte, destination M.Socksaddr, username string, password string) (socks5.Response, error) {
var method byte
if username == "" {
method = socks5.AuthTypeNotRequired
} else {
method = socks5.AuthTypeUsernamePassword
}
err := socks5.WriteAuthRequest(conn, socks5.AuthRequest{
err := socks5.WriteAuthRequest(writer, socks5.AuthRequest{
Methods: []byte{method},
})
if err != nil {
return socks5.Response{}, err
}
authResponse, err := socks5.ReadAuthResponse(conn)
authResponse, err := socks5.ReadAuthResponse(reader)
if err != nil {
return socks5.Response{}, err
}
if authResponse.Method == socks5.AuthTypeUsernamePassword {
err = socks5.WriteUsernamePasswordAuthRequest(conn, socks5.UsernamePasswordAuthRequest{
err = socks5.WriteUsernamePasswordAuthRequest(writer, socks5.UsernamePasswordAuthRequest{
Username: username,
Password: password,
})
if err != nil {
return socks5.Response{}, err
}
usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(conn)
usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(reader)
if err != nil {
return socks5.Response{}, err
}
Expand All @@ -77,14 +78,14 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
} else if authResponse.Method != socks5.AuthTypeNotRequired {
return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method)
}
err = socks5.WriteRequest(conn, socks5.Request{
err = socks5.WriteRequest(writer, socks5.Request{
Command: command,
Destination: destination,
})
if err != nil {
return socks5.Response{}, err
}
response, err := socks5.ReadResponse(conn)
response, err := socks5.ReadResponse(reader)
if err != nil {
return socks5.Response{}, err
}
Expand All @@ -94,18 +95,14 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
return response, err
}

func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
version, err := rw.ReadByte(conn)
func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
version, err := reader.ReadByte()
if err != nil {
return err
}
return HandleConnection0(ctx, conn, version, authenticator, handler, metadata)
}

func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
switch version {
case socks4.Version:
request, err := socks4.ReadRequest0(conn)
request, err := socks4.ReadRequest0(reader)
if err != nil {
return err
}
Expand Down Expand Up @@ -142,7 +139,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
return E.New("socks4: unsupported command ", request.Command)
}
case socks5.Version:
authRequest, err := socks5.ReadAuthRequest0(conn)
authRequest, err := socks5.ReadAuthRequest0(reader)
if err != nil {
return err
}
Expand All @@ -167,7 +164,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
return err
}
if authMethod == socks5.AuthTypeUsernamePassword {
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(conn)
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(reader)
if err != nil {
return err
}
Expand All @@ -186,7 +183,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password)
}
}
request, err := socks5.ReadRequest(conn)
request, err := socks5.ReadRequest(reader)
if err != nil {
return err
}
Expand Down
24 changes: 12 additions & 12 deletions protocol/socks/socks4/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
)

const (
Expand All @@ -31,8 +31,8 @@ type Request struct {
Username string
}

func ReadRequest(reader io.Reader) (request Request, err error) {
version, err := rw.ReadByte(reader)
func ReadRequest(reader varbin.Reader) (request Request, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
Expand All @@ -43,8 +43,8 @@ func ReadRequest(reader io.Reader) (request Request, err error) {
return ReadRequest0(reader)
}

func ReadRequest0(reader io.Reader) (request Request, err error) {
request.Command, err = rw.ReadByte(reader)
func ReadRequest0(reader varbin.Reader) (request Request, err error) {
request.Command, err = reader.ReadByte()
if err != nil {
return
}
Expand Down Expand Up @@ -108,24 +108,24 @@ func WriteRequest(writer io.Writer, request Request) error {
common.Must1(buffer.WriteString(request.Destination.AddrString()))
common.Must(buffer.WriteZero())
}
return rw.WriteBytes(writer, buffer.Bytes())
return common.Error(writer.Write(buffer.Bytes()))
}

type Response struct {
ReplyCode byte
Destination M.Socksaddr
}

func ReadResponse(reader io.Reader) (response Response, err error) {
version, err := rw.ReadByte(reader)
func ReadResponse(reader varbin.Reader) (response Response, err error) {
version, err := reader.ReadByte()
if err != nil {
return
}
if version != 0 {
err = E.New("excepted socks4 response version 0, got ", version)
return
}
response.ReplyCode, err = rw.ReadByte(reader)
response.ReplyCode, err = reader.ReadByte()
if err != nil {
return
}
Expand All @@ -151,13 +151,13 @@ func WriteResponse(writer io.Writer, response Response) error {
binary.Write(buffer, binary.BigEndian, response.Destination.Port),
common.Error(buffer.Write(response.Destination.Addr.AsSlice())),
)
return rw.WriteBytes(writer, buffer.Bytes())
return common.Error(writer.Write(buffer.Bytes()))
}

func readString(reader io.Reader) (string, error) {
func readString(reader varbin.Reader) (string, error) {
buffer := bytes.Buffer{}
for {
b, err := rw.ReadByte(reader)
b, err := reader.ReadByte()
if err != nil {
return "", err
}
Expand Down
Loading

0 comments on commit 9e478d9

Please sign in to comment.