Skip to content

Commit

Permalink
Improve timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 27, 2024
1 parent 7a6e2d8 commit 718380c
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 345 deletions.
8 changes: 7 additions & 1 deletion protocol/wireguard/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,18 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
if err != nil {
return nil, err
}
var udpTimeout time.Duration
if options.UDPTimeout != 0 {
udpTimeout = time.Duration(options.UDPTimeout)
} else {
udpTimeout = C.UDPTimeout
}
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
Context: ctx,
Logger: logger,
System: options.System,
Handler: ep,
UDPTimeout: time.Duration(options.UDPTimeout),
UDPTimeout: udpTimeout,
Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
Expand Down
270 changes: 97 additions & 173 deletions route/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"

Expand All @@ -18,31 +19,35 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)

var _ adapter.ConnectionManager = (*ConnectionManager)(nil)

type ConnectionManager struct {
logger logger.ContextLogger
monitor *ConnectionMonitor
logger logger.ContextLogger
access sync.Mutex
connections list.List[io.Closer]
}

func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
return &ConnectionManager{
logger: logger,
monitor: NewConnectionMonitor(),
logger: logger,
}
}

func (m *ConnectionManager) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateInitialize {
return nil
}
return m.monitor.Start()
return nil
}

func (m *ConnectionManager) Close() error {
return m.monitor.Close()
m.access.Lock()
defer m.access.Unlock()
for element := m.connections.Front(); element != nil; element = element.Next() {
common.Close(element.Value)
}
m.connections.Init()
return nil
}

func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
Expand All @@ -57,95 +62,32 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
err = E.Cause(err, "open outbound connection")
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "open outbound connection: ", err)
m.logger.ErrorContext(ctx, err)
return
}
err = N.ReportConnHandshakeSuccess(conn, remoteConn)
if err != nil {
err = E.Cause(err, "report handshake success")
remoteConn.Close()
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "report handshake success: ", err)
m.logger.ErrorContext(ctx, err)
return
}
m.access.Lock()
element := m.connections.PushBack(conn)
m.access.Unlock()
onClose = N.AppendClose(onClose, func(it error) {
m.access.Lock()
defer m.access.Unlock()
m.connections.Remove(element)
})
var done atomic.Bool
if ctx.Done() != nil {
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
}
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
}

func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
originSource := source
originDestination := destination
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountReader(source, readCounters)
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
dataLen := cachedBuffer.Len()
_, err := destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
if done.Swap(true) {
if onClose != nil {
onClose(err)
}
}
common.Close(originSource, originDestination)
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
continue
}
break
}
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
if err != nil {
common.Close(originSource, originDestination)
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
err = duplexDst.CloseWrite()
if err != nil {
common.Close(originSource, originDestination)
}
} else {
common.Close(originDestination)
}
if done.Swap(true) {
if onClose != nil {
onClose(err)
}
common.Close(originSource, originDestination)
}
if !direction {
if err == nil {
m.logger.DebugContext(ctx, "connection upload finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection upload closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection upload closed")
}
} else {
if err == nil {
m.logger.DebugContext(ctx, "connection download finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection download closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection download closed")
}
}
}

func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = adapter.WithContext(ctx, &metadata)
var (
Expand Down Expand Up @@ -227,58 +169,91 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
}
destination := bufio.NewPacketConn(remotePacketConn)
m.access.Lock()
element := m.connections.PushBack(conn)
m.access.Unlock()
onClose = N.AppendClose(onClose, func(it error) {
m.access.Lock()
defer m.access.Unlock()
m.connections.Remove(element)
})
var done atomic.Bool
if ctx.Done() != nil {
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
}
go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
}

func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
_, err := bufio.CopyPacket(destination, source)
/*var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
originSource := source
originDestination := destination
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters)
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
packet := cachedReader.ReadCachedPacket()
if packet != nil {
cachedPackets = append(cachedPackets, packet)
continue
source, readCounters = N.UnwrapCountReader(source, readCounters)
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
dataLen := cachedBuffer.Len()
_, err := destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
if done.Swap(true) {
onClose(err)
}
common.Close(originSource, originDestination)
if !direction {
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
} else {
m.logger.ErrorContext(ctx, "connection download payload: ", err)
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
continue
}
break
}
var handled bool
if natConn, isNatConn := source.(udpnat.Conn); isNatConn {
natConn.SetHandler(&udpHijacker{
ctx: ctx,
logger: m.logger,
source: natConn,
destination: destination,
direction: direction,
readCounters: readCounters,
writeCounters: writeCounters,
done: done,
onClose: onClose,
})
handled = true
}
if cachedPackets != nil {
_, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters)
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
if err != nil {
common.Close(originDestination)
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
err = duplexDst.CloseWrite()
if err != nil {
common.Close(source, destination)
m.logger.ErrorContext(ctx, "packet upload payload: ", err)
return
common.Close(originSource, originDestination)
}
} else {
common.Close(originDestination)
}
if handled {
return
if done.Swap(true) {
onClose(err)
common.Close(originSource, originDestination)
}
if !direction {
if err == nil {
m.logger.DebugContext(ctx, "connection upload finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection upload closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection upload closed")
}
} else {
if err == nil {
m.logger.DebugContext(ctx, "connection download finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection download closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection download closed")
}
}
_, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/
}

func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
_, err := bufio.CopyPacket(destination, source)
if !direction {
if E.IsClosedOrCanceled(err) {
m.logger.TraceContext(ctx, "packet upload closed")
Expand All @@ -293,58 +268,7 @@ func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.P
}
}
if !done.Swap(true) {
if onClose != nil {
onClose(err)
}
onClose(err)
}
common.Close(source, destination)
}

/*type udpHijacker struct {
ctx context.Context
logger logger.ContextLogger
source io.Closer
destination N.PacketWriter
direction bool
readCounters []N.CountFunc
writeCounters []N.CountFunc
done *atomic.Bool
onClose N.CloseHandlerFunc
}
func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
dataLen := buffer.Len()
for _, counter := range u.readCounters {
counter(int64(dataLen))
}
err := u.destination.WritePacket(buffer, source)
if err != nil {
common.Close(u.source, u.destination)
u.logger.DebugContext(u.ctx, "packet upload closed: ", err)
return
}
for _, counter := range u.writeCounters {
counter(int64(dataLen))
}
}
func (u *udpHijacker) Close() error {
var err error
if !u.done.Swap(true) {
err = common.Close(u.source, u.destination)
if u.onClose != nil {
u.onClose(net.ErrClosed)
}
}
if u.direction {
u.logger.TraceContext(u.ctx, "packet download closed")
} else {
u.logger.TraceContext(u.ctx, "packet upload closed")
}
return err
}
func (u *udpHijacker) Upstream() any {
return u.destination
}
*/
Loading

0 comments on commit 718380c

Please sign in to comment.