Skip to content

Commit

Permalink
Implement read waiter for UDP
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 7, 2023
1 parent c098d42 commit eb3cbf9
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 110 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.20
require (
github.com/gofrs/uuid/v5 v5.0.0
github.com/sagernet/quic-go v0.40.0
github.com/sagernet/sing v0.2.18
github.com/sagernet/sing v0.2.19-0.20231207032540-dbccc28f8194
golang.org/x/crypto v0.16.0
golang.org/x/exp v0.0.0-20231127185646-65229373498e
)
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/sagernet/quic-go v0.40.0 h1:DvQNPb72lzvNQDe9tcUyHTw8eRv6PLtM2mNYmdlzUMo=
github.com/sagernet/quic-go v0.40.0/go.mod h1:VqtdhlbkeeG5Okhb3eDMb/9o0EoglReHunNT9ukrJAI=
github.com/sagernet/sing v0.2.18 h1:2Ce4dl0pkWft+4914NGXPb8OiQpgA8UHQ9xFOmgvKuY=
github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing v0.2.19-0.20231207032540-dbccc28f8194 h1:lphv+waf4VhMIPkOiTewsHaCrBC7Jyrkt/uOKgjLnso=
github.com/sagernet/sing v0.2.19-0.20231207032540-dbccc28f8194/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
Expand Down
48 changes: 13 additions & 35 deletions hysteria/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)

Expand Down Expand Up @@ -118,17 +119,18 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
}

type udpPacketConn struct {
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint32
quicConn quic.Connection
data chan *udpMessage
udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32
closeOnce sync.Once
defragger *udpDefragger
onDestroy func()
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint32
quicConn quic.Connection
data chan *udpMessage
udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32
closeOnce sync.Once
defragger *udpDefragger
onDestroy func()
readWaitOptions N.ReadWaitOptions
}

func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn {
Expand All @@ -143,18 +145,6 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy f
}
}

func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
buffer = p.data
destination = M.ParseSocksaddrHostPort(p.host, p.port)
p.release()
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}

func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
Expand All @@ -167,18 +157,6 @@ func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
}
}

func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
_, err = newBuffer().ReadOnceFrom(p.data)
destination = M.ParseSocksaddrHostPort(p.host, p.port)
p.releaseMessage()
return
case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe
}
}

func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case pkt := <-c.data:
Expand Down
37 changes: 37 additions & 0 deletions hysteria/packet_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package hysteria

import (
"io"

"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return options.NeedHeadroom()
}

func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
destination = M.ParseSocksaddrHostPort(p.host, p.port)
if c.readWaitOptions.NeedHeadroom() {
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.Write(p.data.Bytes())
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
p.releaseMessage()
c.readWaitOptions.PostReturn(buffer)
} else {
buffer = p.data
p.release()
}
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}
48 changes: 13 additions & 35 deletions hysteria2/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

var udpMessagePool = sync.Pool{
Expand Down Expand Up @@ -114,17 +115,18 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
}

type udpPacketConn struct {
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint32
quicConn quic.Connection
data chan *udpMessage
udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32
closeOnce sync.Once
defragger *udpDefragger
onDestroy func()
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint32
quicConn quic.Connection
data chan *udpMessage
udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32
closeOnce sync.Once
defragger *udpDefragger
onDestroy func()
readWaitOptions N.ReadWaitOptions
}

func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn {
Expand All @@ -139,18 +141,6 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy f
}
}

func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
buffer = p.data
destination = M.ParseSocksaddr(p.destination)
p.release()
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}

func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
Expand All @@ -163,18 +153,6 @@ func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
}
}

func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
_, err = newBuffer().ReadOnceFrom(p.data)
destination = M.ParseSocksaddr(p.destination)
p.releaseMessage()
return
case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe
}
}

func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case pkt := <-c.data:
Expand Down
37 changes: 37 additions & 0 deletions hysteria2/packet_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package hysteria2

import (
"io"

"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return options.NeedHeadroom()
}

func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
destination = M.ParseSocksaddr(p.destination)
if c.readWaitOptions.NeedHeadroom() {
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.Write(p.data.Bytes())
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
p.releaseMessage()
c.readWaitOptions.PostReturn(buffer)
} else {
buffer = p.data
p.release()
}
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}
57 changes: 20 additions & 37 deletions tuic/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

var udpMessagePool = sync.Pool{
Expand Down Expand Up @@ -114,20 +115,26 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
return fragments
}

var (
_ N.NetPacketConn = (*udpPacketConn)(nil)
_ N.PacketReadWaiter = (*udpPacketConn)(nil)
)

type udpPacketConn struct {
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint16
quicConn quic.Connection
data chan *udpMessage
udpStream bool
udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32
closeOnce sync.Once
isServer bool
defragger *udpDefragger
onDestroy func()
ctx context.Context
cancel common.ContextCancelCauseFunc
sessionID uint16
quicConn quic.Connection
data chan *udpMessage
udpStream bool
udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32
closeOnce sync.Once
isServer bool
defragger *udpDefragger
onDestroy func()
readWaitOptions N.ReadWaitOptions
}

func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn {
Expand All @@ -144,18 +151,6 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream b
}
}

func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
buffer = p.data
destination = p.destination
p.release()
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}

func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
Expand All @@ -168,18 +163,6 @@ func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
}
}

func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case p := <-c.data:
_, err = newBuffer().ReadOnceFrom(p.data)
destination = p.destination
p.releaseMessage()
return
case <-c.ctx.Done():
return M.Socksaddr{}, io.ErrClosedPipe
}
}

func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case pkt := <-c.data:
Expand Down
37 changes: 37 additions & 0 deletions tuic/packet_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package tuic

import (
"io"

"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return options.NeedHeadroom()
}

func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case p := <-c.data:
destination = p.destination
if c.readWaitOptions.NeedHeadroom() {
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.Write(p.data.Bytes())
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
p.releaseMessage()
c.readWaitOptions.PostReturn(buffer)
} else {
buffer = p.data
p.release()
}
return
case <-c.ctx.Done():
return nil, M.Socksaddr{}, io.ErrClosedPipe
}
}

0 comments on commit eb3cbf9

Please sign in to comment.