Skip to content

Commit

Permalink
Merge pull request #416 from lesismal/rawconn
Browse files Browse the repository at this point in the history
add Engine.OnUDPListen/Conn.SyscallConn
  • Loading branch information
lesismal authored Apr 15, 2024
2 parents 2181991 + 074e29e commit 7e75f32
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 18 deletions.
12 changes: 12 additions & 0 deletions conn_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"net"
"sync"
"syscall"
"time"

"github.com/lesismal/nbio/timer"
Expand Down Expand Up @@ -45,6 +46,8 @@ type Conn struct {
cache *bytes.Buffer

dataHandler func(c *Conn, data []byte)

onConnected func(c *Conn, err error)
}

// Hash returns a hashcode.
Expand Down Expand Up @@ -546,3 +549,12 @@ func (u *udpConn) getConn(p *poller, rAddr *net.UDPAddr) (*Conn, bool) {

return c, ok
}

func (c *Conn) SyscallConn() (syscall.RawConn, error) {
if rc, ok := c.conn.(interface {
SyscallConn() (syscall.RawConn, error)
}); ok {
return rc.SyscallConn()
}
return nil, ErrUnsupported
}
23 changes: 23 additions & 0 deletions conn_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ type Conn struct {
readEvents int32

dataHandler func(c *Conn, data []byte)

onConnected func(c *Conn, err error)
}

// Hash returns a hash code of this connection.
Expand Down Expand Up @@ -1059,3 +1061,24 @@ func getUDPNetAddr(sa syscall.Sockaddr) *net.UDPAddr {
}
return ret
}

func (c *Conn) SyscallConn() (syscall.RawConn, error) {
return &rawConn{fd: c.fd}, nil
}

type rawConn struct {
fd int
}

func (c *rawConn) Control(f func(fd uintptr)) error {
f(uintptr(c.fd))
return nil
}

func (c *rawConn) Read(f func(fd uintptr) (done bool)) error {
return ErrUnsupported
}

func (c *rawConn) Write(f func(fd uintptr) (done bool)) error {
return ErrUnsupported
}
33 changes: 22 additions & 11 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ type Engine struct {
listeners []*poller
pollers []*poller

// onUDPListen for udp listener created.
onUDPListen func(c *Conn)
// callback for new connection connected.
onOpen func(c *Conn)
// callback for connection closed.
Expand Down Expand Up @@ -251,10 +253,18 @@ func (g *Engine) AddConn(conn net.Conn) (*Conn, error) {
return c, nil
}

// OnOpen registers callback for new connection.
func (g *Engine) OnUDPListen(h func(c *Conn)) {
if h == nil {
panic("invalid handler: nil")
}
g.onUDPListen = h
}

// OnOpen registers callback for new connection.
func (g *Engine) OnOpen(h func(c *Conn)) {
if h == nil {
panic("invalid nil handler")
panic("invalid handler: nil")
}
g.onOpen = func(c *Conn) {
g.wgConn.Add(1)
Expand All @@ -265,7 +275,7 @@ func (g *Engine) OnOpen(h func(c *Conn)) {
// OnClose registers callback for disconnected.
func (g *Engine) OnClose(h func(c *Conn, err error)) {
if h == nil {
panic("invalid nil handler")
panic("invalid handler: nil")
}
g.onClose = func(c *Conn, err error) {
g.Async(func() {
Expand All @@ -283,7 +293,7 @@ func (g *Engine) OnRead(h func(c *Conn)) {
// OnData registers callback for data.
func (g *Engine) OnData(h func(c *Conn, data []byte)) {
if h == nil {
panic("invalid nil handler")
panic("invalid handler: nil")
}
g.onData = h
}
Expand All @@ -293,23 +303,23 @@ func (g *Engine) OnData(h func(c *Conn, data []byte)) {
// else it's operating by Sendfile.
func (g *Engine) OnWrittenSize(h func(c *Conn, b []byte, n int)) {
if h == nil {
panic("invalid nil handler")
panic("invalid handler: nil")
}
g.onWrittenSize = h
}

// OnReadBufferAlloc registers callback for memory allocating.
func (g *Engine) OnReadBufferAlloc(h func(c *Conn) []byte) {
if h == nil {
panic("invalid nil handler")
panic("invalid handler: nil")
}
g.onReadBufferAlloc = h
}

// OnReadBufferFree registers callback for memory release.
func (g *Engine) OnReadBufferFree(h func(c *Conn, b []byte)) {
if h == nil {
panic("invalid nil handler")
panic("invalid handler: nil")
}
g.onReadBufferFree = h
}
Expand All @@ -318,7 +328,7 @@ func (g *Engine) OnReadBufferFree(h func(c *Conn, b []byte)) {
// OnWriteBufferRelease registers callback for write buffer memory release.
// func (g *Engine) OnWriteBufferRelease(h func(c *Conn, b []byte)) {
// if h == nil {
// panic("invalid nil handler")
// panic("invalid handler: nil")
// }
// g.onWriteBufferFree = h
// }
Expand All @@ -327,7 +337,7 @@ func (g *Engine) OnReadBufferFree(h func(c *Conn, b []byte)) {
// the handler would be called on windows.
// func (g *Engine) BeforeRead(h func(c *Conn)) {
// if h == nil {
// panic("invalid nil handler")
// panic("invalid handler: nil")
// }
// g.beforeRead = h
// }
Expand All @@ -337,7 +347,7 @@ func (g *Engine) OnReadBufferFree(h func(c *Conn, b []byte)) {
// the handler would be called on *nix.
// func (g *Engine) AfterRead(h func(c *Conn)) {
// if h == nil {
// panic("invalid nil handler")
// panic("invalid handler: nil")
// }
// g.afterRead = h
// }
Expand All @@ -347,15 +357,15 @@ func (g *Engine) OnReadBufferFree(h func(c *Conn, b []byte)) {
// the handler would be called on windows.
// func (g *Engine) BeforeWrite(h func(c *Conn)) {
// if h == nil {
// panic("invalid nil handler")
// panic("invalid handler: nil")
// }
// g.beforeWrite = h
// }

// OnStop registers callback before Engine is stopped.
func (g *Engine) OnStop(h func()) {
if h == nil {
panic("invalid nil handler")
panic("invalid handler: nil")
}
g.onStop = h
}
Expand Down Expand Up @@ -383,6 +393,7 @@ func (g *Engine) initHandlers() {
// g.BeforeRead(func(c *Conn) {})
// g.AfterRead(func(c *Conn) {})
// g.BeforeWrite(func(c *Conn) {})
g.OnUDPListen(func(*Conn) {})
g.OnStop(func() {})

if g.Execute == nil {
Expand Down
12 changes: 8 additions & 4 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ import (
)

var (
ErrReadTimeout = errors.New("read timeout")
errReadTimeout = ErrReadTimeout
ErrReadTimeout = errors.New("read timeout")
errReadTimeout = ErrReadTimeout

ErrWriteTimeout = errors.New("write timeout")
errWriteTimeout = ErrWriteTimeout
ErrOverflow = errors.New("write overflow")
errOverflow = ErrOverflow

ErrOverflow = errors.New("write overflow")
errOverflow = ErrOverflow

ErrUnsupported = errors.New("unsupported operation")
)
6 changes: 3 additions & 3 deletions nbio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ func init() {

if len(data) == 8 && string(data) == "sendfile" {
wsess.isFile = true
fd, err := os.Open(testfile)
file, err := os.Open(testfile)
if err != nil {
log.Panicf("open file failed: %v", err)
}

if _, err = c.Sendfile(fd, 0); err != nil {
if _, err = c.Sendfile(file, 0); err != nil {
panic(err)
}

if err := fd.Close(); err != nil {
if err := file.Close(); err != nil {
log.Panicf("close file failed: %v", err)
}
} else {
Expand Down
5 changes: 5 additions & 0 deletions poller_epoll.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ func (p *poller) addConn(c *Conn) {
c.p = p
if c.typ != ConnTypeUDPServer {
p.g.onOpen(c)
} else {
p.g.onUDPListen(c)
}
p.g.connsUnix[fd] = c
err := p.addRead(fd)
Expand Down Expand Up @@ -221,6 +223,9 @@ func (p *poller) readWriteLoop() {
}

if ev.Events&epollEventsRead != 0 {
if c.onConnected != nil {
c.onConnected(c, nil)
}
if g.onRead == nil {
if asyncReadEnabled {
c.AsyncRead()
Expand Down
5 changes: 5 additions & 0 deletions poller_kqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ func (p *poller) addConn(c *Conn) {
c.p = p
if c.typ != ConnTypeUDPServer {
p.g.onOpen(c)
} else {
p.g.onUDPListen(c)
}
p.g.connsUnix[fd] = c
p.addRead(fd)
Expand Down Expand Up @@ -140,6 +142,9 @@ func (p *poller) readWrite(ev *syscall.Kevent_t) {
c := p.getConn(fd)
if c != nil {
if ev.Filter&syscall.EVFILT_READ == syscall.EVFILT_READ {
if c.onConnected != nil {
c.onConnected(c, nil)
}
if p.g.onRead == nil {
for {
buffer := p.g.borrow(c)
Expand Down
2 changes: 2 additions & 0 deletions poller_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func (p *poller) addConn(c *Conn) error {
// should not call onOpen for udp server conn
if c.typ != ConnTypeUDPServer {
p.g.onOpen(c)
} else {
p.g.onUDPListen(c)
}
// should not read udp client from reading udp server conn
if c.typ != ConnTypeUDPClientFromRead {
Expand Down

0 comments on commit 7e75f32

Please sign in to comment.