diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a9208b2 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +client-dist/* linguist-generated \ No newline at end of file diff --git a/go.mod b/go.mod index eb119e2..4ab4423 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,9 @@ go 1.22.2 require ( github.com/andybalholm/brotli v1.1.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/zishang520/engine.io-go-parser v1.2.4 - github.com/zishang520/engine.io/v2 v2.0.8 - github.com/zishang520/socket.io-go-parser/v2 v2.0.7 + github.com/zishang520/engine.io-go-parser v1.2.5 + github.com/zishang520/engine.io/v2 v2.1.0 + github.com/zishang520/socket.io-go-parser/v2 v2.1.0 ) require ( @@ -17,17 +17,17 @@ require ( github.com/gorilla/websocket v1.5.1 // indirect github.com/onsi/ginkgo/v2 v2.12.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect - github.com/quic-go/quic-go v0.43.0 // indirect + github.com/quic-go/quic-go v0.44.0 // indirect github.com/quic-go/webtransport-go v0.8.0 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect go.uber.org/mock v0.4.0 // indirect - golang.org/x/crypto v0.21.0 // indirect - golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect - golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.23.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 // indirect + golang.org/x/crypto v0.23.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.15.0 // indirect + golang.org/x/tools v0.21.0 // indirect ) diff --git a/go.sum b/go.sum index 278303c..94d6eef 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f h1:pDhu5sgp8yJlEF/g6osliIIpF9K4F5jvkULXa4daRDQ= github.com/google/pprof v0.0.0-20230821062121-407c9e7a662f/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= @@ -29,8 +29,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.43.0 h1:sjtsTKWX0dsHpuMJvLxGqoQdtgJnbAPWY+W+5vjYW/g= -github.com/quic-go/quic-go v0.43.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= +github.com/quic-go/quic-go v0.44.0 h1:So5wOr7jyO4vzL2sd8/pD9Kesciv91zSk8BoFngItQ0= +github.com/quic-go/quic-go v0.44.0/go.mod h1:z4cx/9Ny9UtGITIPzmPTXh1ULfOyWh4qGQlpnPcWmek= github.com/quic-go/webtransport-go v0.8.0 h1:HxSrwun11U+LlmwpgM1kEqIqH90IT4N8auv/cD7QFJg= github.com/quic-go/webtransport-go v0.8.0/go.mod h1:N99tjprW432Ut5ONql/aUhSLT0YVSlwHohQsuac9WaM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -43,30 +43,32 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= -github.com/zishang520/engine.io-go-parser v1.2.4 h1:37h7Mt3Msc3aqub6hl+mYlG6wB81O4zcynrZQIjG41s= -github.com/zishang520/engine.io-go-parser v1.2.4/go.mod h1:G1DciRIGH4/S7x01DIdZQaXrk09ZeRgEw5e/Z9ms4Is= -github.com/zishang520/engine.io/v2 v2.0.8 h1:84rkbpWPzblAMj62uYsaD+XuZQTJTempSTCaxzemNSA= -github.com/zishang520/engine.io/v2 v2.0.8/go.mod h1:z9wFZLzqW1ykzWA84jt//1x0dQjMSim1G3SzIPovdHw= -github.com/zishang520/socket.io-go-parser/v2 v2.0.7 h1:Pcv668c8PYhyeQpaw5/MqV+D9x4p01p5K9ygSYOnYp8= -github.com/zishang520/socket.io-go-parser/v2 v2.0.7/go.mod h1:O/6sR1SjIm8bZvMS3GqwT29TvxdxGYvugvBbRA+a/Zg= +github.com/zishang520/engine.io-go-parser v1.2.5 h1:Disf4rvNQzDsgoC+3yuwuFx5A7JNWlPp+QLUW32WDtc= +github.com/zishang520/engine.io-go-parser v1.2.5/go.mod h1:G1DciRIGH4/S7x01DIdZQaXrk09ZeRgEw5e/Z9ms4Is= +github.com/zishang520/engine.io/v2 v2.1.0 h1:dh3O7OcAfqfhg7AhqlqPRM/6pfdAcoRlEmNbe2wv8qE= +github.com/zishang520/engine.io/v2 v2.1.0/go.mod h1:FnXtT+k/6g2uOb9MpqY71DhV7COwlCH5DCbczn6Q3K8= +github.com/zishang520/socket.io-go-parser/v2 v2.1.0 h1:YaTul861UxdTtq/v7XKmF52gWmDOqwugKBlFyiifKCE= +github.com/zishang520/socket.io-go-parser/v2 v2.1.0/go.mod h1:zmToGML+lCjSjyGZMuVtnvgnFOnDuAxJZKwfDDDHiqI= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= -golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= -golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E= -golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= +golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/socket/adapter.go b/socket/adapter.go index 0c5e87d..a391645 100644 --- a/socket/adapter.go +++ b/socket/adapter.go @@ -187,10 +187,10 @@ func (a *adapter) BroadcastWithAck(packet *parser.Packet, opts *BroadcastOptions id := a.nsp.Ids() packet.Id = &id encodedPackets := a._encode(packet, packetOpts) - clientCount := uint64(0) + var clientCount atomic.Uint64 a.apply(opts, func(socket *Socket) { // track the total number of acknowledgements that are expected - atomic.AddUint64(&clientCount, 1) + clientCount.Add(1) // call the ack callback for each client response socket.Acks().Store(*packet.Id, ack) if notifyOutgoingListeners := socket.NotifyOutgoingListeners(); notifyOutgoingListeners != nil { @@ -198,7 +198,7 @@ func (a *adapter) BroadcastWithAck(packet *parser.Packet, opts *BroadcastOptions } socket.Client().WriteToEngine(encodedPackets, packetOpts) }) - clientCountCallback(atomic.LoadUint64(&clientCount)) + clientCountCallback(clientCount.Load()) } func (a *adapter) _encode(packet *parser.Packet, packetOpts *WriteOptions) []_types.BufferInterface { diff --git a/socket/broadcast-operator.go b/socket/broadcast-operator.go index e205144..ce1302d 100644 --- a/socket/broadcast-operator.go +++ b/socket/broadcast-operator.go @@ -3,7 +3,6 @@ package socket import ( "errors" "fmt" - "sync" "sync/atomic" "time" @@ -200,9 +199,8 @@ func (b *BroadcastOperator) Emit(ev string, args ...any) error { packet.Data = data[:data_len-1] - timedOut := uint32(0) - responses := []any{} - var responsesMu sync.RWMutex + var timedOut atomic.Bool + responses := types.NewSlice[any]() var timeout time.Duration if time := b.flags.Timeout; time != nil { @@ -210,31 +208,26 @@ func (b *BroadcastOperator) Emit(ev string, args ...any) error { } timer := utils.SetTimeout(func() { - atomic.StoreUint32(&timedOut, 1) + timedOut.Store(true) if b.flags.ExpectSingleResponse { ack(nil, errors.New("operation has timed out")) } else { - responsesMu.RLock() - defer responsesMu.RUnlock() - - ack(responses, errors.New("operation has timed out")) + ack(responses.All(), errors.New("operation has timed out")) } }, timeout) expectedServerCount := int64(-1) - actualServerCount := int64(0) - expectedClientCount := uint64(0) + var actualServerCount atomic.Int64 + var expectedClientCount atomic.Uint64 checkCompleteness := func() { - responsesMu.RLock() - defer responsesMu.RUnlock() - - if 0 == atomic.LoadUint32(&timedOut) && expectedServerCount == atomic.LoadInt64(&actualServerCount) && uint64(len(responses)) == atomic.LoadUint64(&expectedClientCount) { + if !timedOut.Load() && expectedServerCount == actualServerCount.Load() && uint64(responses.Len()) == expectedClientCount.Load() { utils.ClearTimeout(timer) if b.flags.ExpectSingleResponse { - ack(responses[0].([]any), nil) + data, _ := responses.Get(0) + ack(data.([]any), nil) } else { - ack(responses, nil) + ack(responses.All(), nil) } } } @@ -245,14 +238,12 @@ func (b *BroadcastOperator) Emit(ev string, args ...any) error { Flags: b.flags, }, func(clientCount uint64) { // each Socket.IO server in the cluster sends the number of clients that were notified - atomic.AddUint64(&expectedClientCount, clientCount) - atomic.AddInt64(&actualServerCount, 1) + expectedClientCount.Add(clientCount) + actualServerCount.Add(1) checkCompleteness() }, func(clientResponse []any, _ error) { // each client sends an acknowledgement - responsesMu.Lock() - responses = append(responses, clientResponse...) - responsesMu.Unlock() + responses.Push(clientResponse...) checkCompleteness() }) expectedServerCount = b.adapter.ServerCount() diff --git a/socket/client.go b/socket/client.go index 5ed859d..f80b42e 100644 --- a/socket/client.go +++ b/socket/client.go @@ -2,7 +2,7 @@ package socket import ( "net/url" - "sync" + "sync/atomic" _types "github.com/zishang520/engine.io-go-parser/types" "github.com/zishang520/engine.io/v2/engine" @@ -17,14 +17,13 @@ var client_log = log.NewLog("socket.io:client") type Client struct { conn engine.Socket - id string - server *Server - encoder parser.Encoder - decoder parser.Decoder - sockets *types.Map[SocketId, *Socket] - nsps *types.Map[string, *Socket] - connectTimeout *utils.Timer - connectTimeout_mu sync.Mutex + id string + server *Server + encoder parser.Encoder + decoder parser.Decoder + sockets *types.Map[SocketId, *Socket] + nsps *types.Map[string, *Socket] + connectTimeout atomic.Pointer[utils.Timer] } func MakeClient() *Client { @@ -74,17 +73,14 @@ func (c *Client) setup() { c.conn.On("error", c.onerror) c.conn.On("close", c.onclose) - c.connectTimeout_mu.Lock() - defer c.connectTimeout_mu.Unlock() - - c.connectTimeout = utils.SetTimeout(func() { + c.connectTimeout.Store(utils.SetTimeout(func() { if c.nsps.Len() == 0 { client_log.Debug("no namespace joined yet, close the client") c.close() } else { client_log.Debug("the client has already joined a namespace, nothing to do") } - }, c.server._connectTimeout) + }, c.server._connectTimeout)) } // Connects a client to a namespace. @@ -124,11 +120,9 @@ func (c *Client) doConnect(name string, auth any) { nsp.Add(c, auth, func(socket *Socket) { c.sockets.Store(socket.Id(), socket) c.nsps.Store(nsp.Name(), socket) - c.connectTimeout_mu.Lock() - defer c.connectTimeout_mu.Unlock() - if c.connectTimeout != nil { - utils.ClearTimeout(c.connectTimeout) - c.connectTimeout = nil + if connectTimeout := c.connectTimeout.Load(); connectTimeout != nil { + utils.ClearTimeout(connectTimeout) + c.connectTimeout.Store(nil) } }) } @@ -259,10 +253,8 @@ func (c *Client) destroy() { c.conn.RemoveListener("close", c.onclose) c.decoder.RemoveListener("decoded", c.ondecoded) - c.connectTimeout_mu.Lock() - defer c.connectTimeout_mu.Unlock() - if c.connectTimeout != nil { - utils.ClearTimeout(c.connectTimeout) - c.connectTimeout = nil + if connectTimeout := c.connectTimeout.Load(); connectTimeout != nil { + utils.ClearTimeout(connectTimeout) + c.connectTimeout.Store(nil) } } diff --git a/socket/namespace.go b/socket/namespace.go index d49fd01..f7ef847 100644 --- a/socket/namespace.go +++ b/socket/namespace.go @@ -3,7 +3,6 @@ package socket import ( "errors" "fmt" - "sync" "sync/atomic" "time" @@ -68,9 +67,7 @@ var ( // // ensure the socket has access to the "users" namespace // }) type Namespace struct { - // _ids has to be first in the struct to guarantee alignment for atomic - // operations. http://golang.org/pkg/sync/atomic/#pkg-note-BUG - _ids uint64 + _ids atomic.Uint64 *StrictEventEmitter @@ -89,8 +86,7 @@ type Namespace struct { server *Server - _fns []func(*Socket, func(*ExtendedError)) - _fns_mu sync.RWMutex + _fns *types.Slice[func(*Socket, func(*ExtendedError))] _remove func(socket *Socket) } @@ -100,8 +96,7 @@ func MakeNamespace() *Namespace { StrictEventEmitter: NewStrictEventEmitter(), sockets: &types.Map[SocketId, *Socket]{}, - _fns: []func(*Socket, func(*ExtendedError)){}, - _ids: 0, + _fns: types.NewSlice[func(*Socket, func(*ExtendedError))](), } n._remove = n.namespace_remove @@ -147,20 +142,7 @@ func (n *Namespace) Name() string { } func (n *Namespace) Ids() uint64 { - return atomic.AddUint64(&n._ids, 1) -} - -func (n *Namespace) fns() []func(*Socket, func(*ExtendedError)) { - n._fns_mu.RLock() - defer n._fns_mu.RUnlock() - - return n._fns -} -func (n *Namespace) useFns(_fns []func(*Socket, func(*ExtendedError))) { - n._fns_mu.Lock() - defer n._fns_mu.Unlock() - - n._fns = _fns + return n._ids.Add(1) } func (n *Namespace) Construct(server *Server, name string) { @@ -189,10 +171,7 @@ func (n *Namespace) InitAdapter() { // // Param: func(*ExtendedError) - the middleware function func (n *Namespace) Use(fn func(*Socket, func(*ExtendedError))) NamespaceInterface { - n._fns_mu.Lock() - defer n._fns_mu.Unlock() - - n._fns = append(n._fns, fn) + n._fns.Push(fn) return n } @@ -202,10 +181,7 @@ func (n *Namespace) Use(fn func(*Socket, func(*ExtendedError))) NamespaceInterfa // // Param: fn - last fn call in the middleware func (n *Namespace) run(socket *Socket, fn func(err *ExtendedError)) { - n._fns_mu.RLock() - fns := make([]func(*Socket, func(*ExtendedError)), len(n._fns)) - copy(fns, n._fns) - n._fns_mu.RUnlock() + fns := n._fns.All() if length := len(fns); length > 0 { var run func(i int) run = func(i int) { diff --git a/socket/parent-namespace.go b/socket/parent-namespace.go index 71f9b72..3841a72 100644 --- a/socket/parent-namespace.go +++ b/socket/parent-namespace.go @@ -12,7 +12,7 @@ import ( var ( parent_namespace_log = log.NewLog("socket.io:parent-namespace") - count uint64 = 0 + count atomic.Uint64 ) // A parent namespace is a special [Namespace] that holds a list of child namespaces which were created either @@ -48,7 +48,7 @@ func MakeParentNamespace() *ParentNamespace { func NewParentNamespace(server *Server) *ParentNamespace { n := MakeParentNamespace() - n.Construct(server, "/_"+strconv.FormatUint(atomic.AddUint64(&count, 1)-1, 10)) + n.Construct(server, "/_"+strconv.FormatUint(count.Add(1)-1, 10)) return n } @@ -68,10 +68,7 @@ func (p *ParentNamespace) CreateChild(name string) *Namespace { parent_namespace_log.Debug("creating child namespace %s", name) namespace := NewNamespace(p.Server(), name) - _p_fns := p.fns() - _fns := make([]func(*Socket, func(*ExtendedError)), len(_p_fns)) - copy(_fns, _p_fns) - namespace.useFns(_fns) + namespace._fns.Replace(p._fns.All()) namespace.AddListener("connect", p.Listeners("connect")...) namespace.AddListener("connection", p.Listeners("connection")...) diff --git a/socket/session-aware-adapter.go b/socket/session-aware-adapter.go index 5e2dff2..2510317 100644 --- a/socket/session-aware-adapter.go +++ b/socket/session-aware-adapter.go @@ -1,7 +1,6 @@ package socket import ( - "sync" "time" "github.com/zishang520/engine.io/v2/types" @@ -19,9 +18,8 @@ type ( maxDisconnectionDuration int64 - sessions *types.Map[PrivateSessionId, *SessionWithTimestamp] - packets []*PersistedPacket - mu_packets sync.RWMutex + sessions *types.Map[PrivateSessionId, *SessionWithTimestamp] + packets *types.Slice[*PersistedPacket] } ) @@ -34,7 +32,7 @@ func MakeSessionAwareAdapter() Adapter { Adapter: MakeAdapter(), sessions: &types.Map[PrivateSessionId, *SessionWithTimestamp]{}, - packets: []*PersistedPacket{}, + packets: types.NewSlice[*PersistedPacket](), } s.Prototype(s) @@ -62,16 +60,9 @@ func (s *sessionAwareAdapter) Construct(nsp NamespaceInterface) { } return true }) - s.mu_packets.Lock() - defer s.mu_packets.Unlock() - - for i, packet := range s.packets { - if packet.EmittedAt < threshold { - copy(s.packets, s.packets[i+1:]) - s.packets = s.packets[:len(s.packets)-i-1] - break - } - } + s.packets.RangeAndSplice(func(packet *PersistedPacket, i int) (bool, int, int, []*PersistedPacket) { + return packet.EmittedAt < threshold, 0, i + 1, nil + }, true) }, 60*1000*time.Millisecond) // prevents the timer from keeping the process alive timer.Unref() @@ -96,28 +87,25 @@ func (s *sessionAwareAdapter) RestoreSession(pid PrivateSessionId, offset string return nil, nil } - s.mu_packets.RLock() - defer s.mu_packets.RUnlock() - // Find the index of the packet with the given offset - index := -1 - for i, packet := range s.packets { - if packet.Id == offset { - index = i - break - } - } + index := s.packets.FindIndex(func(packet *PersistedPacket) bool { + return packet.Id == offset + }) if index == -1 { + // the offset may be too old return nil, nil } // Use a pre-allocated slice to avoid memory allocation in the loop - missedPackets := make([]any, 0, len(s.packets)-index-1) + missedPackets := make([]any, 0, s.packets.Len()-index-1) missedNum := 0 // Iterate over the packets and append the data of those that should be included - for i := index + 1; i < len(s.packets); i++ { - packet := s.packets[i] + for i := index + 1; i < s.packets.Len(); i++ { + packet, err := s.packets.Get(i) + if err != nil { + break + } if shouldIncludePacket(session.Rooms, packet.Opts) { missedPackets = append(missedPackets, packet.Data) missedNum++ @@ -143,10 +131,7 @@ func (s *sessionAwareAdapter) Broadcast(packet *parser.Packet, opts *BroadcastOp // processed (and the format is backward-compatible) packet.Data = append(packet.Data.([]any), id) - s.mu_packets.Lock() - defer s.mu_packets.Unlock() - - s.packets = append(s.packets, &PersistedPacket{ + s.packets.Push(&PersistedPacket{ Id: id, EmittedAt: time.Now().UnixMilli(), Data: packet.Data, diff --git a/socket/socket.go b/socket/socket.go index 5237a0f..adab04d 100644 --- a/socket/socket.go +++ b/socket/socket.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "sync" + "sync/atomic" "time" "github.com/zishang520/engine.io/v2/engine" @@ -94,8 +95,7 @@ type ( // Additional information that can be attached to the Socket instance and which will be used in the // [Server.fetchSockets()] method. - data any - data_mu sync.RWMutex + data atomic.Pointer[any] // Whether the socket is currently connected or not. // @@ -108,8 +108,7 @@ type ( // socket := args[0].(*socket.Socket) // fmt.Println(socket.Connected()) // true // }) - connected bool - connected_mu sync.RWMutex + connected atomic.Bool // The session ID, which must not be shared (unlike [id]). pid PrivateSessionId @@ -118,18 +117,12 @@ type ( server *Server adapter Adapter acks *types.Map[uint64, func([]any, error)] - fns []func([]any, func(error)) - flags *BroadcastFlags - _anyListeners []events.Listener - _anyOutgoingListeners []events.Listener - - canJoin bool - canJoin_mu sync.RWMutex - - fns_mu sync.RWMutex - flags_mu sync.RWMutex - _anyListeners_mu sync.RWMutex - _anyOutgoingListeners_mu sync.RWMutex + fns *types.Slice[func([]any, func(error))] + flags atomic.Pointer[BroadcastFlags] + _anyListeners *types.Slice[events.Listener] + _anyOutgoingListeners *types.Slice[events.Listener] + + canJoin atomic.Bool } ) @@ -138,11 +131,13 @@ func MakeSocket() *Socket { StrictEventEmitter: NewStrictEventEmitter(), // Initialize default value - acks: &types.Map[uint64, func([]any, error)]{}, - fns: []func([]any, func(error)){}, - flags: &BroadcastFlags{}, - canJoin: true, + acks: &types.Map[uint64, func([]any, error)]{}, + fns: types.NewSlice[func([]any, func(error))](), + _anyListeners: types.NewSlice[events.Listener](), + _anyOutgoingListeners: types.NewSlice[events.Listener](), } + s.flags.Store(&BroadcastFlags{}) + s.canJoin.Store(true) return s } @@ -174,16 +169,13 @@ func (s *Socket) Handshake() *Handshake { // Additional information that can be attached to the Socket instance and which will be used in the // [Server.fetchSockets()] method. func (s *Socket) SetData(data any) { - s.data_mu.Lock() - defer s.data_mu.Unlock() - - s.data = data + s.data.Store(&data) } func (s *Socket) Data() any { - s.data_mu.RLock() - defer s.data_mu.RUnlock() - - return s.data + if data := s.data.Load(); data != nil { + return *data + } + return nil } // Whether the socket is currently connected or not. @@ -198,10 +190,7 @@ func (s *Socket) Data() any { // fmt.Println(socket.Connected()) // true // }) func (s *Socket) Connected() bool { - s.connected_mu.RLock() - defer s.connected_mu.RUnlock() - - return s.connected + return s.connected.Load() } func (s *Socket) Acks() *types.Map[uint64, func([]any, error)] { @@ -307,10 +296,8 @@ func (s *Socket) Emit(ev string, args ...any) error { s.registerAckCallback(id, fn) packet.Id = &id } - s.flags_mu.Lock() - flags := *s.flags - s.flags = &BroadcastFlags{} - s.flags_mu.Unlock() + flags := *s.flags.Load() + s.flags.Store(&BroadcastFlags{}) if s.nsp.Server().Opts().GetRawConnectionStateRecovery() != nil { // this ensures the packet is stored and can be transmitted upon reconnection @@ -358,9 +345,7 @@ func (s *Socket) EmitWithAck(ev string, args ...any) func(func([]any, error)) { } func (s *Socket) registerAckCallback(id uint64, ack func([]any, error)) { - s.flags_mu.RLock() - timeout := s.flags.Timeout - s.flags_mu.RUnlock() + timeout := s.flags.Load().Timeout if timeout == nil { s.acks.Store(id, ack) return @@ -488,12 +473,9 @@ func (s *Socket) packet(packet *parser.Packet, opts *BroadcastFlags) { // // Param: Room - a `Room`, or a `Room` slice to expand func (s *Socket) Join(rooms ...Room) { - s.canJoin_mu.RLock() - if !s.canJoin { - defer s.canJoin_mu.RUnlock() + if !s.canJoin.Load() { return } - s.canJoin_mu.RUnlock() socket_log.Debug("join room %s", rooms) s.adapter.AddAll(s.id, types.NewSet(rooms...)) @@ -529,9 +511,7 @@ func (s *Socket) leaveAll() { func (s *Socket) _onconnect() { socket_log.Debug("socket connected - writing packet") - s.connected_mu.Lock() - s.connected = true - s.connected_mu.Unlock() + s.connected.Store(true) s.Join(Room(s.id)) if s.Conn().Protocol() == 3 { @@ -576,16 +556,8 @@ func (s *Socket) onevent(packet *parser.Packet) { socket_log.Debug("attaching ack callback to event") args = append(args, s.ack(*packet.Id)) } - s._anyListeners_mu.RLock() - if s._anyListeners != nil && len(s._anyListeners) > 0 { - listeners := make([]events.Listener, len(s._anyListeners)) - copy(listeners, s._anyListeners) - s._anyListeners_mu.RUnlock() - for _, listener := range listeners { - listener(args...) - } - } else { - s._anyListeners_mu.RUnlock() + for _, listener := range s._anyListeners.All() { + listener(args...) } s.dispatch(args) } @@ -660,9 +632,7 @@ func (s *Socket) _onclose(args ...any) *Socket { } s._cleanup() s.client._remove(s) - s.connected_mu.Lock() - s.connected = false - s.connected_mu.Unlock() + s.connected.Store(false) s.EmitReserved("disconnect", args...) return nil } @@ -671,10 +641,7 @@ func (s *Socket) _onclose(args ...any) *Socket { func (s *Socket) _cleanup() { s.leaveAll() s.nsp.remove(s) - s.canJoin_mu.Lock() - s.canJoin = false - s.canJoin_mu.Unlock() - + s.canJoin.Store(false) } // Produces an `error` packet. @@ -721,9 +688,7 @@ func (s *Socket) Disconnect(status bool) *Socket { // // Param: compress - if `true`, compresses the sending data func (s *Socket) Compress(compress bool) *Socket { - s.flags_mu.Lock() - s.flags.Compress = compress - s.flags_mu.Unlock() + s.flags.Load().Compress = compress return s } @@ -736,9 +701,7 @@ func (s *Socket) Compress(compress bool) *Socket { // socket.Volatile().Emit("hello") // the client may or may not receive it // }) func (s *Socket) Volatile() *Socket { - s.flags_mu.Lock() - s.flags.Volatile = true - s.flags_mu.Unlock() + s.flags.Load().Volatile = true return s } @@ -781,9 +744,7 @@ func (s *Socket) Local() *BroadcastOperator { // }) // }) func (s *Socket) Timeout(timeout time.Duration) *Socket { - s.flags_mu.Lock() - s.flags.Timeout = &timeout - s.flags_mu.Unlock() + s.flags.Load().Timeout = &timeout return s } @@ -827,10 +788,7 @@ func (s *Socket) dispatch(event []any) { // // Param: fn - middleware function (event, next) func (s *Socket) Use(fn func([]any, func(error))) *Socket { - s.fns_mu.Lock() - defer s.fns_mu.Unlock() - - s.fns = append(s.fns, fn) + s.fns.Push(fn) return s } @@ -840,10 +798,7 @@ func (s *Socket) Use(fn func([]any, func(error))) *Socket { // // Pparam: fn - last fn call in the middleware func (s *Socket) run(event []any, fn func(error)) { - s.fns_mu.RLock() - fns := make([]func([]any, func(error)), len(s.fns)) - copy(fns, s.fns) - s.fns_mu.RUnlock() + fns := s.fns.All() if length := len(fns); length > 0 { var run func(i int) run = func(i int) { @@ -921,13 +876,7 @@ func (s *Socket) Rooms() *types.Set[Room] { // // Param: events.Listener func (s *Socket) OnAny(listener events.Listener) *Socket { - s._anyListeners_mu.Lock() - defer s._anyListeners_mu.Unlock() - - if s._anyListeners == nil { - s._anyListeners = []events.Listener{} - } - s._anyListeners = append(s._anyListeners, listener) + s._anyListeners.Push(listener) return s } @@ -936,13 +885,7 @@ func (s *Socket) OnAny(listener events.Listener) *Socket { // // Param: events.Listener func (s *Socket) PrependAny(listener events.Listener) *Socket { - s._anyListeners_mu.Lock() - defer s._anyListeners_mu.Unlock() - - if s._anyListeners == nil { - s._anyListeners = []events.Listener{} - } - s._anyListeners = append([]events.Listener{listener}, s._anyListeners...) + s._anyListeners.Unshift(listener) return s } @@ -965,23 +908,13 @@ func (s *Socket) PrependAny(listener events.Listener) *Socket { // // Param: events.Listener func (s *Socket) OffAny(listener events.Listener) *Socket { - s._anyListeners_mu.Lock() - defer s._anyListeners_mu.Unlock() - - if len(s._anyListeners) == 0 { - return s - } if listener != nil { - listenerPointer := reflect.ValueOf(listener).Pointer() - for i, _listener := range s._anyListeners { - if listenerPointer == reflect.ValueOf(_listener).Pointer() { - copy(s._anyListeners[i:], s._anyListeners[i+1:]) - s._anyListeners = s._anyListeners[:len(s._anyListeners)-1] - return s - } - } + anyListeners := reflect.ValueOf(listener).Pointer() + s._anyListeners.RangeAndSplice(func(listener events.Listener, i int) (bool, int, int, []events.Listener) { + return reflect.ValueOf(listener).Pointer() == anyListeners, i, 1, nil + }) } else { - s._anyListeners = []events.Listener{} + s._anyListeners.Clear() } return s } @@ -989,13 +922,7 @@ func (s *Socket) OffAny(listener events.Listener) *Socket { // Returns an array of listeners that are listening for any event that is specified. This array can be manipulated, // e.g. to remove listeners. func (s *Socket) ListenersAny() []events.Listener { - s._anyListeners_mu.Lock() - defer s._anyListeners_mu.Unlock() - - if s._anyListeners == nil { - s._anyListeners = []events.Listener{} - } - return s._anyListeners + return s._anyListeners.All() } // Adds a listener that will be fired when any event is sent. The event name is passed as the first argument to @@ -1012,13 +939,7 @@ func (s *Socket) ListenersAny() []events.Listener { // // Param: events.Listener func (s *Socket) OnAnyOutgoing(listener events.Listener) *Socket { - s._anyOutgoingListeners_mu.Lock() - defer s._anyOutgoingListeners_mu.Unlock() - - if s._anyOutgoingListeners == nil { - s._anyOutgoingListeners = []events.Listener{} - } - s._anyOutgoingListeners = append(s._anyOutgoingListeners, listener) + s._anyOutgoingListeners.Push(listener) return s } @@ -1032,13 +953,7 @@ func (s *Socket) OnAnyOutgoing(listener events.Listener) *Socket { // }) // }) func (s *Socket) PrependAnyOutgoing(listener events.Listener) *Socket { - s._anyOutgoingListeners_mu.Lock() - defer s._anyOutgoingListeners_mu.Unlock() - - if s._anyOutgoingListeners == nil { - s._anyOutgoingListeners = []events.Listener{} - } - s._anyOutgoingListeners = append([]events.Listener{listener}, s._anyOutgoingListeners...) + s._anyOutgoingListeners.Unshift(listener) return s } @@ -1061,23 +976,13 @@ func (s *Socket) PrependAnyOutgoing(listener events.Listener) *Socket { // // Param: events.Listener - the catch-all listener func (s *Socket) OffAnyOutgoing(listener events.Listener) *Socket { - s._anyOutgoingListeners_mu.Lock() - defer s._anyOutgoingListeners_mu.Unlock() - - if s._anyOutgoingListeners == nil { - return s - } if listener != nil { listenerPointer := reflect.ValueOf(listener).Pointer() - for i, _listener := range s._anyOutgoingListeners { - if listenerPointer == reflect.ValueOf(_listener).Pointer() { - copy(s._anyOutgoingListeners[i:], s._anyOutgoingListeners[i+1:]) - s._anyOutgoingListeners = s._anyOutgoingListeners[:len(s._anyOutgoingListeners)-1] - return s - } - } + s._anyOutgoingListeners.RangeAndSplice(func(listener events.Listener, i int) (bool, int, int, []events.Listener) { + return reflect.ValueOf(listener).Pointer() == listenerPointer, i, 1, nil + }) } else { - s._anyOutgoingListeners = []events.Listener{} + s._anyOutgoingListeners.Clear() } return s } @@ -1085,31 +990,17 @@ func (s *Socket) OffAnyOutgoing(listener events.Listener) *Socket { // Returns an array of listeners that are listening for any event that is specified. This array can be manipulated, // e.g. to remove listeners. func (s *Socket) ListenersAnyOutgoing() []events.Listener { - s._anyOutgoingListeners_mu.Lock() - defer s._anyOutgoingListeners_mu.Unlock() - - if s._anyOutgoingListeners == nil { - s._anyOutgoingListeners = []events.Listener{} - } - return s._anyOutgoingListeners + return s._anyOutgoingListeners.All() } // Notify the listeners for each packet sent (emit or broadcast) func (s *Socket) notifyOutgoingListeners(packet *parser.Packet) { - s._anyOutgoingListeners_mu.RLock() - if s._anyOutgoingListeners != nil && len(s._anyOutgoingListeners) > 0 { - listeners := make([]events.Listener, len(s._anyOutgoingListeners)) - copy(listeners, s._anyOutgoingListeners) - s._anyOutgoingListeners_mu.RUnlock() - for _, listener := range listeners { - if args, ok := packet.Data.([]any); ok { - listener(args...) - } else { - listener(packet.Data) - } + for _, listener := range s._anyOutgoingListeners.All() { + if args, ok := packet.Data.([]any); ok { + listener(args...) + } else { + listener(packet.Data) } - } else { - s._anyOutgoingListeners_mu.RUnlock() } } func (s *Socket) NotifyOutgoingListeners() func(*parser.Packet) { @@ -1117,9 +1008,7 @@ func (s *Socket) NotifyOutgoingListeners() func(*parser.Packet) { } func (s *Socket) newBroadcastOperator() *BroadcastOperator { - s.flags_mu.Lock() - flags := *s.flags - s.flags = &BroadcastFlags{} - s.flags_mu.Unlock() + flags := *s.flags.Load() + s.flags.Store(&BroadcastFlags{}) return NewBroadcastOperator(s.adapter, types.NewSet[Room](), types.NewSet(Room(s.id)), &flags) }