diff --git a/gateway/gateway_events.go b/gateway/gateway_events.go index a2f717d5..484ad95e 100644 --- a/gateway/gateway_events.go +++ b/gateway/gateway_events.go @@ -43,6 +43,12 @@ type EventReady struct { func (EventReady) messageData() {} func (EventReady) eventData() {} +// EventResumed is the event sent by discord when you successfully resume +type EventResumed struct{} + +func (EventResumed) messageData() {} +func (EventResumed) eventData() {} + type EventApplicationCommandPermissionsUpdate struct { discord.ApplicationCommandPermissions } diff --git a/gateway/gateway_impl.go b/gateway/gateway_impl.go index d6a0d09a..92b3cb8b 100644 --- a/gateway/gateway_impl.go +++ b/gateway/gateway_impl.go @@ -92,8 +92,8 @@ func (g *gatewayImpl) open(ctx context.Context) error { g.config.Logger.Debug(g.formatLogs("opening gateway connection")) g.connMu.Lock() - defer g.connMu.Unlock() if g.conn != nil { + g.connMu.Unlock() return discord.ErrGatewayAlreadyConnected } g.status = StatusConnecting @@ -120,6 +120,7 @@ func (g *gatewayImpl) open(ctx context.Context) error { } g.config.Logger.Error(g.formatLogsf("error connecting to the gateway. url: %s, error: %s, body: %s", gatewayURL, err, body)) + g.connMu.Unlock() return err } @@ -128,13 +129,30 @@ func (g *gatewayImpl) open(ctx context.Context) error { }) g.conn = conn + g.connMu.Unlock() // reset rate limiter when connecting g.config.RateLimiter.Reset() g.status = StatusWaitingForHello - go g.listen(conn) + readyChan := make(chan error) + go g.listen(conn, readyChan) + + select { + case <-ctx.Done(): + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + g.Close(closeCtx) + return ctx.Err() + case err = <-readyChan: + if err != nil { + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + g.Close(closeCtx) + return fmt.Errorf("failed to open gateway connection: %w", err) + } + } return nil } @@ -226,6 +244,13 @@ func (g *gatewayImpl) reconnectTry(ctx context.Context, try int) error { } if err := g.open(ctx); err != nil { + var closeError *websocket.CloseError + if errors.As(err, &closeError) { + closeCode := CloseEventCodeByCode(closeError.Code) + if !closeCode.Reconnect { + return err + } + } if errors.Is(err, discord.ErrGatewayAlreadyConnected) { return err } @@ -279,7 +304,7 @@ func (g *gatewayImpl) sendHeartbeat() { g.lastHeartbeatSent = time.Now().UTC() } -func (g *gatewayImpl) identify() { +func (g *gatewayImpl) identify() error { g.status = StatusIdentifying g.config.Logger.Debug(g.formatLogs("sending Identify command...")) @@ -298,12 +323,13 @@ func (g *gatewayImpl) identify() { } if err := g.Send(context.TODO(), OpcodeIdentify, identify); err != nil { - g.config.Logger.Error(g.formatLogs("error sending Identify command err: ", err)) + return err } g.status = StatusWaitingForReady + return nil } -func (g *gatewayImpl) resume() { +func (g *gatewayImpl) resume() error { g.status = StatusResuming resume := MessageDataResume{ Token: g.token, @@ -313,16 +339,22 @@ func (g *gatewayImpl) resume() { g.config.Logger.Debug(g.formatLogs("sending Resume command...")) if err := g.Send(context.TODO(), OpcodeResume, resume); err != nil { - g.config.Logger.Error(g.formatLogs("error sending resume command err: ", err)) + return err } + return nil } -func (g *gatewayImpl) listen(conn *websocket.Conn) { +func (g *gatewayImpl) listen(conn *websocket.Conn, readyChan chan<- error) { defer g.config.Logger.Debug(g.formatLogs("exiting listen goroutine...")) loop: for { mt, data, err := conn.ReadMessage() if err != nil { + if g.status != StatusReady { + readyChan <- err + close(readyChan) + break loop + } g.connMu.Lock() sameConnection := g.conn == conn g.connMu.Unlock() @@ -382,9 +414,14 @@ loop: go g.heartbeat() if g.config.LastSequenceReceived == nil || g.config.SessionID == nil { - g.identify() + err = g.identify() } else { - g.resume() + err = g.resume() + } + if err != nil { + readyChan <- err + close(readyChan) + return } case OpcodeDispatch: @@ -418,6 +455,16 @@ loop: }) } g.eventHandlerFunc(message.T, message.S, g.config.ShardID, eventData) + if _, ok = eventData.(EventReady); ok { + g.config.Logger.Debug(g.formatLogs("ready successful")) + readyChan <- nil + close(readyChan) + } else if _, ok = eventData.(EventResumed); ok { + g.config.Logger.Debug(g.formatLogs("resume successful")) + g.status = StatusReady + readyChan <- nil + close(readyChan) + } case OpcodeHeartbeat: g.sendHeartbeat() diff --git a/gateway/gateway_messages.go b/gateway/gateway_messages.go index 301f8fa6..efbcafb8 100644 --- a/gateway/gateway_messages.go +++ b/gateway/gateway_messages.go @@ -112,7 +112,7 @@ func UnmarshalEventData(data []byte, eventType EventType) (EventData, error) { eventData = d case EventTypeResumed: - // no data + eventData = EventResumed{} case EventTypeApplicationCommandPermissionsUpdate: var d EventApplicationCommandPermissionsUpdate