From 94f0582769591e954287566ac72bd622982d4f62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 23 Oct 2024 13:30:48 +0800 Subject: [PATCH] udpnat2: Add SetHandler --- common/bufio/cache.go | 4 +++- common/network/conn.go | 7 +------ common/network/packet.go | 35 +++++++++++++++++++++++++++++++ common/udpnat2/conn.go | 44 ++++++++++++++++++++++++++------------- common/udpnat2/packet.go | 28 ------------------------- common/udpnat2/service.go | 25 +++++++++++++--------- 6 files changed, 84 insertions(+), 59 deletions(-) create mode 100644 common/network/packet.go delete mode 100644 common/udpnat2/packet.go diff --git a/common/bufio/cache.go b/common/bufio/cache.go index ace72597..ce62d4d3 100644 --- a/common/bufio/cache.go +++ b/common/bufio/cache.go @@ -184,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer { if buffer != nil { buffer.DecRef() } - return &N.PacketBuffer{ + packet := N.NewPacketBuffer() + *packet = N.PacketBuffer{ Buffer: buffer, Destination: c.destination, } + return packet } func (c *CachedPacketConn) Upstream() any { diff --git a/common/network/conn.go b/common/network/conn.go index c795a19d..c289bf61 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -124,7 +124,7 @@ type UDPHandler interface { } type UDPHandlerEx interface { - NewPacketEx(buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) + NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) } // Deprecated: Use UDPConnectionHandlerEx instead. @@ -146,11 +146,6 @@ type CachedPacketReader interface { ReadCachedPacket() *PacketBuffer } -type PacketBuffer struct { - Buffer *buf.Buffer - Destination M.Socksaddr -} - type WithUpstreamReader interface { UpstreamReader() any } diff --git a/common/network/packet.go b/common/network/packet.go new file mode 100644 index 00000000..5b852144 --- /dev/null +++ b/common/network/packet.go @@ -0,0 +1,35 @@ +package network + +import ( + "sync" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type PacketBuffer struct { + Buffer *buf.Buffer + Destination M.Socksaddr +} + +var packetPool = sync.Pool{ + New: func() any { + return new(PacketBuffer) + }, +} + +func NewPacketBuffer() *PacketBuffer { + return packetPool.Get().(*PacketBuffer) +} + +func PutPacketBuffer(packet *PacketBuffer) { + *packet = PacketBuffer{} + packetPool.Put(packet) +} + +func ReleaseMultiPacketBuffer(packetBuffers []*PacketBuffer) { + for _, packet := range packetBuffers { + packet.Buffer.Release() + PutPacketBuffer(packet) + } +} diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index a5ca8ac2..a96f4c8e 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -12,22 +12,23 @@ import ( "github.com/sagernet/sing/common/pipe" ) -type natConn struct { +type Conn struct { writer N.PacketWriter localAddr M.Socksaddr - packetChan chan *Packet + handler N.UDPHandlerEx + packetChan chan *N.PacketBuffer doneChan chan struct{} readDeadline pipe.Deadline readWaitOptions N.ReadWaitOptions } -func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { +func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { select { case p := <-c.packetChan: _, err = buffer.ReadOnceFrom(p.Buffer) destination := p.Destination p.Buffer.Release() - PutPacket(p) + N.PutPacketBuffer(p) return destination, err case <-c.doneChan: return M.Socksaddr{}, io.ErrClosedPipe @@ -36,21 +37,36 @@ func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { } } -func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { +func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.writer.WritePacket(buffer, destination) } -func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { +func (c *Conn) SetHandler(handler N.UDPHandlerEx) { + c.handler = handler +fetch: + for { + select { + case packet := <-c.packetChan: + c.handler.NewPacketEx(packet.Buffer, packet.Destination) + N.PutPacketBuffer(packet) + continue fetch + default: + break fetch + } + } +} + +func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { c.readWaitOptions = options return false } -func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { +func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { select { case packet := <-c.packetChan: buffer = c.readWaitOptions.Copy(packet.Buffer) destination = packet.Destination - PutPacket(packet) + N.PutPacketBuffer(packet) return case <-c.doneChan: return nil, M.Socksaddr{}, io.ErrClosedPipe @@ -59,7 +75,7 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, } } -func (c *natConn) Close() error { +func (c *Conn) Close() error { select { case <-c.doneChan: default: @@ -68,23 +84,23 @@ func (c *natConn) Close() error { return nil } -func (c *natConn) LocalAddr() net.Addr { +func (c *Conn) LocalAddr() net.Addr { return c.localAddr } -func (c *natConn) RemoteAddr() net.Addr { +func (c *Conn) RemoteAddr() net.Addr { return M.Socksaddr{} } -func (c *natConn) SetDeadline(t time.Time) error { +func (c *Conn) SetDeadline(t time.Time) error { return os.ErrInvalid } -func (c *natConn) SetReadDeadline(t time.Time) error { +func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) return nil } -func (c *natConn) SetWriteDeadline(t time.Time) error { +func (c *Conn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } diff --git a/common/udpnat2/packet.go b/common/udpnat2/packet.go deleted file mode 100644 index 1d56ff42..00000000 --- a/common/udpnat2/packet.go +++ /dev/null @@ -1,28 +0,0 @@ -package udpnat - -import ( - "sync" - - "github.com/sagernet/sing/common/buf" - M "github.com/sagernet/sing/common/metadata" -) - -var packetPool = sync.Pool{ - New: func() any { - return new(Packet) - }, -} - -type Packet struct { - Buffer *buf.Buffer - Destination M.Socksaddr -} - -func NewPacket() *Packet { - return packetPool.Get().(*Packet) -} - -func PutPacket(packet *Packet) { - *packet = Packet{} - packetPool.Put(packet) -} diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 85b36417..8c8afc9a 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -14,7 +14,7 @@ import ( ) type Service struct { - nat *freelru.LRU[netip.AddrPort, *natConn] + nat *freelru.LRU[netip.AddrPort, *Conn] handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics @@ -30,9 +30,9 @@ type Metrics struct { } func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service { - nat := common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + nat := common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) nat.SetLifetime(timeout) - nat.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { + nat.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { select { case <-conn.doneChan: return false @@ -40,7 +40,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur return true } }) - nat.SetOnEvict(func(_ netip.AddrPort, conn *natConn) { + nat.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { conn.Close() }) return &Service{ @@ -55,26 +55,31 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati if !loaded { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { + println(2) s.metrics.Rejects++ return } - conn = &natConn{ + conn = &Conn{ writer: writer, localAddr: source, - packetChan: make(chan *Packet, 64), + packetChan: make(chan *N.PacketBuffer, 64), doneChan: make(chan struct{}), readDeadline: pipe.MakeDeadline(), } s.nat.Add(source.AddrPort(), conn) - s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) + go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) s.metrics.Creates++ } - packet := NewPacket() buffer := conn.readWaitOptions.NewPacketBuffer() for _, bufferSlice := range bufferSlices { buffer.Write(bufferSlice) } - *packet = Packet{ + if conn.handler != nil { + conn.handler.NewPacketEx(buffer, destination) + return + } + packet := N.NewPacketBuffer() + *packet = N.PacketBuffer{ Buffer: buffer, Destination: destination, } @@ -83,7 +88,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati s.metrics.Inputs++ default: packet.Buffer.Release() - PutPacket(packet) + N.PutPacketBuffer(packet) s.metrics.Drops++ } }