diff --git a/route/routers.go b/route/routers.go index 70abfd1..c3c17b5 100644 --- a/route/routers.go +++ b/route/routers.go @@ -37,7 +37,15 @@ func NewAPIRouter(swagger *openapi3.T, services *service.Services) (http.Handler e.Use(echo_middleware.JWTWithConfig(echo_middleware.JWTConfig{ Skipper: func(c echo.Context) bool { - return c.RealIP() == "::1" || c.RealIP() == "127.0.0.1" + if c.RealIP() == "::1" || c.RealIP() == "127.0.0.1" { + return true + } + + if c.Request().Method == echo.GET && c.Request().Header.Get(echo.HeaderUpgrade) == "websocket" { + return true + } + + return false }, ParseTokenFunc: func(token string, c echo.Context) (interface{}, error) { claims, code := jwt.Validate(token) diff --git a/service/action_type_service.go b/service/action_type_service.go index 21bf150..7ca6cd0 100644 --- a/service/action_type_service.go +++ b/service/action_type_service.go @@ -3,15 +3,19 @@ package service import ( "context" "errors" + "sync" "time" + "github.com/IceWhaleTech/CasaOS-Common/utils/logger" "github.com/IceWhaleTech/CasaOS-MessageBus/common" "github.com/IceWhaleTech/CasaOS-MessageBus/model" "github.com/IceWhaleTech/CasaOS-MessageBus/repository" + "go.uber.org/zap" ) type ActionService struct { ctx *context.Context + mutex sync.Mutex repository *repository.Repository inboundChannel chan model.Action subscriberChannels map[string]map[string][]chan model.Action @@ -121,7 +125,14 @@ func (s *ActionService) Unsubscribe(sourceID string, name string, c chan model.A } for i, subscriber := range s.subscriberChannels[sourceID][name] { + s.mutex.Lock() + defer s.mutex.Unlock() + if subscriber == c { + if i >= len(s.subscriberChannels[sourceID][name]) { + logger.Error("the i-th subscriber is removed before we get here - concurrency issue?", zap.Int("subscriber", i), zap.Int("total", len(s.subscriberChannels[sourceID][name]))) + return ErrAlreadySubscribed + } s.subscriberChannels[sourceID][name] = append(s.subscriberChannels[sourceID][name][:i], s.subscriberChannels[sourceID][name][i+1:]...) return nil } @@ -132,6 +143,7 @@ func (s *ActionService) Unsubscribe(sourceID string, name string, c chan model.A func (s *ActionService) Start(ctx *context.Context) { s.ctx = ctx + s.mutex = sync.Mutex{} s.inboundChannel = make(chan model.Action) s.subscriberChannels = make(map[string]map[string][]chan model.Action) diff --git a/service/event_type_service.go b/service/event_type_service.go index 7f17a5c..eda3c59 100644 --- a/service/event_type_service.go +++ b/service/event_type_service.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "sync" "time" "github.com/IceWhaleTech/CasaOS-Common/utils/logger" @@ -14,6 +15,7 @@ import ( type EventService struct { ctx *context.Context + mutex sync.Mutex repository *repository.Repository inboundChannel chan model.Event subscriberChannels map[string]map[string][]chan model.Event @@ -123,8 +125,15 @@ func (s *EventService) Unsubscribe(sourceID string, name string, c chan model.Ev } for i, subscriber := range s.subscriberChannels[sourceID][name] { + s.mutex.Lock() + defer s.mutex.Unlock() + if subscriber == c { logger.Info("unsubscribing from event type", zap.String("sourceID", sourceID), zap.String("name", name), zap.Int("subscriber", i)) + if i >= len(s.subscriberChannels[sourceID][name]) { + logger.Error("the i-th subscriber is removed before we get here - concurrency issue?", zap.Int("subscriber", i), zap.Int("total", len(s.subscriberChannels[sourceID][name]))) + return ErrAlreadySubscribed + } s.subscriberChannels[sourceID][name] = append(s.subscriberChannels[sourceID][name][:i], s.subscriberChannels[sourceID][name][i+1:]...) return nil } @@ -135,6 +144,7 @@ func (s *EventService) Unsubscribe(sourceID string, name string, c chan model.Ev func (s *EventService) Start(ctx *context.Context) { s.ctx = ctx + s.mutex = sync.Mutex{} s.inboundChannel = make(chan model.Event) s.subscriberChannels = make(map[string]map[string][]chan model.Event) diff --git a/service/services.go b/service/services.go index f6932ac..0aa0afd 100644 --- a/service/services.go +++ b/service/services.go @@ -15,6 +15,7 @@ type Services struct { var ( ErrInboundChannelNotFound = errors.New("inbound channel not found") ErrSubscriberChannelsNotFound = errors.New("subscriber channels not found") + ErrAlreadySubscribed = errors.New("already subscribed") ) func (s *Services) Start(ctx *context.Context) {