From 78dd6367e901068d403066fd72abad8192cb2a1a Mon Sep 17 00:00:00 2001 From: Luoyy <10894778+zishang520@users.noreply.github.com> Date: Thu, 5 Sep 2024 15:55:05 +0800 Subject: [PATCH 1/2] Fix nil pointer dereference in socket server close method - Added nil check for s.engine before calling Close() to prevent runtime panic - Resolved issue where a nil pointer dereference occurred when closing the socket server - This fixes the error: "runtime error: invalid memory address or nil pointer dereference" --- socket/server.go | 4 +++- socket/socket.go | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/socket/server.go b/socket/server.go index 4af2579..da2db52 100644 --- a/socket/server.go +++ b/socket/server.go @@ -602,7 +602,9 @@ func (s *Server) Close(fn func(error)) { if s.httpServer != nil { s.httpServer.Close(fn) } else { - s.engine.Close() + if s.engine != nil { + s.engine.Close() + } if fn != nil { fn(nil) } diff --git a/socket/socket.go b/socket/socket.go index 8f46330..b3f1ce2 100644 --- a/socket/socket.go +++ b/socket/socket.go @@ -661,7 +661,7 @@ func (s *Socket) _error(err any) { // socket.Disconnect(true) // }) // -// Param: status - if `true`, closes the underlying connection +// Param: status - if `true`, closes the underlying connection func (s *Socket) Disconnect(status bool) *Socket { if !s.Connected() { return s @@ -712,7 +712,7 @@ func (s *Socket) Volatile() *Socket { // socket.Broadcast().Emit("foo", "bar") // }) // -// Return: a new [BroadcastOperator] instance for chaining +// Return: a new [BroadcastOperator] instance for chaining func (s *Socket) Broadcast() *BroadcastOperator { return s.newBroadcastOperator() } @@ -725,7 +725,7 @@ func (s *Socket) Broadcast() *BroadcastOperator { // socket.Local().Emit("foo", "bar") // }) // -// Return: a new [BroadcastOperator] instance for chaining +// Return: a new [BroadcastOperator] instance for chaining func (s *Socket) Local() *BroadcastOperator { return s.newBroadcastOperator().Local() } From 97ddd02c7069c293f43c07e3fab04b2baf2d2e2a Mon Sep 17 00:00:00 2001 From: luoyy Date: Sun, 15 Sep 2024 00:48:32 +0800 Subject: [PATCH 2/2] refactor(adapter): enhanced cluster support and message handling - Refactored adapter for improved cluster support and message handling. - Fixed `ack` type for `SERVER_SIDE_EMIT` message in `ClusterAdapter`. - Opened up helper functions for reuse in other adapters. - Fixed null pointer issue that could occur when `EncodeOptions` and `DecodeOptions` parameters are `nil`. - Renamed `ClusterSocket` to `RemoteSocket`, and `NewClusterSocket` to `NewRemoteSocket`. - Refactored `ServerSideEmit` and `ServerSideEmitWithAck` methods for better structure. --- README.md | 4 +- adapter/adapter-type.go | 37 ++ adapter/adapter.go | 22 + adapter/cluster-adapter-options.go | 76 +++ adapter/cluster-adapter-type.go | 132 ++++ .../cluster-adapter-with-heartbeat-type.go | 18 + adapter/cluster-adapter-with-heartbeat.go | 332 ++++++++++ adapter/cluster-adapter.go | 570 ++++++++++++++++++ adapter/remote-socket.go | 50 ++ adapter/session-aware-adapter.go | 22 + adapter/util.go | 80 +++ go.mod | 22 +- go.sum | 48 +- socket/adapter-type.go | 12 +- socket/adapter-type_test.go | 167 +++++ socket/adapter.go | 18 +- socket/broadcast-operator.go | 109 +--- socket/client.go | 7 +- socket/namespace-type.go | 12 +- socket/namespace.go | 38 +- socket/parent-broadcast-adapter.go | 1 - socket/parent-namespace.go | 9 +- socket/remote-socket.go | 117 ++++ socket/server.go | 20 +- socket/session-aware-adapter.go | 4 +- socket/socket.go | 28 +- socket/type.go | 14 - 27 files changed, 1750 insertions(+), 219 deletions(-) create mode 100644 adapter/adapter-type.go create mode 100644 adapter/adapter.go create mode 100644 adapter/cluster-adapter-options.go create mode 100644 adapter/cluster-adapter-type.go create mode 100644 adapter/cluster-adapter-with-heartbeat-type.go create mode 100644 adapter/cluster-adapter-with-heartbeat.go create mode 100644 adapter/cluster-adapter.go create mode 100644 adapter/remote-socket.go create mode 100644 adapter/session-aware-adapter.go create mode 100644 adapter/util.go create mode 100644 socket/adapter-type_test.go create mode 100644 socket/remote-socket.go delete mode 100644 socket/type.go diff --git a/README.md b/README.md index e86a388..0936bed 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ func main() { client.Emit("auth", client.Handshake().Auth) client.On("message-with-ack", func(args ...interface{}) { - ack := args[len(args)-1].(func([]any, error)) + ack := args[len(args)-1].(socket.Ack) ack(args[:len(args)-1], nil) }) }) @@ -286,7 +286,7 @@ func main() { client.Emit("auth", client.Handshake().Auth) client.On("message-with-ack", func(args ...interface{}) { - ack := args[len(args)-1].(func([]any, error)) + ack := args[len(args)-1].(socket.Ack) ack(args[:len(args)-1], nil) }) }) diff --git a/adapter/adapter-type.go b/adapter/adapter-type.go new file mode 100644 index 0000000..3161252 --- /dev/null +++ b/adapter/adapter-type.go @@ -0,0 +1,37 @@ +package adapter + +import ( + "github.com/zishang520/socket.io/v2/socket" +) + +type ( + Adapter = socket.Adapter + + SessionAwareAdapter = socket.SessionAwareAdapter + + AdapterConstructor = socket.AdapterConstructor + + // A cluster-ready adapter. Any extending interface must: + // + // - implement [ClusterAdapter.DoPublish] and [ClusterAdapter.DoPublishResponse] + // + // - call [ClusterAdapter.OnMessage] and [ClusterAdapter.OnResponse] + ClusterAdapter interface { + Adapter + + Uid() ServerId + OnMessage(*ClusterMessage, Offset) + OnResponse(*ClusterResponse) + Publish(*ClusterMessage) + PublishAndReturnOffset(*ClusterMessage) (Offset, error) + DoPublish(*ClusterMessage) (Offset, error) + PublishResponse(ServerId, *ClusterResponse) + DoPublishResponse(ServerId, *ClusterResponse) error + } + + ClusterAdapterWithHeartbeat interface { + ClusterAdapter + + SetOpts(*ClusterAdapterOptions) + } +) diff --git a/adapter/adapter.go b/adapter/adapter.go new file mode 100644 index 0000000..c8d0d15 --- /dev/null +++ b/adapter/adapter.go @@ -0,0 +1,22 @@ +package adapter + +import ( + "github.com/zishang520/socket.io/v2/socket" +) + +type ( + AdapterBuilder struct { + } +) + +func (*AdapterBuilder) New(nsp socket.Namespace) Adapter { + return NewAdapter(nsp) +} + +func MakeAdapter() Adapter { + return socket.MakeAdapter() +} + +func NewAdapter(nsp socket.Namespace) Adapter { + return socket.NewAdapter(nsp) +} diff --git a/adapter/cluster-adapter-options.go b/adapter/cluster-adapter-options.go new file mode 100644 index 0000000..3f14d49 --- /dev/null +++ b/adapter/cluster-adapter-options.go @@ -0,0 +1,76 @@ +package adapter + +import ( + "time" +) + +type ( + ClusterAdapterOptionsInterface interface { + SetHeartbeatInterval(time.Duration) + GetRawHeartbeatInterval() *time.Duration + HeartbeatInterval() time.Duration + + SetHeartbeatTimeout(int64) + GetRawHeartbeatTimeout() *int64 + HeartbeatTimeout() int64 + } + + ClusterAdapterOptions struct { + // The number of ms between two heartbeats. + // + // Default: 5_000 * time.Millisecond + heartbeatInterval *time.Duration + + // The number of ms without heartbeat before we consider a node down. + // + // Default: 10_000 + heartbeatTimeout *int64 + } +) + +func DefaultClusterAdapterOptions() *ClusterAdapterOptions { + return &ClusterAdapterOptions{} +} + +func (s *ClusterAdapterOptions) Assign(data ClusterAdapterOptionsInterface) (ClusterAdapterOptionsInterface, error) { + if data == nil { + return s, nil + } + if s.GetRawHeartbeatInterval() == nil { + s.SetHeartbeatInterval(data.HeartbeatInterval()) + } + + if s.GetRawHeartbeatTimeout() == nil { + s.SetHeartbeatTimeout(data.HeartbeatTimeout()) + } + + return s, nil +} + +func (s *ClusterAdapterOptions) SetHeartbeatInterval(heartbeatInterval time.Duration) { + s.heartbeatInterval = &heartbeatInterval +} +func (s *ClusterAdapterOptions) GetRawHeartbeatInterval() *time.Duration { + return s.heartbeatInterval +} +func (s *ClusterAdapterOptions) HeartbeatInterval() time.Duration { + if s.heartbeatInterval == nil { + return time.Duration(5_000 * time.Millisecond) + } + + return *s.heartbeatInterval +} + +func (s *ClusterAdapterOptions) SetHeartbeatTimeout(heartbeatTimeout int64) { + s.heartbeatTimeout = &heartbeatTimeout +} +func (s *ClusterAdapterOptions) GetRawHeartbeatTimeout() *int64 { + return s.heartbeatTimeout +} +func (s *ClusterAdapterOptions) HeartbeatTimeout() int64 { + if s.heartbeatTimeout == nil { + return 10_000 + } + + return *s.heartbeatTimeout +} diff --git a/adapter/cluster-adapter-type.go b/adapter/cluster-adapter-type.go new file mode 100644 index 0000000..fca5ec9 --- /dev/null +++ b/adapter/cluster-adapter-type.go @@ -0,0 +1,132 @@ +package adapter + +import ( + "sync/atomic" + "time" + + "github.com/zishang520/engine.io/v2/types" + "github.com/zishang520/engine.io/v2/utils" + "github.com/zishang520/socket.io-go-parser/v2/parser" + "github.com/zishang520/socket.io/v2/socket" +) + +type ( + // The unique ID of a server + ServerId string + + // The unique ID of a message (for the connection state recovery feature) + Offset string + + MessageType int + + // Common fields for all messages + ClusterMessage struct { + Uid ServerId `json:"uid,omitempty" msgpack:"uid,omitempty"` + Nsp string `json:"nsp,omitempty" msgpack:"nsp,omitempty"` + Type MessageType `json:"type,omitempty" msgpack:"type,omitempty"` + Data any `json:"data,omitempty" msgpack:"data,omitempty"` // Data will hold the specific message data for different types + } + + // PacketOptions represents the options for broadcasting messages. + PacketOptions struct { + Rooms []socket.Room `json:"rooms,omitempty" msgpack:"rooms,omitempty"` + Except []socket.Room `json:"except,omitempty" msgpack:"except,omitempty"` + Flags *socket.BroadcastFlags `json:"flags,omitempty" msgpack:"flags,omitempty"` + } + + // Message for BROADCAST + BroadcastMessage struct { + Opts *PacketOptions `json:"opts,omitempty" msgpack:"opts,omitempty"` + Packet *parser.Packet `json:"packet,omitempty" msgpack:"packet,omitempty"` + RequestId *string `json:"requestId,omitempty" msgpack:"requestId,omitempty"` + } + + // Message for SOCKETS_JOIN, SOCKETS_LEAVE + SocketsJoinLeaveMessage struct { + Opts *PacketOptions `json:"opts,omitempty" msgpack:"opts,omitempty"` + Rooms []socket.Room `json:"rooms,omitempty" msgpack:"rooms,omitempty"` + } + + // Message for DISCONNECT_SOCKETS + DisconnectSocketsMessage struct { + Opts *PacketOptions `json:"opts,omitempty" msgpack:"opts,omitempty"` + Close bool `json:"close,omitempty" msgpack:"close,omitempty"` + } + + // Message for FETCH_SOCKETS + FetchSocketsMessage struct { + Opts *PacketOptions `json:"opts,omitempty" msgpack:"opts,omitempty"` + RequestId string `json:"requestId,omitempty" msgpack:"requestId,omitempty"` + } + + // Message for SERVER_SIDE_EMIT + ServerSideEmitMessage struct { + RequestId *string `json:"requestId,omitempty" msgpack:"requestId,omitempty"` + Packet []any `json:"packet,omitempty" msgpack:"packet,omitempty"` + } + + // ClusterRequest equivalent + ClusterRequest struct { + Type MessageType + Resolve func(*types.Slice[any]) + Timeout *atomic.Pointer[utils.Timer] + Expected int64 + Current *atomic.Int64 + Responses *types.Slice[any] + } + + ClusterResponse = ClusterMessage + + SocketResponse struct { + Id socket.SocketId `json:"id,omitempty" msgpack:"id,omitempty"` + Handshake *socket.Handshake `json:"handshake,omitempty" msgpack:"handshake,omitempty"` + Rooms []socket.Room `json:"rooms,omitempty" msgpack:"rooms,omitempty"` + Data any `json:"data,omitempty" msgpack:"data,omitempty"` + } + + FetchSocketsResponse struct { + RequestId string `json:"requestId,omitempty" msgpack:"requestId,omitempty"` + Sockets []*SocketResponse `json:"sockets,omitempty" msgpack:"sockets,omitempty"` + } + + ServerSideEmitResponse struct { + RequestId string `json:"requestId,omitempty" msgpack:"requestId,omitempty"` + Packet []any `json:"packet,omitempty" msgpack:"packet,omitempty"` + } + + BroadcastClientCount struct { + RequestId string `json:"requestId,omitempty" msgpack:"requestId,omitempty"` + ClientCount uint64 `json:"clientCount,omitempty" msgpack:"clientCount,omitempty"` + } + + BroadcastAck struct { + RequestId string `json:"requestId,omitempty" msgpack:"requestId,omitempty"` + Packet []any `json:"packet,omitempty" msgpack:"packet,omitempty"` + } + + ClusterAckRequest struct { + ClientCountCallback func(uint64) + Ack socket.Ack + } +) + +const ( + EMITTER_UID ServerId = "emitter" + DEFAULT_TIMEOUT time.Duration = 5_000 * time.Millisecond +) + +const ( + INITIAL_HEARTBEAT MessageType = iota + 1 + HEARTBEAT + BROADCAST + SOCKETS_JOIN + SOCKETS_LEAVE + DISCONNECT_SOCKETS + FETCH_SOCKETS + FETCH_SOCKETS_RESPONSE + SERVER_SIDE_EMIT + SERVER_SIDE_EMIT_RESPONSE + BROADCAST_CLIENT_COUNT + BROADCAST_ACK + ADAPTER_CLOSE +) diff --git a/adapter/cluster-adapter-with-heartbeat-type.go b/adapter/cluster-adapter-with-heartbeat-type.go new file mode 100644 index 0000000..a2006af --- /dev/null +++ b/adapter/cluster-adapter-with-heartbeat-type.go @@ -0,0 +1,18 @@ +package adapter + +import ( + "sync/atomic" + + "github.com/zishang520/engine.io/v2/types" + "github.com/zishang520/engine.io/v2/utils" +) + +type ( + CustomClusterRequest struct { + Type MessageType + Resolve func(*types.Slice[any]) + Timeout *atomic.Pointer[utils.Timer] + MissingUids *types.Set[ServerId] + Responses *types.Slice[any] + } +) diff --git a/adapter/cluster-adapter-with-heartbeat.go b/adapter/cluster-adapter-with-heartbeat.go new file mode 100644 index 0000000..adcca49 --- /dev/null +++ b/adapter/cluster-adapter-with-heartbeat.go @@ -0,0 +1,332 @@ +package adapter + +import ( + "fmt" + "sync/atomic" + "time" + + "github.com/zishang520/engine.io/v2/types" + "github.com/zishang520/engine.io/v2/utils" + "github.com/zishang520/socket.io/v2/socket" +) + +type ( + ClusterAdapterWithHeartbeatBuilder struct { + Opts *ClusterAdapterOptions + } + + clusterAdapterWithHeartbeat struct { + ClusterAdapter + + _opts *ClusterAdapterOptions + + heartbeatTimer atomic.Pointer[utils.Timer] + nodesMap *types.Map[ServerId, int64] // uid => timestamp of last message + cleanupTimer atomic.Pointer[utils.Timer] + customRequests *types.Map[string, *CustomClusterRequest] + } +) + +func (c *ClusterAdapterWithHeartbeatBuilder) New(nsp socket.Namespace) Adapter { + return NewClusterAdapterWithHeartbeat(nsp, c.Opts) +} + +func MakeClusterAdapterWithHeartbeat() ClusterAdapterWithHeartbeat { + c := &clusterAdapterWithHeartbeat{ + ClusterAdapter: MakeClusterAdapter(), + + _opts: DefaultClusterAdapterOptions(), + nodesMap: &types.Map[ServerId, int64]{}, + customRequests: &types.Map[string, *CustomClusterRequest]{}, + } + + c.Prototype(c) + + return c +} + +func NewClusterAdapterWithHeartbeat(nsp socket.Namespace, opts *ClusterAdapterOptions) ClusterAdapterWithHeartbeat { + c := MakeClusterAdapterWithHeartbeat() + + c.SetOpts(opts) + + c.Construct(nsp) + + return c +} + +func (a *clusterAdapterWithHeartbeat) SetOpts(opts *ClusterAdapterOptions) { + if opts == nil { + opts = DefaultClusterAdapterOptions() + } + a._opts.Assign(opts) +} + +func (a *clusterAdapterWithHeartbeat) Construct(nsp socket.Namespace) { + a.ClusterAdapter.Construct(nsp) + a.cleanupTimer.Store(utils.SetInterval(func() { + now := time.Now().UnixMilli() + a.nodesMap.Range(func(uid ServerId, lastSeen int64) bool { + if now-lastSeen > a._opts.HeartbeatTimeout() { + adapter_log.Debug("[%s] node %s seems down", a.Uid(), uid) + a.removeNode(uid) + } + return true + }) + }, 1_000*time.Millisecond)) +} + +func (a *clusterAdapterWithHeartbeat) Init() { + a.Publish(&ClusterMessage{ + Type: INITIAL_HEARTBEAT, + }) +} + +func (a *clusterAdapterWithHeartbeat) scheduleHeartbeat() { + if heartbeatTimer := a.heartbeatTimer.Load(); heartbeatTimer != nil { + heartbeatTimer.Refresh() + } else { + a.heartbeatTimer.Store(utils.SetTimeout(func() { + a.Publish(&ClusterMessage{ + Type: HEARTBEAT, + }) + }, a._opts.HeartbeatInterval())) + } +} + +func (a *clusterAdapterWithHeartbeat) Close() { + a.Publish(&ClusterMessage{ + Type: ADAPTER_CLOSE, + }) + utils.ClearTimeout(a.heartbeatTimer.Load()) + utils.ClearInterval(a.cleanupTimer.Load()) +} + +func (a *clusterAdapterWithHeartbeat) OnMessage(message *ClusterMessage, offset Offset) { + if message.Uid == a.Uid() { + adapter_log.Debug("[%s] ignore message from self", a.Uid()) + return + } + + if message.Uid != EMITTER_UID { + // we track the UID of each sender, in order to know how many servers there are in the cluster + a.nodesMap.Store(message.Uid, time.Now().UnixMilli()) + } + + adapter_log.Debug( + "[%s] new event of type %d from %s", + a.Uid(), + message.Type, + message.Uid, + ) + + switch message.Type { + case INITIAL_HEARTBEAT: + a.Publish(&ClusterMessage{Type: HEARTBEAT}) + case HEARTBEAT: + // Do nothing + case ADAPTER_CLOSE: + a.removeNode(message.Uid) + default: + a.ClusterAdapter.OnMessage(message, offset) + } +} + +func (a *clusterAdapterWithHeartbeat) ServerCount() int64 { + return int64(a.nodesMap.Len() + 1) +} + +func (a *clusterAdapterWithHeartbeat) Publish(message *ClusterMessage) { + a.scheduleHeartbeat() + + a.ClusterAdapter.Publish(message) +} + +func (a *clusterAdapterWithHeartbeat) ServerSideEmit(packet []any) error { + if len(packet) == 0 { + return fmt.Errorf("packet cannot be empty") + } + + data_len := len(packet) + ack, withAck := packet[data_len-1].(socket.Ack) + if !withAck { + a.Publish(&ClusterMessage{ + Type: SERVER_SIDE_EMIT, + Data: &ServerSideEmitMessage{ + Packet: packet, + }, + }) + return nil + } + expectedResponseCount := a.nodesMap.Len() + + adapter_log.Debug( + `[%s] waiting for %d responses to "serverSideEmit" request`, + a.Uid(), + expectedResponseCount, + ) + + if expectedResponseCount <= 0 { + ack(nil, nil) + return nil + } + + requestId, err := RandomId() + + if err != nil { + return err + } + + timeout := utils.SetTimeout(func() { + if storedRequest, ok := a.customRequests.Load(requestId); ok { + ack( + storedRequest.Responses.All(), + fmt.Errorf(`timeout reached: missing %d responses`, storedRequest.MissingUids.Len()), + ) + a.customRequests.Delete(requestId) + } + }, DEFAULT_TIMEOUT) + + a.customRequests.Store(requestId, &CustomClusterRequest{ + Type: SERVER_SIDE_EMIT, + Resolve: func(data *types.Slice[any]) { + ack(data.All(), nil) + }, + Timeout: Tap(&atomic.Pointer[utils.Timer]{}, func(t *atomic.Pointer[utils.Timer]) { + t.Store(timeout) + }), + MissingUids: types.NewSet(a.nodesMap.Keys()...), + Responses: types.NewSlice[any](), + }) + + a.Publish(&ClusterMessage{ + Type: SERVER_SIDE_EMIT, + Data: &ServerSideEmitMessage{ + RequestId: &requestId, // the presence of this attribute defines whether an acknowledgement is needed + Packet: packet[:data_len-1], + }, + }) + return nil +} + +func (a *clusterAdapterWithHeartbeat) FetchSockets(opts *socket.BroadcastOptions) func(func([]socket.SocketDetails, error)) { + if opts == nil { + opts = &socket.BroadcastOptions{ + Rooms: types.NewSet[socket.Room](), + Except: types.NewSet[socket.Room](), + } + } + return func(cb func([]socket.SocketDetails, error)) { + a.ClusterAdapter.FetchSockets(&socket.BroadcastOptions{ + Rooms: opts.Rooms, + Except: opts.Except, + Flags: &socket.BroadcastFlags{ + Local: true, + }, + })(func(localSockets []socket.SocketDetails, _ error) { + expectedResponseCount := a.ServerCount() - 1 + + if (opts != nil && opts.Flags != nil && opts.Flags.Local) || expectedResponseCount <= 0 { + cb(localSockets, nil) + return + } + + requestId, _ := RandomId() + + t := DEFAULT_TIMEOUT + if opts != nil && opts.Flags != nil && opts.Flags.Timeout != nil { + t = *opts.Flags.Timeout + } + + timeout := utils.SetTimeout(func() { + if storedRequest, ok := a.customRequests.Load(requestId); ok { + cb(nil, fmt.Errorf("timeout reached: missing %d responses", storedRequest.MissingUids.Len())) + a.customRequests.Delete(requestId) + } + }, t) + + a.customRequests.Store(requestId, &CustomClusterRequest{ + Type: FETCH_SOCKETS, + Resolve: func(data *types.Slice[any]) { + cb(SliceMap(data.All(), func(i any) socket.SocketDetails { + return i.(socket.SocketDetails) + }), nil) + }, + Timeout: Tap(&atomic.Pointer[utils.Timer]{}, func(t *atomic.Pointer[utils.Timer]) { + t.Store(timeout) + }), + MissingUids: types.NewSet(a.nodesMap.Keys()...), + Responses: types.NewSlice(SliceMap(localSockets, func(client socket.SocketDetails) any { + return client + })...), + }) + + a.Publish(&ClusterMessage{ + Type: FETCH_SOCKETS, + Data: &FetchSocketsMessage{ + Opts: EncodeOptions(opts), + RequestId: requestId, + }, + }) + }) + } +} + +func (a *clusterAdapterWithHeartbeat) OnResponse(response *ClusterResponse) { + switch response.Type { + case FETCH_SOCKETS_RESPONSE: + data, ok := response.Data.(*FetchSocketsResponse) + if !ok { + adapter_log.Debug("[%s] invalid data for FETCH_SOCKETS_RESPONSE message", a.Uid()) + return + } + adapter_log.Debug("[%s] received response %d to request %s", a.Uid(), response.Type, data.RequestId) + if request, ok := a.customRequests.Load(data.RequestId); ok { + request.Responses.Push(SliceMap(data.Sockets, func(client *SocketResponse) any { + return socket.SocketDetails(NewRemoteSocket(client)) + })...) + + request.MissingUids.Delete(response.Uid) + if request.MissingUids.Len() == 0 { + utils.ClearTimeout(request.Timeout.Load()) + request.Resolve(request.Responses) + a.customRequests.Delete(data.RequestId) + } + } + + case SERVER_SIDE_EMIT_RESPONSE: + data, ok := response.Data.(*ServerSideEmitResponse) + if !ok { + adapter_log.Debug("[%s] invalid data for FETCH_SOCKETS_RESPONSE message", a.Uid()) + return + } + adapter_log.Debug("[%s] received response %d to request %s", a.Uid(), response.Type, data.RequestId) + if request, ok := a.customRequests.Load(data.RequestId); ok { + request.Responses.Push(data.Packet) + + request.MissingUids.Delete(response.Uid) + if request.MissingUids.Len() == 0 { + utils.ClearTimeout(request.Timeout.Load()) + request.Resolve(request.Responses) + a.customRequests.Delete(data.RequestId) + } + } + + default: + a.ClusterAdapter.OnResponse(response) + } +} + +func (a *clusterAdapterWithHeartbeat) removeNode(uid ServerId) { + a.customRequests.Range(func(requestId string, request *CustomClusterRequest) bool { + request.MissingUids.Delete(uid) + if request.MissingUids.Len() == 0 { + utils.ClearTimeout(request.Timeout.Load()) + request.Resolve(request.Responses) + a.customRequests.Delete(requestId) + } + return true + }) + + a.nodesMap.Delete(uid) +} diff --git a/adapter/cluster-adapter.go b/adapter/cluster-adapter.go new file mode 100644 index 0000000..2e7b282 --- /dev/null +++ b/adapter/cluster-adapter.go @@ -0,0 +1,570 @@ +package adapter + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/zishang520/engine.io/v2/types" + "github.com/zishang520/engine.io/v2/utils" + "github.com/zishang520/socket.io-go-parser/v2/parser" + "github.com/zishang520/socket.io/v2/socket" +) + +type ( + // A cluster-ready adapter. Any extending interface must: + // + // - implement [ClusterAdapter.DoPublish] and [ClusterAdapter.DoPublishResponse] + // + // - call [ClusterAdapter.OnMessage] and [ClusterAdapter.OnResponse] + ClusterAdapterBuilder struct { + } + + clusterAdapter struct { + Adapter + + // @protected + uid ServerId + + requests *types.Map[string, *ClusterRequest] + ackRequests *types.Map[string, *ClusterAckRequest] + } +) + +func (cb *ClusterAdapterBuilder) New(nsp socket.Namespace) Adapter { + return NewClusterAdapter(nsp) +} + +func MakeClusterAdapter() ClusterAdapter { + c := &clusterAdapter{ + Adapter: MakeAdapter(), + + requests: &types.Map[string, *ClusterRequest]{}, + ackRequests: &types.Map[string, *ClusterAckRequest]{}, + } + + c.Prototype(c) + + return c +} + +func NewClusterAdapter(nsp socket.Namespace) ClusterAdapter { + c := MakeClusterAdapter() + + c.Construct(nsp) + + return c +} + +// @protected +func (c *clusterAdapter) Uid() ServerId { + return c.uid +} + +func (c *clusterAdapter) Construct(nsp socket.Namespace) { + c.Adapter.Construct(nsp) + + uid, _ := RandomId() + c.uid = ServerId(uid) +} + +// OnMessage handles incoming messages +// +// @protected +func (c *clusterAdapter) OnMessage(message *ClusterMessage, offset Offset) { + if message.Uid == c.uid { + adapter_log.Debug("[%s] ignore message from self", c.uid) + return + } + + adapter_log.Debug("[%s] new event of type %d from %s", c.uid, message.Type, message.Uid) + + switch message.Type { + case BROADCAST: + data, ok := message.Data.(*BroadcastMessage) + if !ok { + adapter_log.Debug("[%s] invalid data for BROADCAST message", c.uid) + return + } + + withAck := data.RequestId != nil + if withAck { + c.Adapter.BroadcastWithAck( + data.Packet, + DecodeOptions(data.Opts), + func(clientCount uint64) { + adapter_log.Debug("[%s] waiting for %d client acknowledgements", c.uid, clientCount) + c.PublishResponse(message.Uid, &ClusterResponse{ + Type: BROADCAST_CLIENT_COUNT, + Data: &BroadcastClientCount{ + RequestId: *data.RequestId, + ClientCount: clientCount, + }, + }) + }, + func(args []any, _ error) { + adapter_log.Debug("[%s] received acknowledgement with value %v", c.uid, args) + c.PublishResponse(message.Uid, &ClusterResponse{ + Type: BROADCAST_ACK, + Data: &BroadcastAck{ + RequestId: *data.RequestId, + Packet: args, + }, + }) + }, + ) + } else { + opts := DecodeOptions(data.Opts) + c.addOffsetIfNecessary(data.Packet, opts, offset) + c.Adapter.Broadcast(data.Packet, opts) + } + + case SOCKETS_JOIN: + data, ok := message.Data.(*SocketsJoinLeaveMessage) + if !ok { + adapter_log.Debug("[%s] invalid data for SOCKETS_JOIN message", c.uid) + return + } + c.Adapter.AddSockets(DecodeOptions(data.Opts), data.Rooms) + + case SOCKETS_LEAVE: + data, ok := message.Data.(*SocketsJoinLeaveMessage) + if !ok { + adapter_log.Debug("[%s] invalid data for SOCKETS_LEAVE message", c.uid) + return + } + c.Adapter.DelSockets(DecodeOptions(data.Opts), data.Rooms) + + case DISCONNECT_SOCKETS: + data, ok := message.Data.(*DisconnectSocketsMessage) + if !ok { + adapter_log.Debug("[%s] invalid data for DISCONNECT_SOCKETS message", c.uid) + return + } + c.Adapter.DisconnectSockets( + DecodeOptions(data.Opts), + data.Close, + ) + + case FETCH_SOCKETS: + data, ok := message.Data.(*FetchSocketsMessage) + if !ok { + adapter_log.Debug("[%s] invalid data for FETCH_SOCKETS message", c.uid) + return + } + adapter_log.Debug( + "[%s] calling fetchSockets with opts %v", + c.uid, + data.Opts, + ) + c.Adapter.FetchSockets(DecodeOptions(data.Opts))( + func(localSockets []socket.SocketDetails, err error) { + if err != nil { + adapter_log.Debug("FETCH_SOCKETS Adapter.OnMessage error: %s", err.Error()) + return + } + c.PublishResponse(message.Uid, &ClusterResponse{ + Type: FETCH_SOCKETS_RESPONSE, + Data: &FetchSocketsResponse{ + RequestId: data.RequestId, + Sockets: SliceMap(localSockets, func(client socket.SocketDetails) *SocketResponse { + return &SocketResponse{ + Id: client.Id(), + Handshake: client.Handshake(), + Rooms: client.Rooms().Keys(), + Data: client.Data(), + } + }), + }, + }, + ) + }, + ) + + case SERVER_SIDE_EMIT: + data, ok := message.Data.(*ServerSideEmitMessage) + if !ok { + adapter_log.Debug("[%s] invalid data for SERVER_SIDE_EMIT message", c.uid) + return + } + packet := data.Packet + if data.RequestId == nil { + c.Nsp().OnServerSideEmit(packet) + return + } + called := sync.Once{} + callback := socket.Ack(func(arg []any, _ error) { + // only one argument is expected + called.Do(func() { + adapter_log.Debug("[%s] calling acknowledgement with %v", c.uid, arg) + c.PublishResponse(message.Uid, &ClusterResponse{ + Type: SERVER_SIDE_EMIT_RESPONSE, + Data: &ServerSideEmitResponse{ + RequestId: *data.RequestId, + Packet: arg, + }, + }) + }) + }) + + c.Nsp().OnServerSideEmit(append(packet, callback)) + + case BROADCAST_CLIENT_COUNT, BROADCAST_ACK, FETCH_SOCKETS_RESPONSE, SERVER_SIDE_EMIT_RESPONSE: + // extending classes may not make a distinction between a ClusterMessage and a ClusterResponse payload and may + // always call the OnMessage() method + c.OnResponse(message) + default: + adapter_log.Debug("[%s] unknown message type: %d", c.uid, message.Type) + } +} + +// OnResponse handles incoming responses +// +// @protected +func (c *clusterAdapter) OnResponse(response *ClusterResponse) { + switch response.Type { + case BROADCAST_CLIENT_COUNT: + data, ok := response.Data.(*BroadcastClientCount) + if !ok { + adapter_log.Debug("[%s] invalid data for BROADCAST_CLIENT_COUNT message", c.uid) + return + } + adapter_log.Debug("[%s] received response %d to request %s", c.uid, response.Type, data.RequestId) + if ackRequest, ok := c.ackRequests.Load(data.RequestId); ok { + ackRequest.ClientCountCallback(data.ClientCount) + } + + case BROADCAST_ACK: + data, ok := response.Data.(*BroadcastAck) + if !ok { + adapter_log.Debug("[%s] invalid data for BROADCAST_ACK message", c.uid) + return + } + adapter_log.Debug("[%s] received response %d to request %s", c.uid, response.Type, data.RequestId) + if ackRequest, ok := c.ackRequests.Load(data.RequestId); ok { + ackRequest.Ack(data.Packet, nil) + } + + case FETCH_SOCKETS_RESPONSE: + data, ok := response.Data.(*FetchSocketsResponse) + if !ok { + adapter_log.Debug("[%s] invalid data for FETCH_SOCKETS_RESPONSE message", c.uid) + return + } + adapter_log.Debug("[%s] received response %d to request %s", c.uid, response.Type, data.RequestId) + if request, ok := c.requests.Load(data.RequestId); ok { + request.Current.Add(1) + request.Responses.Push(SliceMap(data.Sockets, func(client *SocketResponse) any { + return socket.SocketDetails(NewRemoteSocket(client)) + })...) + + if request.Current.Load() == request.Expected { + utils.ClearTimeout(request.Timeout.Load()) + request.Resolve(request.Responses) + c.requests.Delete(data.RequestId) + } + } + + case SERVER_SIDE_EMIT_RESPONSE: + data, ok := response.Data.(*ServerSideEmitResponse) + if !ok { + adapter_log.Debug("[%s] invalid data for SERVER_SIDE_EMIT_RESPONSE message", c.uid) + return + } + adapter_log.Debug("[%s] received response %d to request %s", c.uid, response.Type, data.RequestId) + if request, ok := c.requests.Load(data.RequestId); ok { + request.Current.Add(1) + request.Responses.Push(data.Packet) + + if request.Current.Load() == request.Expected { + utils.ClearTimeout(request.Timeout.Load()) + request.Resolve(request.Responses) + c.requests.Delete(data.RequestId) + } + } + default: + adapter_log.Debug("[%s] unknown response type: %d", c.uid, response.Type) + } +} + +func (c *clusterAdapter) Broadcast(packet *parser.Packet, opts *socket.BroadcastOptions) { + onlyLocal := opts != nil && opts.Flags != nil && opts.Flags.Local + + if !onlyLocal { + offset, err := c.PublishAndReturnOffset(&ClusterMessage{ + Type: BROADCAST, + Data: &BroadcastMessage{ + Packet: packet, + Opts: EncodeOptions(opts), + }, + }) + if err != nil { + adapter_log.Debug("[%s] error while broadcasting message: %s", c.uid, err.Error()) + return + } + c.addOffsetIfNecessary(packet, opts, offset) + } + + c.Adapter.Broadcast(packet, opts) +} + +// Adds an offset at the end of the data array in order to allow the client to receive any missed packets when it +// reconnects after a temporary disconnection. +func (c *clusterAdapter) addOffsetIfNecessary(packet *parser.Packet, opts *socket.BroadcastOptions, offset Offset) { + if c.Nsp().Server().Opts().GetRawConnectionStateRecovery() == nil { + return + } + + isEventPacket := packet.Type == parser.EVENT + // packets with acknowledgement are not stored because the acknowledgement function cannot be serialized and + // restored on another server upon reconnection + withoutAcknowledgement := packet.Id == nil + notVolatile := opts == nil || opts.Flags == nil || opts.Flags.Volatile == false + + if isEventPacket && withoutAcknowledgement && notVolatile { + packet.Data = append(packet.Data.([]any), offset) + } +} + +func (c *clusterAdapter) BroadcastWithAck(packet *parser.Packet, opts *socket.BroadcastOptions, clientCountCallback func(uint64), ack socket.Ack) { + onlyLocal := opts != nil && opts.Flags != nil && opts.Flags.Local + if !onlyLocal { + requestId, _ := RandomId() + + c.ackRequests.Store(requestId, &ClusterAckRequest{ + ClientCountCallback: clientCountCallback, + Ack: ack, + }) + + c.Publish(&ClusterMessage{ + Type: BROADCAST, + Data: &BroadcastMessage{ + Packet: packet, + RequestId: &requestId, + Opts: EncodeOptions(opts), + }, + }) + + timeout := time.Duration(0) + if opts != nil && opts.Flags != nil && opts.Flags.Timeout != nil { + timeout = *opts.Flags.Timeout + } + + // we have no way to know at this level whether the server has received an acknowledgement from each client, so we + // will simply clean up the ackRequests map after the given delay + utils.SetTimeout(func() { + c.ackRequests.Delete(requestId) + }, timeout) + } + + c.Adapter.BroadcastWithAck(packet, opts, clientCountCallback, ack) +} + +func (c *clusterAdapter) AddSockets(opts *socket.BroadcastOptions, rooms []socket.Room) { + onlyLocal := opts != nil && opts.Flags != nil && opts.Flags.Local + + if !onlyLocal { + _, err := c.PublishAndReturnOffset(&ClusterMessage{ + Type: SOCKETS_JOIN, + Data: &SocketsJoinLeaveMessage{ + Opts: EncodeOptions(opts), + Rooms: rooms, + }, + }) + if err != nil { + adapter_log.Debug("[%s] error while publishing message: %s", c.uid, err.Error()) + } + } + + c.Adapter.AddSockets(opts, rooms) +} + +func (c *clusterAdapter) DelSockets(opts *socket.BroadcastOptions, rooms []socket.Room) { + onlyLocal := opts != nil && opts.Flags != nil && opts.Flags.Local + + if !onlyLocal { + _, err := c.PublishAndReturnOffset(&ClusterMessage{ + Type: SOCKETS_LEAVE, + Data: &SocketsJoinLeaveMessage{ + Opts: EncodeOptions(opts), + Rooms: rooms, + }, + }) + if err != nil { + adapter_log.Debug("[%s] error while publishing message: %s", c.uid, err.Error()) + } + } + + c.Adapter.DelSockets(opts, rooms) +} + +func (c *clusterAdapter) DisconnectSockets(opts *socket.BroadcastOptions, state bool) { + onlyLocal := opts != nil && opts.Flags != nil && opts.Flags.Local + + if !onlyLocal { + _, err := c.PublishAndReturnOffset(&ClusterMessage{ + Type: DISCONNECT_SOCKETS, + Data: &DisconnectSocketsMessage{ + Opts: EncodeOptions(opts), + Close: state, + }, + }) + if err != nil { + adapter_log.Debug("[%s] error while publishing message: %s", c.uid, err.Error()) + } + } + + c.Adapter.DisconnectSockets(opts, state) +} + +func (c *clusterAdapter) FetchSockets(opts *socket.BroadcastOptions) func(func([]socket.SocketDetails, error)) { + return func(callback func([]socket.SocketDetails, error)) { + c.Adapter.FetchSockets(opts)(func(localSockets []socket.SocketDetails, _ error) { + expectedResponseCount := c.ServerCount() - 1 + + if (opts != nil && opts.Flags != nil && opts.Flags.Local) || expectedResponseCount <= 0 { + callback(localSockets, nil) + return + } + + requestId, _ := RandomId() + + t := DEFAULT_TIMEOUT + if opts != nil && opts.Flags != nil && opts.Flags.Timeout != nil { + t = *opts.Flags.Timeout + } + + timeout := utils.SetTimeout(func() { + if storedRequest, ok := c.requests.Load(requestId); ok { + callback(nil, fmt.Errorf("timeout reached: only %d responses received out of %d", storedRequest.Current.Load(), storedRequest.Expected)) + c.requests.Delete(requestId) + } + }, t) + + c.requests.Store(requestId, &ClusterRequest{ + Type: FETCH_SOCKETS, + Resolve: func(data *types.Slice[any]) { + callback(SliceMap(data.All(), func(i any) socket.SocketDetails { + return i.(socket.SocketDetails) + }), nil) + }, + Timeout: Tap(&atomic.Pointer[utils.Timer]{}, func(t *atomic.Pointer[utils.Timer]) { + t.Store(timeout) + }), + Current: &atomic.Int64{}, + Expected: expectedResponseCount, + Responses: types.NewSlice(SliceMap(localSockets, func(client socket.SocketDetails) any { + return client + })...), + }) + + c.Publish(&ClusterMessage{ + Type: FETCH_SOCKETS, + Data: &FetchSocketsMessage{ + Opts: EncodeOptions(opts), + RequestId: requestId, + }, + }) + }) + } +} + +func (c *clusterAdapter) ServerSideEmit(packet []any) error { + if len(packet) == 0 { + return fmt.Errorf("packet cannot be empty") + } + + data_len := len(packet) + ack, withAck := packet[data_len-1].(socket.Ack) + if !withAck { + c.Publish(&ClusterMessage{ + Type: SERVER_SIDE_EMIT, + Data: &ServerSideEmitMessage{ + Packet: packet, + }, + }) + return nil + } + + expectedResponseCount := c.ServerCount() - 1 + + adapter_log.Debug(`[%s] waiting for %d responses to "serverSideEmit" request`, c.uid, expectedResponseCount) + + if expectedResponseCount <= 0 { + ack(nil, nil) + return nil + } + + requestId, err := RandomId() + + if err != nil { + return err + } + + timeout := utils.SetTimeout(func() { + if storedRequest, ok := c.requests.Load(requestId); ok { + ack( + storedRequest.Responses.All(), + fmt.Errorf(`timeout reached: only %d responses received out of %d`, storedRequest.Current.Load(), storedRequest.Expected), + ) + c.requests.Delete(requestId) + } + }, DEFAULT_TIMEOUT) + + c.requests.Store(requestId, &ClusterRequest{ + Type: SERVER_SIDE_EMIT, + Resolve: func(data *types.Slice[any]) { + ack(data.All(), nil) + }, + Timeout: Tap(&atomic.Pointer[utils.Timer]{}, func(t *atomic.Pointer[utils.Timer]) { + t.Store(timeout) + }), + Current: &atomic.Int64{}, + Expected: expectedResponseCount, + Responses: types.NewSlice[any](), + }) + + c.Publish(&ClusterMessage{ + Type: SERVER_SIDE_EMIT, + Data: &ServerSideEmitMessage{ + RequestId: &requestId, // the presence of this attribute defines whether an acknowledgement is needed + Packet: packet[:data_len-1], + }, + }) + return nil +} + +func (c *clusterAdapter) Publish(message *ClusterMessage) { + _, err := c.PublishAndReturnOffset(message) + if err != nil { + adapter_log.Debug(`[%s] error while publishing message: %s`, c.uid, err.Error()) + } +} + +func (c *clusterAdapter) PublishAndReturnOffset(message *ClusterMessage) (Offset, error) { + message.Uid = c.uid + message.Nsp = c.Nsp().Name() + return c.Proto().(ClusterAdapter).DoPublish(message) +} + +// Send a message to the other members of the cluster. +func (c *clusterAdapter) DoPublish(message *ClusterMessage) (Offset, error) { + return "", errors.New("DoPublish() is not supported on parent ClusterAdapter") +} + +func (c *clusterAdapter) PublishResponse(requesterUid ServerId, response *ClusterResponse) { + response.Uid = c.uid + response.Nsp = c.Nsp().Name() + + err := c.Proto().(ClusterAdapter).DoPublishResponse(requesterUid, response) + if err != nil { + adapter_log.Debug(`[%s] error while publishing response: %s`, c.uid, err.Error()) + } +} + +// Send a response to the given member of the cluster. +func (c *clusterAdapter) DoPublishResponse(requesterUid ServerId, response *ClusterResponse) error { + return errors.New("DoPublishResponse() is not supported on parent ClusterAdapter") +} diff --git a/adapter/remote-socket.go b/adapter/remote-socket.go new file mode 100644 index 0000000..594a322 --- /dev/null +++ b/adapter/remote-socket.go @@ -0,0 +1,50 @@ +package adapter + +import ( + "github.com/zishang520/engine.io/v2/types" + "github.com/zishang520/socket.io/v2/socket" +) + +// Expose of subset of the attributes and methods of the Socket struct +type RemoteSocket struct { + id socket.SocketId + handshake *socket.Handshake + rooms *types.Set[socket.Room] + data any +} + +func MakeRemoteSocket() *RemoteSocket { + r := &RemoteSocket{} + return r +} + +func NewRemoteSocket(details *SocketResponse) *RemoteSocket { + r := MakeRemoteSocket() + + r.Construct(details) + + return r +} + +func (r *RemoteSocket) Id() socket.SocketId { + return r.id +} + +func (r *RemoteSocket) Handshake() *socket.Handshake { + return r.handshake +} + +func (r *RemoteSocket) Rooms() *types.Set[socket.Room] { + return r.rooms +} + +func (r *RemoteSocket) Data() any { + return r.data +} + +func (r *RemoteSocket) Construct(details *SocketResponse) { + r.id = details.Id + r.handshake = details.Handshake + r.rooms = types.NewSet(details.Rooms...) + r.data = details.Data +} diff --git a/adapter/session-aware-adapter.go b/adapter/session-aware-adapter.go new file mode 100644 index 0000000..cb44ef2 --- /dev/null +++ b/adapter/session-aware-adapter.go @@ -0,0 +1,22 @@ +package adapter + +import ( + "github.com/zishang520/socket.io/v2/socket" +) + +type ( + SessionAwareAdapterBuilder struct { + } +) + +func (*SessionAwareAdapterBuilder) New(nsp socket.Namespace) Adapter { + return NewSessionAwareAdapter(nsp) +} + +func MakeSessionAwareAdapter() SessionAwareAdapter { + return socket.MakeSessionAwareAdapter() +} + +func NewSessionAwareAdapter(nsp socket.Namespace) SessionAwareAdapter { + return socket.NewSessionAwareAdapter(nsp) +} diff --git a/adapter/util.go b/adapter/util.go new file mode 100644 index 0000000..d7783ba --- /dev/null +++ b/adapter/util.go @@ -0,0 +1,80 @@ +package adapter + +import ( + "crypto/rand" + "encoding/base64" + "encoding/hex" + + "github.com/zishang520/engine.io/v2/log" + "github.com/zishang520/engine.io/v2/types" + "github.com/zishang520/socket.io/v2/socket" +) + +var adapter_log = log.NewLog("socket.io-adapter") + +// Encode BroadcastOptions into PacketOptions +func EncodeOptions(opts *socket.BroadcastOptions) *PacketOptions { + p := &PacketOptions{} + if opts == nil { + return p + } + + if opts.Rooms != nil { + p.Rooms = opts.Rooms.Keys() // Convert the set to a slice of strings + } + if opts.Except != nil { + p.Except = opts.Except.Keys() // Convert the set to a slice of strings + } + if opts.Flags != nil { + p.Flags = opts.Flags // Pass flags as is + } + return p +} + +// Decode PacketOptions back into BroadcastOptions +func DecodeOptions(opts *PacketOptions) *socket.BroadcastOptions { + b := &socket.BroadcastOptions{ + Rooms: types.NewSet[socket.Room](), + Except: types.NewSet[socket.Room](), + } + if opts == nil { + return b + } + + b.Rooms.Add(opts.Rooms...) // Convert slice to set + b.Except.Add(opts.Except...) // Convert slice to set + b.Flags = opts.Flags // Pass flags as is + + return b +} + +func RandomId() (string, error) { + r := make([]byte, 8) + if _, err := rand.Read(r); err != nil { + return "", err + } + return hex.EncodeToString(r), nil +} + +func Uid2(length int) (string, error) { + r := make([]byte, length) + if _, err := rand.Read(r); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(r), nil +} + +func SliceMap[I any, O any](i []I, converter func(I) O) (o []O) { + for _, _i := range i { + o = append(o, converter(_i)) + } + return o +} + +// Tap calls the given function with the given value, then returns the value. +func Tap[T any](value T, callback func(T)) T { + if callback != nil { + callback(value) + } + return value +} diff --git a/go.mod b/go.mod index fa3997a..bd2260b 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module github.com/zishang520/socket.io/v2 -go 1.22.2 +go 1.23.1 require ( github.com/andybalholm/brotli v1.1.0 - github.com/zishang520/engine.io-go-parser v1.2.5 - github.com/zishang520/engine.io/v2 v2.2.2 - github.com/zishang520/socket.io-go-parser/v2 v2.2.0 + github.com/zishang520/engine.io-go-parser v1.2.6 + github.com/zishang520/engine.io/v2 v2.2.3 + github.com/zishang520/socket.io-go-parser/v2 v2.2.1 ) require ( @@ -15,18 +15,18 @@ require ( github.com/gookit/color v1.5.4 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/onsi/ginkgo/v2 v2.12.0 // indirect - github.com/quic-go/qpack v0.4.0 // indirect - github.com/quic-go/quic-go v0.45.1 // indirect + github.com/quic-go/qpack v0.5.1 // indirect + github.com/quic-go/quic-go v0.47.0 // indirect github.com/quic-go/webtransport-go v0.8.0 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect go.uber.org/mock v0.4.0 // indirect - golang.org/x/crypto v0.23.0 // indirect + golang.org/x/crypto v0.26.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/net v0.25.0 // indirect - golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect - golang.org/x/tools v0.21.0 // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sys v0.23.0 // indirect + golang.org/x/text v0.17.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect ) diff --git a/go.sum b/go.sum index 1d54530..3f9f729 100644 --- a/go.sum +++ b/go.sum @@ -25,48 +25,48 @@ github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= -github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.45.1 h1:tPfeYCk+uZHjmDRwHHQmvHRYL2t44ROTujLeFVBmjCA= -github.com/quic-go/quic-go v0.45.1/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.47.0 h1:yXs3v7r2bm1wmPTYNLKAAJTHMYkPEsfYJmTazXrCZ7Y= +github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E= github.com/quic-go/webtransport-go v0.8.0 h1:HxSrwun11U+LlmwpgM1kEqIqH90IT4N8auv/cD7QFJg= github.com/quic-go/webtransport-go v0.8.0/go.mod h1:N99tjprW432Ut5ONql/aUhSLT0YVSlwHohQsuac9WaM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= -github.com/zishang520/engine.io-go-parser v1.2.5 h1:Disf4rvNQzDsgoC+3yuwuFx5A7JNWlPp+QLUW32WDtc= -github.com/zishang520/engine.io-go-parser v1.2.5/go.mod h1:G1DciRIGH4/S7x01DIdZQaXrk09ZeRgEw5e/Z9ms4Is= -github.com/zishang520/engine.io/v2 v2.2.2 h1:arZwVX0/55mo9Koc5JQeiEuFesXaibWkHez5UBkRTJA= -github.com/zishang520/engine.io/v2 v2.2.2/go.mod h1:1J6oMkUUIPbX+CENYMYpaX381fYQ6mwRgy2AOBweycY= -github.com/zishang520/socket.io-go-parser/v2 v2.2.0 h1:x2Ca9lF1kqM4OmCYDcfZ+yPJFYlx319W0LYkaQ4ab4I= -github.com/zishang520/socket.io-go-parser/v2 v2.2.0/go.mod h1:UcVUFESDZBbC9rJ9LZv0DHapH0rekQHtDab3a/9/UOY= +github.com/zishang520/engine.io-go-parser v1.2.6 h1:X8n0+udu7m/Zc+M+zRWiO0LcoPtyXMMMyrJckMGn+tA= +github.com/zishang520/engine.io-go-parser v1.2.6/go.mod h1:WRsjNz1Oi04dqGcvjpW0t6/B2KIuDSrTBvCZDs7r3XY= +github.com/zishang520/engine.io/v2 v2.2.3 h1:XByBVlcQwnn9MPV2q/0FAIzRl2p8bDdNmDd6tjQNRc4= +github.com/zishang520/engine.io/v2 v2.2.3/go.mod h1:C6CuEcQqHFsE2/rmQikiqww0JKfPBJjlCItMeURyhHU= +github.com/zishang520/socket.io-go-parser/v2 v2.2.1 h1:xtNT3ImCeb3bzDYaYJItmJe5ILALbuXbagQ2mVTRWv8= +github.com/zishang520/socket.io-go-parser/v2 v2.2.1/go.mod h1:ehqNzBXCP9zazabYhcXrVfOXzvhjijw16pJXEL+AD08= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= -golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/socket/adapter-type.go b/socket/adapter-type.go index af40fb3..7b17759 100644 --- a/socket/adapter-type.go +++ b/socket/adapter-type.go @@ -40,16 +40,16 @@ type ( } BroadcastOptions struct { - Rooms *types.Set[Room] - Except *types.Set[Room] - Flags *BroadcastFlags `json:"flags,omitempty" msgpack:"flags,omitempty"` + Rooms *types.Set[Room] `json:"rooms,omitempty" msgpack:"rooms,omitempty"` + Except *types.Set[Room] `json:"except,omitempty" msgpack:"except,omitempty"` + Flags *BroadcastFlags `json:"flags,omitempty" msgpack:"flags,omitempty"` } SessionToPersist struct { Sid SocketId `json:"sid" msgpack:"sid"` Pid PrivateSessionId `json:"pid" msgpack:"pid"` - Rooms *types.Set[Room] - Data any `json:"data" msgpack:"data"` + Rooms *types.Set[Room] `json:"rooms,omitempty" msgpack:"rooms,omitempty"` + Data any `json:"data" msgpack:"data"` } Session struct { @@ -118,7 +118,7 @@ type ( // - `Flags` {*BroadcastFlags} flags for this packet // - `Except` {*types.Set[Room]} sids that should be excluded // - `Rooms` {*types.Set[Room]} list of rooms to broadcast to - BroadcastWithAck(*parser.Packet, *BroadcastOptions, func(uint64), func([]any, error)) + BroadcastWithAck(*parser.Packet, *BroadcastOptions, func(uint64), Ack) // Gets a list of sockets by sid. Sockets(*types.Set[Room]) *types.Set[SocketId] diff --git a/socket/adapter-type_test.go b/socket/adapter-type_test.go new file mode 100644 index 0000000..b61de00 --- /dev/null +++ b/socket/adapter-type_test.go @@ -0,0 +1,167 @@ +package socket + +import ( + "testing" + "time" + + "github.com/zishang520/engine.io-go-parser/packet" + "github.com/zishang520/engine.io/v2/types" +) + +func TestBroadcastFlagsInheritance(t *testing.T) { + flags := BroadcastFlags{ + WriteOptions: WriteOptions{ + Volatile: true, + PreEncoded: false, + Options: packet.Options{ + Compress: true, + }, + }, + Local: true, + Broadcast: false, + Binary: true, + } + + if flags.Volatile != true { + t.Errorf("Expected Volatile to be true, got %v", flags.Volatile) + } + if flags.PreEncoded != false { + t.Errorf("Expected PreEncoded to be false, got %v", flags.PreEncoded) + } + if flags.Local != true { + t.Errorf("Expected Local to be true, got %v", flags.Local) + } + if flags.Broadcast != false { + t.Errorf("Expected Broadcast to be false, got %v", flags.Broadcast) + } + if flags.Binary != true { + t.Errorf("Expected Binary to be true, got %v", flags.Binary) + } + if flags.Options.Compress != true { + t.Errorf("Expected Options.Compress to be true, got %v", flags.Options.Compress) + } +} + +func TestBroadcastOptionsInheritance(t *testing.T) { + rooms := types.NewSet[Room]() + except := types.NewSet[Room]() + + flags := &BroadcastFlags{ + WriteOptions: WriteOptions{ + Volatile: false, + PreEncoded: true, + Options: packet.Options{}, + }, + Local: true, + Broadcast: true, + Binary: false, + } + + opts := BroadcastOptions{ + Rooms: rooms, + Except: except, + Flags: flags, + } + + if opts.Flags != flags { + t.Errorf("Expected Flags to be correctly embedded in BroadcastOptions, got %v", opts.Flags) + } + if opts.Rooms != rooms { + t.Errorf("Expected Rooms to be set correctly, got %v", opts.Rooms) + } + if opts.Except != except { + t.Errorf("Expected Except to be set correctly, got %v", opts.Except) + } +} + +func TestSessionInheritance(t *testing.T) { + rooms := types.NewSet[Room]() + + session := Session{ + SessionToPersist: &SessionToPersist{ + Sid: "socket123", + Pid: "private123", + Rooms: rooms, + Data: "sample data", + }, + MissedPackets: []any{"packet1", "packet2"}, + } + + if session.Sid != "socket123" { + t.Errorf("Expected Sid to be 'socket123', got %v", session.Sid) + } + if session.Pid != "private123" { + t.Errorf("Expected Pid to be 'private123', got %v", session.Pid) + } + if session.Rooms != rooms { + t.Errorf("Expected Rooms to be set correctly, got %v", session.Rooms) + } + if session.Data != "sample data" { + t.Errorf("Expected Data to be 'sample data', got %v", session.Data) + } + if len(session.MissedPackets) != 2 { + t.Errorf("Expected MissedPackets to contain 2 packets, got %d", len(session.MissedPackets)) + } +} + +func TestPersistedPacket(t *testing.T) { + flags := &BroadcastFlags{ + WriteOptions: WriteOptions{ + Volatile: true, + PreEncoded: true, + }, + Local: true, + Broadcast: false, + } + + opts := &BroadcastOptions{ + Rooms: types.NewSet[Room](), + Except: types.NewSet[Room](), + Flags: flags, + } + + packet := PersistedPacket{ + Id: "packet1", + EmittedAt: time.Now().Unix(), + Data: "packet data", + Opts: opts, + } + + if packet.Id != "packet1" { + t.Errorf("Expected Id to be 'packet1', got %v", packet.Id) + } + if packet.Opts != opts { + t.Errorf("Expected BroadcastOptions to be set correctly, got %v", packet.Opts) + } + if packet.Data != "packet data" { + t.Errorf("Expected Data to be 'packet data', got %v", packet.Data) + } +} + +func TestSessionWithTimestampInheritance(t *testing.T) { + sessionWithTimestamp := SessionWithTimestamp{ + SessionToPersist: &SessionToPersist{ + Sid: "sid123", + Pid: "pid123", + Rooms: types.NewSet[Room](), + Data: "some data", + }, + DisconnectedAt: time.Now().Unix(), + } + + if sessionWithTimestamp.Sid != "sid123" { + t.Errorf("Expected Sid to be 'sid123', got %v", sessionWithTimestamp.Sid) + } + if sessionWithTimestamp.Pid != "pid123" { + t.Errorf("Expected Pid to be 'pid123', got %v", sessionWithTimestamp.Pid) + } + if sessionWithTimestamp.Rooms == nil { + t.Errorf("Expected Rooms to be set, got nil") + } + if sessionWithTimestamp.Data != "some data" { + t.Errorf("Expected Data to be 'some data', got %v", sessionWithTimestamp.Data) + } + if sessionWithTimestamp.DisconnectedAt <= 0 { + t.Errorf("Expected DisconnectedAt to be a valid timestamp, got %d", sessionWithTimestamp.DisconnectedAt) + } +} diff --git a/socket/adapter.go b/socket/adapter.go index 5323386..9c8fafa 100644 --- a/socket/adapter.go +++ b/socket/adapter.go @@ -1,18 +1,17 @@ package socket import ( + "fmt" "sync/atomic" _types "github.com/zishang520/engine.io-go-parser/types" "github.com/zishang520/engine.io/v2/events" "github.com/zishang520/engine.io/v2/types" - "github.com/zishang520/engine.io/v2/utils" "github.com/zishang520/socket.io-go-parser/v2/parser" ) type ( AdapterBuilder struct { - AdapterConstructor } adapter struct { @@ -171,7 +170,7 @@ func (a *adapter) Broadcast(packet *parser.Packet, opts *BroadcastOptions) { // - `Flags` {*BroadcastFlags} flags for this packet // - `Except` {*types.Set[Room]} sids that should be excluded // - `Rooms` {*types.Set[Room]} list of rooms to broadcast to -func (a *adapter) BroadcastWithAck(packet *parser.Packet, opts *BroadcastOptions, clientCountCallback func(uint64), ack func([]any, error)) { +func (a *adapter) BroadcastWithAck(packet *parser.Packet, opts *BroadcastOptions, clientCountCallback func(uint64), ack Ack) { flags := &BroadcastFlags{} if opts != nil && opts.Flags != nil { flags = opts.Flags @@ -269,8 +268,16 @@ func (a *adapter) DisconnectSockets(opts *BroadcastOptions, status bool) { } func (a *adapter) apply(opts *BroadcastOptions, callback func(*Socket)) { + if opts == nil { + opts = &BroadcastOptions{ + Rooms: types.NewSet[Room](), + Except: types.NewSet[Room](), + } + } + rooms := opts.Rooms except := a.computeExceptSids(opts.Except) + if rooms != nil && rooms.Len() > 0 { ids := types.NewSet[SocketId]() for _, room := range rooms.Keys() { @@ -312,9 +319,8 @@ func (a *adapter) computeExceptSids(exceptRooms *types.Set[Room]) *types.Set[Soc } // Send a packet to the other Socket.IO servers in the cluster -func (a *adapter) ServerSideEmit(args []any) error { - utils.Log().Warning(`this adapter does not support the ServerSideEmit() functionality`) - return nil +func (a *adapter) ServerSideEmit(packet []any) error { + return fmt.Errorf(`this adapter does not support the ServerSideEmit() functionality`) } // Save the client session in order to restore it upon reconnection. diff --git a/socket/broadcast-operator.go b/socket/broadcast-operator.go index ce1302d..7e79744 100644 --- a/socket/broadcast-operator.go +++ b/socket/broadcast-operator.go @@ -185,7 +185,7 @@ func (b *BroadcastOperator) Emit(ev string, args ...any) error { Data: data, } - ack, withAck := data[data_len-1].(func([]any, error)) + ack, withAck := data[data_len-1].(Ack) if !withAck { b.adapter.Broadcast(packet, &BroadcastOptions{ @@ -261,9 +261,9 @@ func (b *BroadcastOperator) Emit(ev string, args ...any) error { // } // }) // -// Return: a `func(func([]any, error))` that will be fulfilled when all clients have acknowledged the event -func (b *BroadcastOperator) EmitWithAck(ev string, args ...any) func(func([]any, error)) { - return func(ack func([]any, error)) { +// Return: a `func(socket.Ack)` that will be fulfilled when all clients have acknowledged the event +func (b *BroadcastOperator) EmitWithAck(ev string, args ...any) func(Ack) { + return func(ack Ack) { b.Emit(ev, append(args, ack)...) } } @@ -378,104 +378,3 @@ func (b *BroadcastOperator) DisconnectSockets(status bool) { Flags: b.flags, }, status) } - -// Expose of subset of the attributes and methods of the Socket struct -type RemoteSocket struct { - id SocketId - handshake *Handshake - rooms *types.Set[Room] - data any - - operator *BroadcastOperator -} - -func MakeRemoteSocket() *RemoteSocket { - r := &RemoteSocket{} - return r -} - -func NewRemoteSocket(adapter Adapter, details SocketDetails) *RemoteSocket { - r := MakeRemoteSocket() - - r.Construct(adapter, details) - - return r -} - -func (r *RemoteSocket) Id() SocketId { - return r.id -} - -func (r *RemoteSocket) Handshake() *Handshake { - return r.handshake -} - -func (r *RemoteSocket) Rooms() *types.Set[Room] { - return r.rooms -} - -func (r *RemoteSocket) Data() any { - return r.data -} - -func (r *RemoteSocket) Construct(adapter Adapter, details SocketDetails) { - r.id = details.Id() - r.handshake = details.Handshake() - r.rooms = types.NewSet(details.Rooms().Keys()...) - r.data = details.Data() - r.operator = NewBroadcastOperator(adapter, types.NewSet(Room(r.id)), types.NewSet[Room](), &BroadcastFlags{ - ExpectSingleResponse: true, // so that remoteSocket.Emit() with acknowledgement behaves like socket.Emit() - }) -} - -// Adds a timeout in milliseconds for the next operation. -// -// io.FetchSockets()(func(sockets []*RemoteSocket, _ error){ -// -// for _, socket := range sockets { -// if (someCondition) { -// socket.Timeout(1000 * time.Millisecond).Emit("some-event", func(args []any, err error) { -// if err != nil { -// // the client did not acknowledge the event in the given delay -// } -// }) -// } -// } -// -// }) -// // Note: if possible, using a room instead of looping over all sockets is preferable -// -// io.Timeout(1000 * time.Millisecond).To(someConditionRoom).Emit("some-event", func(args []any, err error) { -// // ... -// }) -// -// Param: time.Duration - timeout -func (r *RemoteSocket) Timeout(timeout time.Duration) *BroadcastOperator { - return r.operator.Timeout(timeout) -} - -func (r *RemoteSocket) Emit(ev string, args ...any) error { - return r.operator.Emit(ev, args...) -} - -// Joins a room. -// -// Param: Room - a [Room], or a [Room] slice to expand -func (r *RemoteSocket) Join(room ...Room) { - r.operator.SocketsJoin(room...) -} - -// Leaves a room. -// -// Param: Room - a [Room], or a [Room] slice to expand -func (r *RemoteSocket) Leave(room ...Room) { - r.operator.SocketsLeave(room...) -} - -// Disconnects this client. -// -// Param: close - if `true`, closes the underlying connection -func (r *RemoteSocket) Disconnect(status bool) *RemoteSocket { - r.operator.DisconnectSockets(status) - return r -} diff --git a/socket/client.go b/socket/client.go index 0a284f1..4ac3266 100644 --- a/socket/client.go +++ b/socket/client.go @@ -177,12 +177,7 @@ func (c *Client) WriteToEngine(encodedPackets []_types.BufferInterface, opts *Wr } for _, encodedPacket := range encodedPackets { - switch data := encodedPacket.(type) { - case *_types.StringBuffer: - c.conn.Write(_types.NewStringBuffer(data.Bytes()), &opts.Options, nil) - case *_types.BytesBuffer: - c.conn.Write(_types.NewBytesBuffer(data.Bytes()), &opts.Options, nil) - } + c.conn.Write(encodedPacket.Clone(), &opts.Options, nil) } } diff --git a/socket/namespace-type.go b/socket/namespace-type.go index e05f653..83f4c09 100644 --- a/socket/namespace-type.go +++ b/socket/namespace-type.go @@ -7,6 +7,8 @@ import ( "github.com/zishang520/engine.io/v2/types" ) +type NamespaceMiddleware = func(*Socket, func(*ExtendedError)) + // A namespace is a communication channel that allows you to split the logic of your application over a single shared // connection. // @@ -76,7 +78,7 @@ type Namespace interface { Adapter() Adapter Name() string Ids() uint64 - Fns() *types.Slice[func(*Socket, func(*ExtendedError))] + Fns() *types.Slice[NamespaceMiddleware] // Construct() should be called after calling Prototype() Construct(*Server, string) @@ -89,10 +91,10 @@ type Namespace interface { InitAdapter() // Whether to remove child namespaces that have no sockets connected to them - Cleanup(func()) + Cleanup(types.Callable) // Sets up namespace middleware. - Use(func(*Socket, func(*ExtendedError))) Namespace + Use(NamespaceMiddleware) Namespace // Targets a room when emitting. To(...Room) *BroadcastOperator @@ -119,12 +121,12 @@ type Namespace interface { ServerSideEmit(string, ...any) error // Sends a message and expect an acknowledgement from the other Socket.IO servers of the cluster. - ServerSideEmitWithAck(string, ...any) func(func([]any, error)) + ServerSideEmitWithAck(string, ...any) func(Ack) error // @private // // Called when a packet is received from another Socket.IO server - OnServerSideEmit(string, ...any) + OnServerSideEmit([]any) // Gets a list of clients. AllSockets() (*types.Set[SocketId], error) diff --git a/socket/namespace.go b/socket/namespace.go index 9f3b5ef..8b8c52c 100644 --- a/socket/namespace.go +++ b/socket/namespace.go @@ -84,9 +84,9 @@ type namespace struct { server *Server - _fns *types.Slice[func(*Socket, func(*ExtendedError))] + _fns *types.Slice[NamespaceMiddleware] - _cleanup func() + _cleanup types.Callable } func MakeNamespace() Namespace { @@ -94,7 +94,7 @@ func MakeNamespace() Namespace { StrictEventEmitter: NewStrictEventEmitter(), sockets: &types.Map[SocketId, *Socket]{}, - _fns: types.NewSlice[func(*Socket, func(*ExtendedError))](), + _fns: types.NewSlice[NamespaceMiddleware](), _cleanup: nil, } @@ -143,7 +143,7 @@ func (n *namespace) Ids() uint64 { return n._ids.Add(1) } -func (n *namespace) Fns() *types.Slice[func(*Socket, func(*ExtendedError))] { +func (n *namespace) Fns() *types.Slice[NamespaceMiddleware] { return n._fns } @@ -172,7 +172,7 @@ func (n *namespace) InitAdapter() { // }) // // Param: func(*ExtendedError) - the middleware function -func (n *namespace) Use(fn func(*Socket, func(*ExtendedError))) Namespace { +func (n *namespace) Use(fn NamespaceMiddleware) Namespace { n._fns.Push(fn) return n } @@ -464,9 +464,7 @@ func (n *namespace) ServerSideEmit(ev string, args ...any) error { return errors.New(fmt.Sprintf(`"%s" is a reserved event name`, ev)) } - n.Proto().Adapter().ServerSideEmit(append([]any{ev}, args...)) - - return nil + return n.Proto().Adapter().ServerSideEmit(append([]any{ev}, args...)) } // Sends a message and expect an acknowledgement from the other Socket.IO servers of the cluster. @@ -481,16 +479,28 @@ func (n *namespace) ServerSideEmit(ev string, args ...any) error { // } // }) // -// Return: a `func(func([]any, error))` that will be fulfilled when all servers have acknowledged the event -func (n *namespace) ServerSideEmitWithAck(ev string, args ...any) func(func([]any, error)) { - return func(ack func([]any, error)) { - n.ServerSideEmit(ev, append(args, ack)...) +// Return: a `func(socket.Ack)` that will be fulfilled when all servers have acknowledged the event +func (n *namespace) ServerSideEmitWithAck(ev string, args ...any) func(Ack) error { + return func(ack Ack) error { + return n.ServerSideEmit(ev, append(args, ack)...) } } // Called when a packet is received from another Socket.IO server -func (n *namespace) OnServerSideEmit(ev string, args ...any) { - n.EmitUntyped(ev, args...) +func (n *namespace) OnServerSideEmit(args []any) { + // Convert the first argument to a string and pass the rest as args + if len(args) == 0 { + return // No arguments provided + } + + ev, ok := args[0].(string) + if !ok { + // Handle error, the first argument should be a string + return + } + + // Remove the first argument (event name) and pass the rest + n.EmitUntyped(ev, args[1:]...) } // Gets a list of socket ids. diff --git a/socket/parent-broadcast-adapter.go b/socket/parent-broadcast-adapter.go index 0d57edf..4abfdff 100644 --- a/socket/parent-broadcast-adapter.go +++ b/socket/parent-broadcast-adapter.go @@ -6,7 +6,6 @@ import ( type ( ParentBroadcastAdapterBuilder struct { - AdapterConstructor } // A dummy adapter that only supports broadcasting to child (concrete) namespaces. diff --git a/socket/parent-namespace.go b/socket/parent-namespace.go index a4bd4db..be30449 100644 --- a/socket/parent-namespace.go +++ b/socket/parent-namespace.go @@ -55,6 +55,10 @@ func NewParentNamespace(server *Server) ParentNamespace { return n } +func (p *parentNamespace) Children() *types.Set[Namespace] { + return p.children +} + func (p *parentNamespace) Adapter() Adapter { return p.adapter } @@ -70,10 +74,6 @@ func (p *parentNamespace) Emit(ev string, args ...any) error { return nil } -func (p *parentNamespace) Children() *types.Set[Namespace] { - return p.children -} - func (p *parentNamespace) CreateChild(name string) Namespace { parent_namespace_log.Debug("creating child namespace %s", name) namespace := NewNamespace(p.Server(), name) @@ -98,6 +98,7 @@ func (p *parentNamespace) CreateChild(name string) Namespace { p.Server()._nsps.Store(name, namespace) p.Server().Sockets().EmitReserved("new_namespace", namespace) + return namespace } diff --git a/socket/remote-socket.go b/socket/remote-socket.go new file mode 100644 index 0000000..80cb5f4 --- /dev/null +++ b/socket/remote-socket.go @@ -0,0 +1,117 @@ +package socket + +import ( + "time" + + "github.com/zishang520/engine.io/v2/types" +) + +type ( + SocketDetails interface { + Id() SocketId + Handshake() *Handshake + Rooms() *types.Set[Room] + Data() any + } + + // Expose of subset of the attributes and methods of the Socket struct + RemoteSocket struct { + id SocketId + handshake *Handshake + rooms *types.Set[Room] + data any + + operator *BroadcastOperator + } +) + +func MakeRemoteSocket() *RemoteSocket { + r := &RemoteSocket{} + return r +} + +func NewRemoteSocket(adapter Adapter, details SocketDetails) *RemoteSocket { + r := MakeRemoteSocket() + + r.Construct(adapter, details) + + return r +} + +func (r *RemoteSocket) Id() SocketId { + return r.id +} + +func (r *RemoteSocket) Handshake() *Handshake { + return r.handshake +} + +func (r *RemoteSocket) Rooms() *types.Set[Room] { + return r.rooms +} + +func (r *RemoteSocket) Data() any { + return r.data +} + +func (r *RemoteSocket) Construct(adapter Adapter, details SocketDetails) { + r.id = details.Id() + r.handshake = details.Handshake() + r.rooms = types.NewSet(details.Rooms().Keys()...) + r.data = details.Data() + r.operator = NewBroadcastOperator(adapter, types.NewSet(Room(r.id)), types.NewSet[Room](), &BroadcastFlags{ + ExpectSingleResponse: true, // so that remoteSocket.Emit() with acknowledgement behaves like socket.Emit() + }) +} + +// Adds a timeout in milliseconds for the next operation. +// +// io.FetchSockets()(func(sockets []*RemoteSocket, _ error){ +// +// for _, socket := range sockets { +// if (someCondition) { +// socket.Timeout(1000 * time.Millisecond).Emit("some-event", func(args []any, err error) { +// if err != nil { +// // the client did not acknowledge the event in the given delay +// } +// }) +// } +// } +// +// }) +// // Note: if possible, using a room instead of looping over all sockets is preferable +// +// io.Timeout(1000 * time.Millisecond).To(someConditionRoom).Emit("some-event", func(args []any, err error) { +// // ... +// }) +// +// Param: time.Duration - timeout +func (r *RemoteSocket) Timeout(timeout time.Duration) *BroadcastOperator { + return r.operator.Timeout(timeout) +} + +func (r *RemoteSocket) Emit(ev string, args ...any) error { + return r.operator.Emit(ev, args...) +} + +// Joins a room. +// +// Param: Room - a [Room], or a [Room] slice to expand +func (r *RemoteSocket) Join(room ...Room) { + r.operator.SocketsJoin(room...) +} + +// Leaves a room. +// +// Param: Room - a [Room], or a [Room] slice to expand +func (r *RemoteSocket) Leave(room ...Room) { + r.operator.SocketsLeave(room...) +} + +// Disconnects this client. +// +// Param: close - if `true`, closes the underlying connection +func (r *RemoteSocket) Disconnect(status bool) *RemoteSocket { + r.operator.DisconnectSockets(status) + return r +} diff --git a/socket/server.go b/socket/server.go index da2db52..c500f07 100644 --- a/socket/server.go +++ b/socket/server.go @@ -15,13 +15,14 @@ import ( "github.com/andybalholm/brotli" "github.com/zishang520/engine.io/v2/engine" + "github.com/zishang520/engine.io/v2/events" "github.com/zishang520/engine.io/v2/log" "github.com/zishang520/engine.io/v2/types" "github.com/zishang520/engine.io/v2/utils" "github.com/zishang520/socket.io-go-parser/v2/parser" ) -const clientVersion = "4.7.2" +const clientVersion = "4.7.5" var ( dotMapRegex = regexp.MustCompile(`\.map`) @@ -439,7 +440,7 @@ func (Server) sendFile(filename string, w http.ResponseWriter, r *http.Request) } defer file.Close() - encoding := utils.Contains(r.Header.Get("Accept-Encoding"), []string{"gzip", "deflate", "br"}) + encoding := utils.Contains(r.Header.Get("Accept-Encoding"), []string{"gzip", "deflate", "br" /*, "zstd"*/}) switch encoding { case "br": @@ -470,6 +471,13 @@ func (Server) sendFile(filename string, w http.ResponseWriter, r *http.Request) w.Header().Set("Content-Encoding", "deflate") w.WriteHeader(http.StatusOK) io.Copy(fl, file) + + // TODO: Implement zstd support. + // Go's standard library does not yet support zstd compression (see issue: https://github.com/golang/go/issues/62513). + // Consider using the klauspost/compress library for a high-performance implementation: + // https://github.com/klauspost/compress/tree/master/zstd + // case "zstd": + default: w.WriteHeader(http.StatusOK) io.Copy(w, file) @@ -512,7 +520,7 @@ func (s *Server) onconnection(conns ...any) { // Param: string | *regexp.Regexp | ParentNspNameMatchFn - nsp name // // Param: func(...any) - nsp `connection` ev handler -func (s *Server) Of(name any, fn func(...any)) Namespace { +func (s *Server) Of(name any, fn events.Listener) Namespace { switch n := name.(type) { case ParentNspNameMatchFn: parentNsp := NewParentNamespace(s) @@ -619,7 +627,7 @@ func (s *Server) Close(fn func(error)) { // }) // // Param: func(*ExtendedError) - the middleware function -func (s *Server) Use(fn func(*Socket, func(*ExtendedError))) *Server { +func (s *Server) Use(fn NamespaceMiddleware) *Server { s.sockets.Use(fn) return s } @@ -741,8 +749,8 @@ func (s *Server) ServerSideEmit(ev string, args ...any) error { // // Param: args - an array of arguments // -// Return: a `func(func([]any, error))` that will be fulfilled when all servers have acknowledged the event -func (s *Server) ServerSideEmitWithAck(ev string, args ...any) func(func([]any, error)) { +// Return: a `func(socket.Ack)` that will be fulfilled when all servers have acknowledged the event +func (s *Server) ServerSideEmitWithAck(ev string, args ...any) func(Ack) error { return s.sockets.ServerSideEmitWithAck(ev, args...) } diff --git a/socket/session-aware-adapter.go b/socket/session-aware-adapter.go index 82d76fb..86df4a6 100644 --- a/socket/session-aware-adapter.go +++ b/socket/session-aware-adapter.go @@ -10,7 +10,6 @@ import ( type ( SessionAwareAdapterBuilder struct { - AdapterConstructor } sessionAwareAdapter struct { @@ -69,8 +68,7 @@ func (s *sessionAwareAdapter) Construct(nsp Namespace) { } func (s *sessionAwareAdapter) PersistSession(session *SessionToPersist) { - _session := &SessionWithTimestamp{SessionToPersist: session, DisconnectedAt: time.Now().UnixMilli()} - s.sessions.Store(_session.Pid, _session) + s.sessions.Store(session.Pid, &SessionWithTimestamp{SessionToPersist: session, DisconnectedAt: time.Now().UnixMilli()}) } func (s *sessionAwareAdapter) RestoreSession(pid PrivateSessionId, offset string) (*Session, error) { diff --git a/socket/socket.go b/socket/socket.go index b3f1ce2..13f3311 100644 --- a/socket/socket.go +++ b/socket/socket.go @@ -23,6 +23,10 @@ var ( ) type ( + Ack = func([]any, error) + + SocketMiddleware = func([]any, func(error)) + Handshake struct { // The headers sent as part of the handshake Headers map[string][]string `json:"headers" msgpack:"headers"` @@ -111,8 +115,8 @@ type ( // TODO: remove this unused reference server *Server adapter Adapter - acks *types.Map[uint64, func([]any, error)] - fns *types.Slice[func([]any, func(error))] + acks *types.Map[uint64, Ack] + fns *types.Slice[SocketMiddleware] flags atomic.Pointer[BroadcastFlags] _anyListeners *types.Slice[events.Listener] _anyOutgoingListeners *types.Slice[events.Listener] @@ -126,8 +130,8 @@ func MakeSocket() *Socket { StrictEventEmitter: NewStrictEventEmitter(), // Initialize default value - acks: &types.Map[uint64, func([]any, error)]{}, - fns: types.NewSlice[func([]any, func(error))](), + acks: &types.Map[uint64, Ack]{}, + fns: types.NewSlice[SocketMiddleware](), _anyListeners: types.NewSlice[events.Listener](), _anyOutgoingListeners: types.NewSlice[events.Listener](), } @@ -188,7 +192,7 @@ func (s *Socket) Connected() bool { return s.connected.Load() } -func (s *Socket) Acks() *types.Map[uint64, func([]any, error)] { +func (s *Socket) Acks() *types.Map[uint64, Ack] { return s.acks } @@ -284,7 +288,7 @@ func (s *Socket) Emit(ev string, args ...any) error { Data: data, } // access last argument to see if it's an ACK callback - if fn, ok := data[data_len-1].(func([]any, error)); ok { + if fn, ok := data[data_len-1].(Ack); ok { id := s.nsp.Ids() socket_log.Debug("emitting packet with ack id %d", id) packet.Data = data[:data_len-1] @@ -332,14 +336,14 @@ func (s *Socket) Emit(ev string, args ...any) error { // }) // }) // -// Return: a `func(func([]any, error))` that will be fulfilled when all clients have acknowledged the event -func (s *Socket) EmitWithAck(ev string, args ...any) func(func([]any, error)) { - return func(ack func([]any, error)) { +// Return: a `func(socket.Ack)` that will be fulfilled when all clients have acknowledged the event +func (s *Socket) EmitWithAck(ev string, args ...any) func(Ack) { + return func(ack Ack) { s.Emit(ev, append(args, ack)...) } } -func (s *Socket) registerAckCallback(id uint64, ack func([]any, error)) { +func (s *Socket) registerAckCallback(id uint64, ack Ack) { timeout := s.flags.Load().Timeout if timeout == nil { s.acks.Store(id, ack) @@ -563,7 +567,7 @@ func (s *Socket) onevent(packet *parser.Packet) { // Produces an ack callback to emit with an event. // // Param: id - packet id -func (s *Socket) ack(id uint64) func([]any, error) { +func (s *Socket) ack(id uint64) Ack { sent := &sync.Once{} return func(args []any, _ error) { // prevent double callbacks @@ -785,7 +789,7 @@ func (s *Socket) dispatch(event []any) { // }); // // Param: fn - middleware function (event, next) -func (s *Socket) Use(fn func([]any, func(error))) *Socket { +func (s *Socket) Use(fn SocketMiddleware) *Socket { s.fns.Push(fn) return s } diff --git a/socket/type.go b/socket/type.go deleted file mode 100644 index 8abde88..0000000 --- a/socket/type.go +++ /dev/null @@ -1,14 +0,0 @@ -package socket - -import ( - "github.com/zishang520/engine.io/v2/types" -) - -type ( - SocketDetails interface { - Id() SocketId - Handshake() *Handshake - Rooms() *types.Set[Room] - Data() any - } -)