diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 2466da2ebb..a0cecd7f55 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -23,26 +23,22 @@ import ( "sync/atomic" "time" - "github.com/OpenIMSDK/protocol/msggateway" - - "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" - - "github.com/OpenIMSDK/protocol/constant" - - "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" - "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" - + "github.com/go-playground/validator/v10" "github.com/redis/go-redis/v9" + "golang.org/x/sync/errgroup" + "github.com/OpenIMSDK/protocol/constant" + "github.com/OpenIMSDK/protocol/msggateway" "github.com/OpenIMSDK/tools/discoveryregistry" - - "github.com/go-playground/validator/v10" - "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/utils" + + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" ) type LongConnServer interface { @@ -78,7 +74,6 @@ type WsServer struct { onlineUserNum atomic.Int64 onlineUserConnNum atomic.Int64 handshakeTimeout time.Duration - hubServer *Server validate *validator.Validate cache cache.MsgModel userClient *rpcclient.UserRpcClient @@ -183,27 +178,39 @@ func (ws *WsServer) Run() error { return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) // Start listening } +var concurrentRequest = 3 + func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error { conns, err := ws.disCov.GetConns(ctx, config.Config.RpcRegisterName.OpenImMessageGatewayName) if err != nil { return err } + + wg := errgroup.Group{} + wg.SetLimit(concurrentRequest) + // Online push user online message to other node for _, v := range conns { + v := v // safe closure var if v.Target() == ws.disCov.GetSelfConnTarget() { log.ZDebug(ctx, "Filter out this node", "node", v.Target()) continue } - msgClient := msggateway.NewMsgGatewayClient(v) - _, err := msgClient.MultiTerminalLoginCheck(ctx, &msggateway.MultiTerminalLoginCheckReq{ - UserID: client.UserID, - PlatformID: int32(client.PlatformID), Token: client.token, + + wg.Go(func() error { + msgClient := msggateway.NewMsgGatewayClient(v) + _, err := msgClient.MultiTerminalLoginCheck(ctx, &msggateway.MultiTerminalLoginCheckReq{ + UserID: client.UserID, + PlatformID: int32(client.PlatformID), Token: client.token, + }) + if err != nil { + log.ZWarn(ctx, "MultiTerminalLoginCheck err", err, "node", v.Target()) + } + return nil }) - if err != nil { - log.ZWarn(ctx, "MultiTerminalLoginCheck err", err, "node", v.Target()) - continue - } } + + _ = wg.Wait() return nil } @@ -289,70 +296,72 @@ func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Clien } fallthrough case constant.AllLoginButSameTermKick: - if clientOK { - isDeleteUser := ws.clients.deleteClients(newClient.UserID, oldClients) - if isDeleteUser { - ws.onlineUserNum.Add(-1) - } - for _, c := range oldClients { - err := c.KickOnlineMessage() - if err != nil { - log.ZWarn(c.ctx, "KickOnlineMessage", err) - } + if !clientOK { + return + } + + isDeleteUser := ws.clients.deleteClients(newClient.UserID, oldClients) + if isDeleteUser { + ws.onlineUserNum.Add(-1) + } + for _, c := range oldClients { + err := c.KickOnlineMessage() + if err != nil { + log.ZWarn(c.ctx, "KickOnlineMessage", err) } - m, err := ws.cache.GetTokensWithoutError( + } + m, err := ws.cache.GetTokensWithoutError( + newClient.ctx, + newClient.UserID, + newClient.PlatformID, + ) + if err != nil && err != redis.Nil { + log.ZWarn( newClient.ctx, + "get token from redis err", + err, + "userID", newClient.UserID, + "platformID", newClient.PlatformID, ) - if err != nil && err != redis.Nil { - log.ZWarn( - newClient.ctx, - "get token from redis err", - err, - "userID", - newClient.UserID, - "platformID", - newClient.PlatformID, - ) - return - } - if m == nil { - log.ZWarn( - newClient.ctx, - "m is nil", - errors.New("m is nil"), - "userID", - newClient.UserID, - "platformID", - newClient.PlatformID, - ) - return - } - log.ZDebug( + return + } + if m == nil { + log.ZWarn( newClient.ctx, - "get token from redis", + "m is nil", + errors.New("m is nil"), "userID", newClient.UserID, "platformID", newClient.PlatformID, - "tokenMap", - m, ) - - for k := range m { - if k != newClient.ctx.GetToken() { - m[k] = constant.KickedToken - } - } - log.ZDebug(newClient.ctx, "set token map is ", "token map", m, "userID", - newClient.UserID, "token", newClient.ctx.GetToken()) - err = ws.cache.SetTokenMapByUidPid(newClient.ctx, newClient.UserID, newClient.PlatformID, m) - if err != nil { - log.ZWarn(newClient.ctx, "SetTokenMapByUidPid err", err, "userID", newClient.UserID, "platformID", newClient.PlatformID) - return + return + } + log.ZDebug( + newClient.ctx, + "get token from redis", + "userID", + newClient.UserID, + "platformID", + newClient.PlatformID, + "tokenMap", + m, + ) + + for k := range m { + if k != newClient.ctx.GetToken() { + m[k] = constant.KickedToken } } + log.ZDebug(newClient.ctx, "set token map is ", "token map", m, "userID", + newClient.UserID, "token", newClient.ctx.GetToken()) + err = ws.cache.SetTokenMapByUidPid(newClient.ctx, newClient.UserID, newClient.PlatformID, m) + if err != nil { + log.ZWarn(newClient.ctx, "SetTokenMapByUidPid err", err, "userID", newClient.UserID, "platformID", newClient.PlatformID) + return + } } } @@ -404,7 +413,7 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { httpError(connContext, errs.ErrConnArgsErr) return } - if err := authverify.WsVerifyToken(token, userID, platformID); err != nil { + if err = authverify.WsVerifyToken(token, userID, platformID); err != nil { httpError(connContext, err) return }