diff --git a/connection_impl.go b/connection_impl.go index beb11d23..def1d97c 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -50,9 +50,11 @@ type connection struct { bookSize int // The size of data that can be read at once. } -var _ Connection = &connection{} -var _ Reader = &connection{} -var _ Writer = &connection{} +var ( + _ Connection = &connection{} + _ Reader = &connection{} + _ Writer = &connection{} +) // Reader implements Connection. func (c *connection) Reader() Reader { @@ -168,7 +170,7 @@ func (c *connection) Until(delim byte) (line []byte, err error) { l = c.inputBuffer.Len() i := c.inputBuffer.indexByte(delim, n) if i < 0 { - n = l //skip all exists bytes + n = l // skip all exists bytes continue } return c.Next(i + 1) @@ -297,6 +299,12 @@ func (c *connection) Close() error { return c.onClose() } +// Detach detaches the connection from poller but doesn't close it. +func (c *connection) Detach() error { + c.detaching = true + return c.onClose() +} + // ------------------------------------------ private ------------------------------------------ var barrierPool = sync.Pool{ @@ -368,8 +376,6 @@ func (c *connection) initFDOperator() { func (c *connection) initFinalizer() { c.AddCloseCallback(func(connection Connection) (err error) { c.stop(flushing) - // stop the finalizing state to prevent conn.fill function to be performed - c.stop(finalizing) c.operator.Free() if err = c.netFD.Close(); err != nil { logger.Printf("NETPOLL: netFD close failed: %v", err) @@ -405,15 +411,10 @@ func (c *connection) waitRead(n int) (err error) { } // wait full n for c.inputBuffer.Len() < n { - if c.IsActive() { - <-c.readTrigger - continue - } - // confirm that fd is still valid. - if atomic.LoadUint32(&c.netFD.closed) == 0 { - return c.fill(n) + if !c.IsActive() { + return Exception(ErrConnClosed, "wait read") } - return Exception(ErrConnClosed, "wait read") + <-c.readTrigger } return nil } @@ -430,12 +431,7 @@ func (c *connection) waitReadWithTimeout(n int) (err error) { for c.inputBuffer.Len() < n { if !c.IsActive() { // cannot return directly, stop timer before ! - // confirm that fd is still valid. - if atomic.LoadUint32(&c.netFD.closed) == 0 { - err = c.fill(n) - } else { - err = Exception(ErrConnClosed, "wait read") - } + err = Exception(ErrConnClosed, "wait read") break } @@ -458,39 +454,6 @@ func (c *connection) waitReadWithTimeout(n int) (err error) { return err } -// fill data after connection is closed. -func (c *connection) fill(need int) (err error) { - if !c.lock(finalizing) { - return ErrConnClosed - } - defer c.unlock(finalizing) - - var n int - var bs [][]byte - for { - bs = c.inputs(c.inputBarrier.bs) - TryRead: - n, err = readv(c.fd, bs, c.inputBarrier.ivs) - if err != nil { - if err == syscall.EINTR { - // if err == EINTR, we must reuse bs that has been booked - // otherwise will mess the input buffer - goto TryRead - } - break - } - if n == 0 { - err = Exception(ErrEOF, "") - break - } - c.inputAck(n) - } - if c.inputBuffer.Len() >= need { - return nil - } - return err -} - // flush write data directly. func (c *connection) flush() error { if c.outputBuffer.IsEmpty() { diff --git a/connection_lock.go b/connection_lock.go index e036de4b..2dce6622 100644 --- a/connection_lock.go +++ b/connection_lock.go @@ -47,7 +47,6 @@ const ( closing key = iota processing flushing - finalizing // total must be at the bottom. total ) diff --git a/connection_onevent.go b/connection_onevent.go index e8567ba9..f8351f32 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -90,7 +90,7 @@ func (c *connection) AddCloseCallback(callback CloseCallback) error { return nil } -// OnPrepare supports close connection, but not read/write data. +// onPrepare supports close connection, but not read/write data. // connection will be registered by this call after preparing. func (c *connection) onPrepare(opts *options) (err error) { if opts != nil { diff --git a/connection_test.go b/connection_test.go index 655777c5..3d8fe160 100644 --- a/connection_test.go +++ b/connection_test.go @@ -22,6 +22,8 @@ import ( "errors" "fmt" "math/rand" + "net" + "os" "runtime" "sync" "sync/atomic" @@ -395,7 +397,7 @@ func TestConnectionUntil(t *testing.T) { buf, err := rconn.Reader().Until('\n') Equal(t, len(buf), 100) - MustTrue(t, errors.Is(err, ErrEOF)) + Assert(t, errors.Is(err, ErrConnClosed), err) } func TestBookSizeLargerThanMaxSize(t *testing.T) { @@ -432,3 +434,60 @@ func TestBookSizeLargerThanMaxSize(t *testing.T) { wg.Wait() rconn.Close() } + +func TestConnDetach(t *testing.T) { + ln, err := CreateListener("tcp", ":1234") + MustNil(t, err) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + if conn == nil { + continue + } + go func() { + buf := make([]byte, 1024) + // slow read + for { + _, err := conn.Read(buf) + if err != nil { + return + } + time.Sleep(100 * time.Millisecond) + _, err = conn.Write(buf) + if err != nil { + return + } + } + }() + } + }() + + c, err := DialConnection("tcp", ":1234", time.Second) + MustNil(t, err) + + conn := c.(*TCPConnection) + + err = conn.Detach() + MustNil(t, err) + + f := os.NewFile(uintptr(conn.fd), "netpoll-connection") + defer f.Close() + + gonetconn, err := net.FileConn(f) + MustNil(t, err) + buf := make([]byte, 1024) + _, err = gonetconn.Write(buf) + MustNil(t, err) + _, err = gonetconn.Read(buf) + MustNil(t, err) + + err = gonetconn.Close() + MustNil(t, err) + + err = ln.Close() + MustNil(t, err) +} diff --git a/mux/shard_queue.go b/mux/shard_queue.go index 7c7c1261..364fabae 100644 --- a/mux/shard_queue.go +++ b/mux/shard_queue.go @@ -111,7 +111,7 @@ func (q *ShardQueue) Close() error { // wait for all tasks finished for atomic.LoadInt32(&q.state) != closed { if atomic.LoadInt32(&q.trigger) == 0 { - atomic.StoreInt32(&q.trigger, closed) + atomic.StoreInt32(&q.state, closed) return nil } runtime.Gosched() diff --git a/net_io.go b/net_io.go new file mode 100644 index 00000000..c7322fd2 --- /dev/null +++ b/net_io.go @@ -0,0 +1,42 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netpoll + +import "syscall" + +// return value: +// - n: n == 0 but err == nil, retry syscall +// - err: if not nil, connection should be closed. +func ioread(fd int, bs [][]byte, ivs []syscall.Iovec) (n int, err error) { + n, err = readv(fd, bs, ivs) + if n == 0 && err == nil { // means EOF + return 0, Exception(ErrEOF, "") + } + if err == syscall.EINTR || err == syscall.EAGAIN { + return 0, nil + } + return n, err +} + +// return value: +// - n: n == 0 but err == nil, retry syscall +// - err: if not nil, connection should be closed. +func iosend(fd int, bs [][]byte, ivs []syscall.Iovec, zerocopy bool) (n int, err error) { + n, err = sendmsg(fd, bs, ivs, zerocopy) + if err == syscall.EAGAIN { + return 0, nil + } + return n, err +} diff --git a/net_netfd.go b/net_netfd.go index cea01a77..96bc0945 100644 --- a/net_netfd.go +++ b/net_netfd.go @@ -50,6 +50,8 @@ type netFD struct { network string // tcp tcp4 tcp6, udp, udp4, udp6, ip, ip4, ip6, unix, unixgram, unixpacket localAddr net.Addr remoteAddr net.Addr + // for detaching conn from poller + detaching bool } func newNetFD(fd, family, sotype int, net string) *netFD { diff --git a/net_netfd_conn.go b/net_netfd_conn.go index c2ab43e0..cd6922d4 100644 --- a/net_netfd_conn.go +++ b/net_netfd_conn.go @@ -59,7 +59,7 @@ func (c *netFD) Close() (err error) { if atomic.AddUint32(&c.closed, 1) != 1 { return nil } - if c.fd > 0 { + if !c.detaching && c.fd > 2 { err = syscall.Close(c.fd) if err != nil { logger.Printf("NETPOLL: netFD[%d] close error: %s", c.fd, err.Error()) diff --git a/netpoll_options.go b/netpoll_options.go index f2effafa..ec384f54 100644 --- a/netpoll_options.go +++ b/netpoll_options.go @@ -29,14 +29,15 @@ import ( // Experience recommends assigning a poller every 20c. // // You can only use SetNumLoops before any connection is created. An example usage: -// func init() { -// netpoll.SetNumLoops(...) -// } +// +// func init() { +// netpoll.SetNumLoops(...) +// } func SetNumLoops(numLoops int) error { return setNumLoops(numLoops) } -// LoadBalance sets the load balancing method. Load balancing is always a best effort to attempt +// SetLoadBalance sets the load balancing method. Load balancing is always a best effort to attempt // to distribute the incoming connections between multiple polls. // This option only works when NumLoops is set. func SetLoadBalance(lb LoadBalance) error { diff --git a/netpoll_test.go b/netpoll_test.go index 933566aa..cedf6226 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -21,6 +21,9 @@ import ( "context" "errors" "math/rand" + "runtime" + "sync" + "sync/atomic" "testing" "time" ) @@ -248,9 +251,10 @@ func TestCloseCallbackWhenOnConnect(t *testing.T) { MustNil(t, err) } -func TestCloseAndWrite(t *testing.T) { +func TestServerReadAndClose(t *testing.T) { var network, address = "tcp", ":18888" var sendMsg = []byte("hello") + var closed int32 var loop = newTestEventLoop(network, address, func(ctx context.Context, connection Connection) error { _, err := connection.Reader().Next(len(sendMsg)) @@ -258,6 +262,7 @@ func TestCloseAndWrite(t *testing.T) { err = connection.Close() MustNil(t, err) + atomic.AddInt32(&closed, 1) return nil }, ) @@ -269,7 +274,10 @@ func TestCloseAndWrite(t *testing.T) { err = conn.Writer().Flush() MustNil(t, err) - time.Sleep(time.Millisecond * 100) // wait for poller close connection + for atomic.LoadInt32(&closed) == 0 { + runtime.Gosched() // wait for poller close connection + } + time.Sleep(time.Millisecond * 50) _, err = conn.Writer().WriteBinary(sendMsg) MustNil(t, err) err = conn.Writer().Flush() @@ -279,9 +287,59 @@ func TestCloseAndWrite(t *testing.T) { MustNil(t, err) } +func TestClientWriteAndClose(t *testing.T) { + var ( + network, address = "tcp", ":18889" + connnum = 10 + packetsize, packetnum = 1000 * 5, 1 + recvbytes int32 = 0 + ) + var loop = newTestEventLoop(network, address, + func(ctx context.Context, connection Connection) error { + buf, err := connection.Reader().Next(connection.Reader().Len()) + if errors.Is(err, ErrConnClosed) { + return err + } + MustNil(t, err) + atomic.AddInt32(&recvbytes, int32(len(buf))) + return nil + }, + ) + var wg sync.WaitGroup + for i := 0; i < connnum; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var conn, err = DialConnection(network, address, time.Second) + MustNil(t, err) + sendMsg := make([]byte, packetsize) + for j := 0; j < packetnum; j++ { + _, err = conn.Write(sendMsg) + MustNil(t, err) + } + err = conn.Close() + MustNil(t, err) + }() + } + wg.Wait() + exceptbytes := int32(packetsize * packetnum * connnum) + for atomic.LoadInt32(&recvbytes) != exceptbytes { + t.Logf("left %d bytes not received", exceptbytes-atomic.LoadInt32(&recvbytes)) + runtime.Gosched() + } + err := loop.Shutdown(context.Background()) + MustNil(t, err) +} + func newTestEventLoop(network, address string, onRequest OnRequest, opts ...Option) EventLoop { - var listener, _ = CreateListener(network, address) - var eventLoop, _ = NewEventLoop(onRequest, opts...) - go eventLoop.Serve(listener) - return eventLoop + ln, err := CreateListener(network, address) + if err != nil { + panic(err) + } + elp, err := NewEventLoop(onRequest, opts...) + if err != nil { + panic(err) + } + go elp.Serve(ln) + return elp } diff --git a/nocopy_linkbuffer.go b/nocopy_linkbuffer.go index 6200bf5b..59cc6530 100644 --- a/nocopy_linkbuffer.go +++ b/nocopy_linkbuffer.go @@ -475,7 +475,7 @@ func (b *LinkBuffer) WriteDirect(p []byte, remainLen int) error { // find origin origin := b.flush malloc := b.mallocSize - remainLen // calculate the remaining malloc length - for t := origin.malloc - len(origin.buf); t <= malloc; t = origin.malloc - len(origin.buf) { + for t := origin.malloc - len(origin.buf); t < malloc; t = origin.malloc - len(origin.buf) { malloc -= t origin = origin.next } @@ -486,18 +486,24 @@ func (b *LinkBuffer) WriteDirect(p []byte, remainLen int) error { dataNode := newLinkBufferNode(0) dataNode.buf, dataNode.malloc = p[:0], n - newNode := newLinkBufferNode(0) - newNode.off = malloc - newNode.buf = origin.buf[:malloc] - newNode.malloc = origin.malloc - newNode.readonly = false - origin.malloc = malloc - origin.readonly = true - - // link nodes - dataNode.next = newNode - newNode.next = origin.next - origin.next = dataNode + if remainLen > 0 { + newNode := newLinkBufferNode(0) + newNode.off = malloc + newNode.buf = origin.buf[:malloc] + newNode.malloc = origin.malloc + newNode.readonly = false + origin.malloc = malloc + origin.readonly = true + + // link nodes + dataNode.next = newNode + newNode.next = origin.next + origin.next = dataNode + } else { + // link nodes + dataNode.next = origin.next + origin.next = dataNode + } // adjust b.write for b.write.next != nil { diff --git a/nocopy_linkbuffer_race.go b/nocopy_linkbuffer_race.go index 7f2f274f..a785aa15 100644 --- a/nocopy_linkbuffer_race.go +++ b/nocopy_linkbuffer_race.go @@ -513,7 +513,7 @@ func (b *LinkBuffer) WriteDirect(p []byte, remainLen int) error { // find origin origin := b.flush malloc := b.mallocSize - remainLen // calculate the remaining malloc length - for t := origin.malloc - len(origin.buf); t <= malloc; t = origin.malloc - len(origin.buf) { + for t := origin.malloc - len(origin.buf); t < malloc; t = origin.malloc - len(origin.buf) { malloc -= t origin = origin.next } @@ -524,18 +524,24 @@ func (b *LinkBuffer) WriteDirect(p []byte, remainLen int) error { dataNode := newLinkBufferNode(0) dataNode.buf, dataNode.malloc = p[:0], n - newNode := newLinkBufferNode(0) - newNode.off = malloc - newNode.buf = origin.buf[:malloc] - newNode.malloc = origin.malloc - newNode.readonly = false - origin.malloc = malloc - origin.readonly = true - - // link nodes - dataNode.next = newNode - newNode.next = origin.next - origin.next = dataNode + if remainLen > 0 { + newNode := newLinkBufferNode(0) + newNode.off = malloc + newNode.buf = origin.buf[:malloc] + newNode.malloc = origin.malloc + newNode.readonly = false + origin.malloc = malloc + origin.readonly = true + + // link nodes + dataNode.next = newNode + newNode.next = origin.next + origin.next = dataNode + } else { + // link nodes + dataNode.next = origin.next + origin.next = dataNode + } // adjust b.write for b.write.next != nil { diff --git a/nocopy_linkbuffer_test.go b/nocopy_linkbuffer_test.go index a2f68fb4..c3f9b9d8 100644 --- a/nocopy_linkbuffer_test.go +++ b/nocopy_linkbuffer_test.go @@ -454,9 +454,11 @@ func TestWriteDirect(t *testing.T) { buf.WriteDirect([]byte("nopqrst"), 28) bt[4] = 'u' buf.WriteDirect([]byte("vwxyz"), 27) + copy(bt[5:], "abcdefghijklmnopqrstuvwxyza") + buf.WriteDirect([]byte("abcdefghijklmnopqrstuvwxyz"), 0) buf.Flush() bs := buf.Bytes() - str := "abcdefghijklmnopqrstuvwxyz" + str := "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzaabcdefghijklmnopqrstuvwxyz" for i := 0; i < len(str); i++ { if bs[i] != str[i] { t.Error("not equal!") diff --git a/poll_default.go b/poll_default.go new file mode 100644 index 00000000..e9aaa093 --- /dev/null +++ b/poll_default.go @@ -0,0 +1,75 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netpoll + +func (p *defaultPoll) Alloc() (operator *FDOperator) { + op := p.opcache.alloc() + op.poll = p + return op +} + +func (p *defaultPoll) Free(operator *FDOperator) { + p.opcache.freeable(operator) +} + +func (p *defaultPoll) appendHup(operator *FDOperator) { + p.hups = append(p.hups, operator.OnHup) + p.detach(operator) + operator.done() +} + +func (p *defaultPoll) detach(operator *FDOperator) { + if err := operator.Control(PollDetach); err != nil { + logger.Printf("NETPOLL: poller detach operator failed: %v", err) + } +} + +func (p *defaultPoll) onhups() { + if len(p.hups) == 0 { + return + } + hups := p.hups + p.hups = nil + go func(onhups []func(p Poll) error) { + for i := range onhups { + if onhups[i] != nil { + onhups[i](p) + } + } + }(hups) +} + +// readall read all left data before close connection +func readall(op *FDOperator, br barrier) (err error) { + var bs = br.bs + var ivs = br.ivs + var n int + for { + bs = op.Inputs(br.bs) + if len(bs) == 0 { + return nil + } + + TryRead: + n, err = ioread(op.FD, bs, ivs) + op.InputAck(n) + if err != nil { + return err + } + if n == 0 { + goto TryRead + } + } +} diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 62911bbd..3312e435 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -12,27 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build (darwin || netbsd || freebsd || openbsd || dragonfly) && !race +//go:build darwin || netbsd || freebsd || openbsd || dragonfly // +build darwin netbsd freebsd openbsd dragonfly -// +build !race package netpoll import ( + "sync" "sync/atomic" "syscall" "unsafe" ) -func openPoll() Poll { +func openPoll() (Poll, error) { return openDefaultPoll() } -func openDefaultPoll() *defaultPoll { +func openDefaultPoll() (*defaultPoll, error) { l := new(defaultPoll) p, err := syscall.Kqueue() if err != nil { - panic(err) + return nil, err } l.fd = p _, err = syscall.Kevent(l.fd, []syscall.Kevent_t{{ @@ -41,15 +41,17 @@ func openDefaultPoll() *defaultPoll { Flags: syscall.EV_ADD | syscall.EV_CLEAR, }}, nil, nil) if err != nil { - panic(err) + syscall.Close(l.fd) + return nil, err } l.opcache = newOperatorCache() - return l + return l, nil } type defaultPoll struct { fd int trigger uint32 + m sync.Map // only used in go:race opcache *operatorCache // operator cache hups []func(p Poll) error } @@ -64,6 +66,7 @@ func (p *defaultPoll) Wait() error { barriers[i].ivs = make([]syscall.Iovec, caps) } // wait + var triggerRead, triggerWrite, triggerHup bool for { n, err := syscall.Kevent(p.fd, nil, events, nil) if err != nil && err != syscall.EINTR { @@ -74,19 +77,24 @@ func (p *defaultPoll) Wait() error { return err } for i := 0; i < n; i++ { + var fd = int(events[i].Ident) // trigger - if events[i].Ident == 0 { + if fd == 0 { // clean trigger atomic.StoreUint32(&p.trigger, 0) continue } - var operator = *(**FDOperator)(unsafe.Pointer(&events[i].Udata)) - if !operator.do() { + var operator = p.getOperator(fd, unsafe.Pointer(&events[i].Udata)) + if operator == nil || !operator.do() { continue } - // check poll in - if events[i].Filter == syscall.EVFILT_READ && events[i].Flags&syscall.EV_ENABLE != 0 { + evt := events[i] + triggerRead = evt.Filter == syscall.EVFILT_READ && evt.Flags&syscall.EV_ENABLE != 0 + triggerWrite = evt.Filter == syscall.EVFILT_WRITE && evt.Flags&syscall.EV_ENABLE != 0 + triggerHup = evt.Flags&syscall.EV_EOF != 0 + + if triggerRead { if operator.OnRead != nil { // for non-connection operator.OnRead(p) @@ -94,25 +102,25 @@ func (p *defaultPoll) Wait() error { // only for connection var bs = operator.Inputs(barriers[i].bs) if len(bs) > 0 { - var n, err = readv(operator.FD, bs, barriers[i].ivs) + var n, err = ioread(operator.FD, bs, barriers[i].ivs) operator.InputAck(n) - if err != nil && err != syscall.EAGAIN && err != syscall.EINTR { - logger.Printf("NETPOLL: readv(fd=%d) failed: %s", operator.FD, err.Error()) + if err != nil { p.appendHup(operator) continue } } } } - - // check hup - if events[i].Flags&syscall.EV_EOF != 0 { + if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close + if err = readall(operator, barriers[i]); err != nil { + logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error()) + } + } + if triggerHup { p.appendHup(operator) continue } - - // check poll out - if events[i].Filter == syscall.EVFILT_WRITE && events[i].Flags&syscall.EV_ENABLE != 0 { + if triggerWrite { if operator.OnWrite != nil { // for non-connection operator.OnWrite(p) @@ -121,10 +129,9 @@ func (p *defaultPoll) Wait() error { var bs, supportZeroCopy = operator.Outputs(barriers[i].bs) if len(bs) > 0 { // TODO: Let the upper layer pass in whether to use ZeroCopy. - var n, err = sendmsg(operator.FD, bs, barriers[i].ivs, false && supportZeroCopy) + var n, err = iosend(operator.FD, bs, barriers[i].ivs, false && supportZeroCopy) operator.OutputAck(n) - if err != nil && err != syscall.EAGAIN { - logger.Printf("NETPOLL: sendmsg(fd=%d) failed: %s", operator.FD, err.Error()) + if err != nil { p.appendHup(operator) continue } @@ -134,7 +141,7 @@ func (p *defaultPoll) Wait() error { operator.done() } // hup conns together to avoid blocking the poll. - p.detaches() + p.onhups() p.opcache.free() } } @@ -162,7 +169,7 @@ func (p *defaultPoll) Trigger() error { func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { var evs = make([]syscall.Kevent_t, 1) evs[0].Ident = uint64(operator.FD) - *(**FDOperator)(unsafe.Pointer(&evs[0].Udata)) = operator + p.setOperator(unsafe.Pointer(&evs[0].Udata), operator) switch event { case PollReadable, PollModReadable: operator.inuse() @@ -171,6 +178,7 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { operator.inuse() evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE|syscall.EV_ONESHOT case PollDetach: + p.delOperator(operator) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE|syscall.EV_ONESHOT case PollR2RW: evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE @@ -180,34 +188,3 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { _, err := syscall.Kevent(p.fd, evs, nil, nil) return err } - -func (p *defaultPoll) Alloc() (operator *FDOperator) { - op := p.opcache.alloc() - op.poll = p - return op -} - -func (p *defaultPoll) Free(operator *FDOperator) { - p.opcache.freeable(operator) -} - -func (p *defaultPoll) appendHup(operator *FDOperator) { - p.hups = append(p.hups, operator.OnHup) - operator.Control(PollDetach) - operator.done() -} - -func (p *defaultPoll) detaches() { - if len(p.hups) == 0 { - return - } - hups := p.hups - p.hups = nil - go func(onhups []func(p Poll) error) { - for i := range onhups { - if onhups[i] != nil { - onhups[i](p) - } - } - }(hups) -} diff --git a/poll_default_bsd_norace.go b/poll_default_bsd_norace.go new file mode 100644 index 00000000..8a0266d0 --- /dev/null +++ b/poll_default_bsd_norace.go @@ -0,0 +1,33 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build (darwin || netbsd || freebsd || openbsd || dragonfly) && !race +// +build darwin netbsd freebsd openbsd dragonfly +// +build !race + +package netpoll + +import "unsafe" + +func (p *defaultPoll) getOperator(fd int, ptr unsafe.Pointer) *FDOperator { + return *(**FDOperator)(ptr) +} + +func (p *defaultPoll) setOperator(ptr unsafe.Pointer, operator *FDOperator) { + *(**FDOperator)(ptr) = operator +} + +func (p *defaultPoll) delOperator(operator *FDOperator) { + +} diff --git a/poll_default_bsd_race.go b/poll_default_bsd_race.go new file mode 100644 index 00000000..30baf6e0 --- /dev/null +++ b/poll_default_bsd_race.go @@ -0,0 +1,37 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build (darwin || netbsd || freebsd || openbsd || dragonfly) && race +// +build darwin netbsd freebsd openbsd dragonfly +// +build race + +package netpoll + +import "unsafe" + +func (p *defaultPoll) getOperator(fd int, ptr unsafe.Pointer) *FDOperator { + tmp, _ := p.m.Load(fd) + if tmp == nil { + return nil + } + return tmp.(*FDOperator) +} + +func (p *defaultPoll) setOperator(ptr unsafe.Pointer, operator *FDOperator) { + p.m.Store(operator.FD, operator) +} + +func (p *defaultPoll) delOperator(operator *FDOperator) { + p.m.Delete(operator.FD) +} diff --git a/poll_default_linux.go b/poll_default_linux.go index 290f33f2..8da7d55b 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -12,44 +12,48 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build !race -// +build !race - package netpoll import ( "runtime" + "sync" "sync/atomic" "syscall" "unsafe" ) -// Includes defaultPoll/multiPoll/uringPoll... -func openPoll() Poll { +func openPoll() (Poll, error) { return openDefaultPoll() } -func openDefaultPoll() *defaultPoll { - var poll = defaultPoll{} +func openDefaultPoll() (*defaultPoll, error) { + var poll = new(defaultPoll) + poll.buf = make([]byte, 8) var p, err = EpollCreate(0) if err != nil { - panic(err) + return nil, err } poll.fd = p + var r0, _, e0 = syscall.Syscall(syscall.SYS_EVENTFD2, 0, 0, 0) if e0 != 0 { - syscall.Close(p) - panic(err) + _ = syscall.Close(poll.fd) + return nil, e0 } poll.Reset = poll.reset poll.Handler = poll.handler - poll.wop = &FDOperator{FD: int(r0)} - poll.Control(poll.wop, PollReadable) + + if err = poll.Control(poll.wop, PollReadable); err != nil { + _ = syscall.Close(poll.wop.FD) + _ = syscall.Close(poll.fd) + return nil, err + } + poll.opcache = newOperatorCache() - return &poll + return poll, nil } type defaultPoll struct { @@ -58,6 +62,7 @@ type defaultPoll struct { wop *FDOperator // eventfd, wake epoll_wait buf []byte // read wfd trigger msg trigger uint32 // trigger flag + m sync.Map // only used in go:race opcache *operatorCache // operator cache // fns for handle events Reset func(size, caps int) @@ -110,11 +115,19 @@ func (p *defaultPoll) Wait() (err error) { } func (p *defaultPoll) handler(events []epollevent) (closed bool) { + var triggerRead, triggerWrite, triggerHup, triggerError bool for i := range events { - var operator = *(**FDOperator)(unsafe.Pointer(&events[i].data)) - if !operator.do() { + operator := p.getOperator(0, unsafe.Pointer(&events[i].data)) + if operator == nil || !operator.do() { continue } + + evt := events[i].events + triggerRead = evt&syscall.EPOLLIN != 0 + triggerWrite = evt&syscall.EPOLLOUT != 0 + triggerHup = evt&(syscall.EPOLLHUP|syscall.EPOLLRDHUP) != 0 + triggerError = evt&syscall.EPOLLERR != 0 + // trigger or exit gracefully if operator.FD == p.wop.FD { // must clean trigger first @@ -131,9 +144,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { continue } - evt := events[i].events - // check poll in - if evt&syscall.EPOLLIN != 0 { + if triggerRead { if operator.OnRead != nil { // for non-connection operator.OnRead(p) @@ -141,10 +152,9 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { // for connection var bs = operator.Inputs(p.barriers[i].bs) if len(bs) > 0 { - var n, err = readv(operator.FD, bs, p.barriers[i].ivs) + var n, err = ioread(operator.FD, bs, p.barriers[i].ivs) operator.InputAck(n) - if err != nil && err != syscall.EAGAIN && err != syscall.EINTR { - logger.Printf("NETPOLL: readv(fd=%d) failed: %s", operator.FD, err.Error()) + if err != nil { p.appendHup(operator) continue } @@ -153,13 +163,16 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { logger.Printf("NETPOLL: operator has critical problem! event=%d operator=%v", evt, operator) } } - - // check hup - if evt&(syscall.EPOLLHUP|syscall.EPOLLRDHUP) != 0 { + if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close + if err := readall(operator, p.barriers[i]); err != nil { + logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error()) + } + } + if triggerHup { p.appendHup(operator) continue } - if evt&syscall.EPOLLERR != 0 { + if triggerError { // Under block-zerocopy, the kernel may give an error callback, which is not a real error, just an EAGAIN. // So here we need to check this error, if it is EAGAIN then do nothing, otherwise still mark as hup. if _, _, _, _, err := syscall.Recvmsg(operator.FD, nil, nil, syscall.MSG_ERRQUEUE); err != syscall.EAGAIN { @@ -169,8 +182,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { } continue } - // check poll out - if evt&syscall.EPOLLOUT != 0 { + if triggerWrite { if operator.OnWrite != nil { // for non-connection operator.OnWrite(p) @@ -179,10 +191,9 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { var bs, supportZeroCopy = operator.Outputs(p.barriers[i].bs) if len(bs) > 0 { // TODO: Let the upper layer pass in whether to use ZeroCopy. - var n, err = sendmsg(operator.FD, bs, p.barriers[i].ivs, false && supportZeroCopy) + var n, err = iosend(operator.FD, bs, p.barriers[i].ivs, false && supportZeroCopy) operator.OutputAck(n) - if err != nil && err != syscall.EAGAIN { - logger.Printf("NETPOLL: sendmsg(fd=%d) failed: %s", operator.FD, err.Error()) + if err != nil { p.appendHup(operator) continue } @@ -194,7 +205,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { operator.done() } // hup conns together to avoid blocking the poll. - p.detaches() + p.onhups() return false } @@ -218,7 +229,7 @@ func (p *defaultPoll) Trigger() error { func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { var op int var evt epollevent - *(**FDOperator)(unsafe.Pointer(&evt.data)) = operator + p.setOperator(unsafe.Pointer(&evt.data), operator) switch event { case PollReadable: // server accept a new connection and wait read operator.inuse() @@ -229,6 +240,7 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { case PollModReadable: // client wait read/write op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollDetach: // deregister + p.delOperator(operator) op, evt.events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollR2RW: // connection wait read/write op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR @@ -237,36 +249,3 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { } return EpollCtl(p.fd, op, operator.FD, &evt) } - -func (p *defaultPoll) Alloc() (operator *FDOperator) { - op := p.opcache.alloc() - op.poll = p - return op -} - -func (p *defaultPoll) Free(operator *FDOperator) { - p.opcache.freeable(operator) -} - -func (p *defaultPoll) appendHup(operator *FDOperator) { - p.hups = append(p.hups, operator.OnHup) - if err := operator.Control(PollDetach); err != nil { - logger.Printf("NETPOLL: poller detach operator failed: %v", err) - } - operator.done() -} - -func (p *defaultPoll) detaches() { - if len(p.hups) == 0 { - return - } - hups := p.hups - p.hups = nil - go func(onhups []func(p Poll) error) { - for i := range onhups { - if onhups[i] != nil { - onhups[i](p) - } - } - }(hups) -} diff --git a/poll_default_linux_norace.go b/poll_default_linux_norace.go new file mode 100644 index 00000000..29d5e6be --- /dev/null +++ b/poll_default_linux_norace.go @@ -0,0 +1,32 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux && !race +// +build linux,!race + +package netpoll + +import "unsafe" + +func (p *defaultPoll) getOperator(fd int, ptr unsafe.Pointer) *FDOperator { + return *(**FDOperator)(ptr) +} + +func (p *defaultPoll) setOperator(ptr unsafe.Pointer, operator *FDOperator) { + *(**FDOperator)(ptr) = operator +} + +func (p *defaultPoll) delOperator(operator *FDOperator) { + +} diff --git a/poll_default_linux_race.go b/poll_default_linux_race.go new file mode 100644 index 00000000..775b587b --- /dev/null +++ b/poll_default_linux_race.go @@ -0,0 +1,43 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux && race +// +build linux,race + +package netpoll + +import "unsafe" + +type eventdata struct { + fd int32 + pad int32 +} + +func (p *defaultPoll) getOperator(fd int, ptr unsafe.Pointer) *FDOperator { + data := *(*eventdata)(ptr) + tmp, _ := p.m.Load(int(data.fd)) + if tmp == nil { + return nil + } + return tmp.(*FDOperator) +} + +func (p *defaultPoll) setOperator(ptr unsafe.Pointer, operator *FDOperator) { + *(*eventdata)(ptr) = eventdata{fd: int32(operator.FD)} + p.m.Store(operator.FD, operator) +} + +func (p *defaultPoll) delOperator(operator *FDOperator) { + p.m.Delete(operator.FD) +} diff --git a/poll_default_linux_test.go b/poll_default_linux_test.go index 4cb72eed..acd0afc9 100644 --- a/poll_default_linux_test.go +++ b/poll_default_linux_test.go @@ -18,6 +18,7 @@ package netpoll import ( + "errors" "syscall" "testing" @@ -167,7 +168,7 @@ func TestEpollETClose(t *testing.T) { events := make([]epollevent, 128) eventdata := [8]byte{0, 0, 0, 0, 0, 0, 0, 1} event := &epollevent{ - events: EPOLLET | syscall.EPOLLOUT | syscall.EPOLLRDHUP | syscall.EPOLLERR, + events: EPOLLET | syscall.EPOLLIN | syscall.EPOLLOUT | syscall.EPOLLRDHUP | syscall.EPOLLERR, data: eventdata, } @@ -175,6 +176,7 @@ func TestEpollETClose(t *testing.T) { err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd, event) _, err = EpollWait(epollfd, events, -1) MustNil(t, err) + Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) Assert(t, events[0].events&syscall.EPOLLERR == 0) @@ -190,7 +192,7 @@ func TestEpollETClose(t *testing.T) { MustNil(t, err) // EPOLL: close peer fd - // EPOLLOUT + // EPOLLIN and EPOLLOUT rfd, wfd = GetSysFdPairs() err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd, event) err = syscall.Close(wfd) @@ -198,9 +200,14 @@ func TestEpollETClose(t *testing.T) { n, err = EpollWait(epollfd, events, 100) MustNil(t, err) Assert(t, n == 1, n) + Assert(t, events[0].events&syscall.EPOLLIN != 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) Assert(t, events[0].events&syscall.EPOLLRDHUP != 0) Assert(t, events[0].events&syscall.EPOLLERR == 0) + buf := make([]byte, 1024) + ivs := make([]syscall.Iovec, 1) + n, err = ioread(rfd, [][]byte{buf}, ivs) // EOF + Assert(t, n == 0 && errors.Is(err, ErrEOF), n, err) } func TestEpollETDel(t *testing.T) { diff --git a/poll_manager.go b/poll_manager.go index 2c2e8097..119187c0 100644 --- a/poll_manager.go +++ b/poll_manager.go @@ -107,13 +107,24 @@ func (m *manager) Close() error { } // Run all pollers. -func (m *manager) Run() error { +func (m *manager) Run() (err error) { + defer func() { + if err != nil { + _ = m.Close() + } + }() + // new poll to fill delta. for idx := len(m.polls); idx < m.NumLoops; idx++ { - var poll = openPoll() + var poll Poll + poll, err = openPoll() + if err != nil { + return + } m.polls = append(m.polls, poll) go poll.Wait() } + // LoadBalance must be set before calling Run, otherwise it will panic. m.balance.Rebalance(m.polls) return nil diff --git a/poll_race_bsd.go b/poll_race_bsd.go deleted file mode 100644 index 5caf393d..00000000 --- a/poll_race_bsd.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2022 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build (darwin || netbsd || freebsd || openbsd || dragonfly) && race -// +build darwin netbsd freebsd openbsd dragonfly -// +build race - -package netpoll - -import ( - "sync" - "sync/atomic" - "syscall" -) - -// mock no race poll -func openPoll() Poll { - return openDefaultPoll() -} - -func openDefaultPoll() *defaultPoll { - l := new(defaultPoll) - p, err := syscall.Kqueue() - if err != nil { - panic(err) - } - l.fd = p - _, err = syscall.Kevent(l.fd, []syscall.Kevent_t{{ - Ident: 0, - Filter: syscall.EVFILT_USER, - Flags: syscall.EV_ADD | syscall.EV_CLEAR, - }}, nil, nil) - if err != nil { - panic(err) - } - l.opcache = newOperatorCache() - return l -} - -type defaultPoll struct { - fd int - trigger uint32 - m sync.Map - opcache *operatorCache // operator cache - hups []func(p Poll) error -} - -// Wait implements Poll. -func (p *defaultPoll) Wait() error { - // init - var size, caps = 1024, barriercap - var events, barriers = make([]syscall.Kevent_t, size), make([]barrier, size) - for i := range barriers { - barriers[i].bs = make([][]byte, caps) - barriers[i].ivs = make([]syscall.Iovec, caps) - } - // wait - for { - n, err := syscall.Kevent(p.fd, nil, events, nil) - if err != nil && err != syscall.EINTR { - // exit gracefully - if err == syscall.EBADF { - return nil - } - return err - } - for i := 0; i < n; i++ { - var fd = int(events[i].Ident) - // trigger - if fd == 0 { - // clean trigger - atomic.StoreUint32(&p.trigger, 0) - continue - } - tmp, ok := p.m.Load(fd) - if !ok { - continue - } - operator := tmp.(*FDOperator) - if !operator.do() { - continue - } - - // check poll in - if events[i].Filter == syscall.EVFILT_READ && events[i].Flags&syscall.EV_ENABLE != 0 { - if operator.OnRead != nil { - // for non-connection - operator.OnRead(p) - } else { - // only for connection - var bs = operator.Inputs(barriers[i].bs) - if len(bs) > 0 { - var n, err = readv(operator.FD, bs, barriers[i].ivs) - operator.InputAck(n) - if err != nil && err != syscall.EAGAIN && err != syscall.EINTR { - logger.Printf("NETPOLL: readv(fd=%d) failed: %s", operator.FD, err.Error()) - p.appendHup(operator) - continue - } - } - } - } - - // check hup - if events[i].Flags&syscall.EV_EOF != 0 { - p.appendHup(operator) - continue - } - - // check poll out - if events[i].Filter == syscall.EVFILT_WRITE && events[i].Flags&syscall.EV_ENABLE != 0 { - if operator.OnWrite != nil { - // for non-connection - operator.OnWrite(p) - } else { - // only for connection - var bs, supportZeroCopy = operator.Outputs(barriers[i].bs) - if len(bs) > 0 { - // TODO: Let the upper layer pass in whether to use ZeroCopy. - var n, err = sendmsg(operator.FD, bs, barriers[i].ivs, false && supportZeroCopy) - operator.OutputAck(n) - if err != nil && err != syscall.EAGAIN { - logger.Printf("NETPOLL: sendmsg(fd=%d) failed: %s", operator.FD, err.Error()) - p.appendHup(operator) - continue - } - } - } - } - operator.done() - } - // hup conns together to avoid blocking the poll. - p.detaches() - p.opcache.free() - } -} - -// TODO: Close will bad file descriptor here -func (p *defaultPoll) Close() error { - var err = syscall.Close(p.fd) - // delete all *FDOperator - p.m.Range(func(key, value interface{}) bool { - var operator, _ = value.(*FDOperator) - if operator.OnHup != nil { - operator.OnHup(p) - } - return true - }) - return err -} - -// Trigger implements Poll. -func (p *defaultPoll) Trigger() error { - if atomic.AddUint32(&p.trigger, 1) > 1 { - return nil - } - _, err := syscall.Kevent(p.fd, []syscall.Kevent_t{{ - Ident: 0, - Filter: syscall.EVFILT_USER, - Fflags: syscall.NOTE_TRIGGER, - }}, nil, nil) - return err -} - -// Control implements Poll. -func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { - var evs = make([]syscall.Kevent_t, 1) - evs[0].Ident = uint64(operator.FD) - switch event { - case PollReadable, PollModReadable: - operator.inuse() - p.m.Store(operator.FD, operator) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE - case PollDetach: - p.m.Delete(operator.FD) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE|syscall.EV_ONESHOT - case PollWritable: - operator.inuse() - p.m.Store(operator.FD, operator) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE|syscall.EV_ONESHOT - case PollR2RW: - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE - case PollRW2R: - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE|syscall.EV_ONESHOT - } - _, err := syscall.Kevent(p.fd, evs, nil, nil) - return err -} - -func (p *defaultPoll) Alloc() (operator *FDOperator) { - op := p.opcache.alloc() - op.poll = p - return op -} - -func (p *defaultPoll) Free(operator *FDOperator) { - p.opcache.freeable(operator) -} - -func (p *defaultPoll) appendHup(operator *FDOperator) { - p.hups = append(p.hups, operator.OnHup) - if err := operator.Control(PollDetach); err != nil { - logger.Printf("NETPOLL: poller detach operator failed: %v", err) - } - operator.done() -} - -func (p *defaultPoll) detaches() { - if len(p.hups) == 0 { - return - } - hups := p.hups - p.hups = nil - go func(onhups []func(p Poll) error) { - for i := range onhups { - if onhups[i] != nil { - onhups[i](p) - } - } - }(hups) -} diff --git a/poll_race_linux.go b/poll_race_linux.go deleted file mode 100644 index 254e5c89..00000000 --- a/poll_race_linux.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2022 CloudWeGo Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build race -// +build race - -package netpoll - -import ( - "runtime" - "sync" - "sync/atomic" - "syscall" -) - -// mock no race poll -func openPoll() Poll { - return openDefaultPoll() -} - -func openDefaultPoll() *defaultPoll { - var poll = defaultPoll{} - poll.buf = make([]byte, 8) - var p, err = EpollCreate(0) - if err != nil { - panic(err) - } - poll.fd = p - var r0, _, e0 = syscall.Syscall(syscall.SYS_EVENTFD2, 0, 0, 0) - if e0 != 0 { - syscall.Close(p) - panic(err) - } - poll.wfd = int(r0) - poll.Control(&FDOperator{FD: poll.wfd}, PollReadable) - poll.opcache = newOperatorCache() - return &poll -} - -type defaultPoll struct { - pollArgs - fd int // epoll fd - wfd int // wake epoll wait - buf []byte // read wfd trigger msg - trigger uint32 // trigger flag - m sync.Map - opcache *operatorCache // operator cache -} - -type pollArgs struct { - size int - caps int - events []syscall.EpollEvent - barriers []barrier - hups []func(p Poll) error -} - -func (a *pollArgs) reset(size, caps int) { - a.size, a.caps = size, caps - a.events, a.barriers = make([]syscall.EpollEvent, size), make([]barrier, size) - for i := range a.barriers { - a.barriers[i].bs = make([][]byte, a.caps) - a.barriers[i].ivs = make([]syscall.Iovec, a.caps) - } -} - -// Wait implements Poll. -func (p *defaultPoll) Wait() (err error) { - // init - var caps, msec, n = barriercap, -1, 0 - p.reset(128, caps) - // wait - for { - if n == p.size && p.size < 128*1024 { - p.reset(p.size<<1, caps) - } - n, err = syscall.EpollWait(p.fd, p.events, msec) - if err != nil && err != syscall.EINTR { - return err - } - if n <= 0 { - msec = -1 - runtime.Gosched() - continue - } - msec = 0 - if p.handler(p.events[:n]) { - return nil - } - p.opcache.free() - } -} - -func (p *defaultPoll) handler(events []syscall.EpollEvent) (closed bool) { - for i := range events { - var fd = int(events[i].Fd) - // trigger or exit gracefully - if fd == p.wfd { - // must clean trigger first - syscall.Read(p.wfd, p.buf) - atomic.StoreUint32(&p.trigger, 0) - // if closed & exit - if p.buf[0] > 0 { - syscall.Close(p.wfd) - syscall.Close(p.fd) - return true - } - continue - } - tmp, ok := p.m.Load(fd) - if !ok { - continue - } - operator := tmp.(*FDOperator) - if !operator.do() { - continue - } - - evt := events[i].Events - // check poll in - if evt&syscall.EPOLLIN != 0 { - if operator.OnRead != nil { - // for non-connection - operator.OnRead(p) - } else if operator.Inputs != nil { - // for connection - var bs = operator.Inputs(p.barriers[i].bs) - if len(bs) > 0 { - var n, err = readv(operator.FD, bs, p.barriers[i].ivs) - operator.InputAck(n) - if err != nil && err != syscall.EAGAIN && err != syscall.EINTR { - logger.Printf("NETPOLL: readv(fd=%d) failed: %s", operator.FD, err.Error()) - p.appendHup(operator) - continue - } - } - } else { - logger.Printf("NETPOLL: operator has critical problem! event=%d operator=%v", evt, operator) - } - } - - // check hup - if evt&(syscall.EPOLLHUP|syscall.EPOLLRDHUP) != 0 { - p.appendHup(operator) - continue - } - if evt&syscall.EPOLLERR != 0 { - // Under block-zerocopy, the kernel may give an error callback, which is not a real error, just an EAGAIN. - // So here we need to check this error, if it is EAGAIN then do nothing, otherwise still mark as hup. - if _, _, _, _, err := syscall.Recvmsg(operator.FD, nil, nil, syscall.MSG_ERRQUEUE); err != syscall.EAGAIN { - p.appendHup(operator) - } else { - operator.done() - } - continue - } - - // check poll out - if evt&syscall.EPOLLOUT != 0 { - if operator.OnWrite != nil { - // for non-connection - operator.OnWrite(p) - } else if operator.Outputs != nil { - // for connection - var bs, supportZeroCopy = operator.Outputs(p.barriers[i].bs) - if len(bs) > 0 { - // TODO: Let the upper layer pass in whether to use ZeroCopy. - var n, err = sendmsg(operator.FD, bs, p.barriers[i].ivs, false && supportZeroCopy) - operator.OutputAck(n) - if err != nil && err != syscall.EAGAIN { - logger.Printf("NETPOLL: sendmsg(fd=%d) failed: %s", operator.FD, err.Error()) - p.appendHup(operator) - continue - } - } - } else { - logger.Printf("NETPOLL: operator has critical problem! event=%d operator=%v", evt, operator) - } - } - operator.done() - } - // hup conns together to avoid blocking the poll. - p.detaches() - return false -} - -// Close will write 10000000 -func (p *defaultPoll) Close() error { - _, err := syscall.Write(p.wfd, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - // delete all *FDOperator - p.m.Range(func(key, value interface{}) bool { - var operator, _ = value.(*FDOperator) - if operator.OnHup != nil { - operator.OnHup(p) - } - return true - }) - return err -} - -// Trigger implements Poll. -func (p *defaultPoll) Trigger() error { - if atomic.AddUint32(&p.trigger, 1) > 1 { - return nil - } - // MAX(eventfd) = 0xfffffffffffffffe - _, err := syscall.Write(p.wfd, []byte{0, 0, 0, 0, 0, 0, 0, 1}) - return err -} - -// Control implements Poll. -func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { - var op int - var evt syscall.EpollEvent - evt.Fd = int32(operator.FD) - switch event { - case PollReadable: - operator.inuse() - p.m.Store(operator.FD, operator) - op, evt.Events = syscall.EPOLL_CTL_ADD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollWritable: - operator.inuse() - p.m.Store(operator.FD, operator) - op, evt.Events = syscall.EPOLL_CTL_ADD, EPOLLET|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollModReadable: - p.m.Store(operator.FD, operator) - op, evt.Events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollDetach: - p.m.Delete(operator.FD) - op, evt.Events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollR2RW: - op, evt.Events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollRW2R: - op, evt.Events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR - } - return syscall.EpollCtl(p.fd, op, operator.FD, &evt) -} - -func (p *defaultPoll) Alloc() (operator *FDOperator) { - op := p.opcache.alloc() - op.poll = p - return op -} - -func (p *defaultPoll) Free(operator *FDOperator) { - p.opcache.freeable(operator) -} - -func (p *defaultPoll) appendHup(operator *FDOperator) { - p.hups = append(p.hups, operator.OnHup) - operator.Control(PollDetach) - operator.done() -} - -func (p *defaultPoll) detaches() { - if len(p.hups) == 0 { - return - } - hups := p.hups - p.hups = nil - go func(onhups []func(p Poll) error) { - for i := range onhups { - if onhups[i] != nil { - onhups[i](p) - } - } - }(hups) -} diff --git a/poll_test.go b/poll_test.go index c30dab83..5980dde7 100644 --- a/poll_test.go +++ b/poll_test.go @@ -31,7 +31,9 @@ func TestPollTrigger(t *testing.T) { t.Skip() var trigger int var stop = make(chan error) - var p = openDefaultPoll() + var p, err = openDefaultPoll() + MustNil(t, err) + go func() { stop <- p.Wait() }() @@ -46,7 +48,7 @@ func TestPollTrigger(t *testing.T) { Equal(t, trigger, 2) p.Close() - err := <-stop + err = <-stop MustNil(t, err) } @@ -65,7 +67,8 @@ func TestPollMod(t *testing.T) { return nil } var stop = make(chan error) - var p = openDefaultPoll() + var p, err = openDefaultPoll() + MustNil(t, err) go func() { stop <- p.Wait() }() @@ -73,7 +76,6 @@ func TestPollMod(t *testing.T) { var rfd, wfd = GetSysFdPairs() var rop = &FDOperator{FD: rfd, OnRead: read, OnWrite: write, OnHup: hup, poll: p} var wop = &FDOperator{FD: wfd, OnRead: read, OnWrite: write, OnHup: hup, poll: p} - var err error var r, w, h int32 r, w, h = atomic.LoadInt32(&rn), atomic.LoadInt32(&wn), atomic.LoadInt32(&hn) Assert(t, r == 0 && w == 0 && h == 0, r, w, h) @@ -113,7 +115,8 @@ func TestPollMod(t *testing.T) { } func TestPollClose(t *testing.T) { - var p = openDefaultPoll() + var p, err = openDefaultPoll() + MustNil(t, err) var wg sync.WaitGroup wg.Add(1) go func() { @@ -126,7 +129,7 @@ func TestPollClose(t *testing.T) { func BenchmarkPollMod(b *testing.B) { b.StopTimer() - var p = openDefaultPoll() + var p, _ = openDefaultPoll() r, _ := GetSysFdPairs() var operator = &FDOperator{FD: r} p.Control(operator, PollReadable)