From a9924230ed09296b1a183e4a0e48b7112a6d3be2 Mon Sep 17 00:00:00 2001 From: Tiger Wang Date: Thu, 10 Nov 2022 15:48:21 -0500 Subject: [PATCH] add support to subscribe to more than one events of the same source id (#4) --- api/message_bus/openapi.yaml | 115 +++++++++++++++++++++----- cmd/message-bus-tool/cmd/constants.go | 2 +- cmd/message-bus-tool/cmd/subscribe.go | 16 ++-- route/api_route_action.go | 10 ++- route/api_route_event.go | 58 +++++++++---- service/event_type_service.go | 47 ++++++++--- service/event_type_service_test.go | 80 ++++++++++-------- 7 files changed, 237 insertions(+), 91 deletions(-) diff --git a/api/message_bus/openapi.yaml b/api/message_bus/openapi.yaml index 12a63dc..de7719c 100644 --- a/api/message_bus/openapi.yaml +++ b/api/message_bus/openapi.yaml @@ -45,7 +45,9 @@ tags: x-tagGroups: - name: Methods tags: + - EventType methods - Event methods + - ActionType methods - Action methods - name: Schemas @@ -66,7 +68,7 @@ paths: List all event types that are currently registered with the message bus. operationId: getEventTypes tags: - - Event methods + - EventType methods responses: "200": $ref: "#/components/responses/GetEventTypesOK" @@ -79,7 +81,7 @@ paths: Register a new event type with the message bus. operationId: RegisterEventType tags: - - Event methods + - EventType methods requestBody: $ref: "#/components/requestBodies/RegisterEventType" responses: @@ -97,9 +99,9 @@ paths: Get all event types that are registered with the message bus for a specific source ID. operationId: getEventTypesBySourceID tags: - - Event methods + - EventType methods parameters: - - $ref: "#/components/parameters/SourceId" + - $ref: "#/components/parameters/SourceID" responses: "200": $ref: "#/components/responses/GetEventTypesOK" @@ -113,16 +115,17 @@ paths: Get an event type that is registered with the message bus for a specific source ID and event name. operationId: getEventType tags: - - Event methods + - EventType methods parameters: - - $ref: "#/components/parameters/SourceId" - - $ref: "#/components/parameters/Name" + - $ref: "#/components/parameters/SourceID" + - $ref: "#/components/parameters/EventName" responses: "200": $ref: "#/components/responses/GetEventTypeOK" "404": $ref: "#/components/responses/ResponseNotFound" + /event/{source_id}/{name}: post: summary: Publish an event description: | @@ -131,8 +134,8 @@ paths: tags: - Event methods parameters: - - $ref: "#/components/parameters/SourceId" - - $ref: "#/components/parameters/Name" + - $ref: "#/components/parameters/SourceID" + - $ref: "#/components/parameters/EventName" requestBody: $ref: "#/components/requestBodies/PublishEvent" responses: @@ -143,7 +146,7 @@ paths: "404": $ref: "#/components/responses/ResponseNotFound" - /event_type/{source_id}/{name}/ws: + /event/{source_id}: get: summary: Subscribe to an event type (WebSocket) description: | @@ -152,8 +155,8 @@ paths: tags: - Event methods parameters: - - $ref: "#/components/parameters/SourceId" - - $ref: "#/components/parameters/Name" + - $ref: "#/components/parameters/SourceID" + - $ref: "#/components/parameters/EventNames" responses: "101": description: | @@ -161,13 +164,23 @@ paths: /action_type: get: + summary: List action types + description: | + List all action types that are currently registered with the message bus. operationId: getActionTypes + tags: + - ActionType methods responses: "200": $ref: "#/components/responses/GetActionTypesOK" post: + summary: Register an action type + description: | + Register a new action type with the message bus. operationId: RegisterActionType + tags: + - ActionType methods requestBody: $ref: "#/components/requestBodies/RegisterActionType" responses: @@ -180,9 +193,14 @@ paths: /action_type/{source_id}: get: + summary: Get action types by source ID + description: | + Get all action types that are registered with the message bus for a specific source ID. operationId: getActionTypesBySourceID + tags: + - ActionType methods parameters: - - $ref: "#/components/parameters/SourceId" + - $ref: "#/components/parameters/SourceID" responses: "200": $ref: "#/components/responses/GetActionTypesOK" @@ -191,21 +209,32 @@ paths: /action_type/{source_id}/{name}: get: + summary: Get an action type by source ID and name + description: | + Get an action type that is registered with the message bus for a specific source ID and action name. operationId: getActionType + tags: + - ActionType methods parameters: - - $ref: "#/components/parameters/SourceId" - - $ref: "#/components/parameters/Name" + - $ref: "#/components/parameters/SourceID" + - $ref: "#/components/parameters/ActionName" responses: "200": $ref: "#/components/responses/GetActionTypeOK" "404": $ref: "#/components/responses/ResponseNotFound" + /action/{source_id}/{name}: post: + summary: Trigger an action + description: | + Trigger an action on the message bus. operationId: triggerAction + tags: + - Action methods parameters: - - $ref: "#/components/parameters/SourceId" - - $ref: "#/components/parameters/Name" + - $ref: "#/components/parameters/SourceID" + - $ref: "#/components/parameters/ActionName" requestBody: $ref: "#/components/requestBodies/TriggerAction" responses: @@ -216,6 +245,22 @@ paths: "404": $ref: "#/components/responses/ResponseNotFound" + /action/{source_id}: + get: + summary: Subscribe to an action type (WebSocket) + description: | + Subscribe to an action type by source ID and name via WebSocket. + operationId: subscribeAction + tags: + - Action methods + parameters: + - $ref: "#/components/parameters/SourceID" + - $ref: "#/components/parameters/ActionNames" + responses: + "101": + description: | + The connection will be upgraded to a WebSocket connection and the client will receive actions as they are triggered. + components: securitySchemes: @@ -225,19 +270,51 @@ components: name: Authorization parameters: - SourceId: + SourceID: name: source_id in: path required: true schema: type: string + example: "local-storage" + + EventName: + name: name + in: path + required: true + schema: + type: string + example: "local-storage:disk:added" - Name: + EventNames: + name: names + in: query + allowEmptyValue: false + required: false + schema: + type: array + items: + type: string + example: "local-storage:disk:added,local-storage:disk:removed" + + ActionName: name: name in: path required: true schema: type: string + example: "local-storage:partition:format" + + ActionNames: + name: names + in: query + allowEmptyValue: false + required: false + schema: + type: array + items: + type: string + example: "local-storage:partition:format,local-storage:partition:mount" requestBodies: RegisterEventType: diff --git a/cmd/message-bus-tool/cmd/constants.go b/cmd/message-bus-tool/cmd/constants.go index d1dfb20..1a356c3 100644 --- a/cmd/message-bus-tool/cmd/constants.go +++ b/cmd/message-bus-tool/cmd/constants.go @@ -2,7 +2,7 @@ package cmd const ( FlagSourceID = "source-id" - FlagEventName = "event-name" + FlagEventNames = "event-name" FlagMessageBufferSize = "message-buffer-size" FlagBaseURL = "base-url" ) diff --git a/cmd/message-bus-tool/cmd/subscribe.go b/cmd/message-bus-tool/cmd/subscribe.go index 05f1f42..f2c69ba 100644 --- a/cmd/message-bus-tool/cmd/subscribe.go +++ b/cmd/message-bus-tool/cmd/subscribe.go @@ -29,12 +29,18 @@ var subscribeCmd = &cobra.Command{ panic(err) } - eventName, err := cmd.Flags().GetString(FlagEventName) + eventNames, err := cmd.Flags().GetString(FlagEventNames) if err != nil { panic(err) } - wsURL := fmt.Sprintf("ws://%s%s/event_type/%s/%s/ws", strings.TrimRight(baseURL, "/"), basePath, sourceID, eventName) + var wsURL string + + if eventNames == "" { + wsURL = fmt.Sprintf("ws://%s%s/event/%s", strings.TrimRight(baseURL, "/"), basePath, sourceID) + } else { + wsURL = fmt.Sprintf("ws://%s%s/event/%s?names=%s", strings.TrimRight(baseURL, "/"), basePath, sourceID, eventNames) + } fmt.Printf("subscribed to %s\n", wsURL) ws, err := websocket.Dial(wsURL, "", origin) @@ -64,16 +70,12 @@ func init() { subscribeCmd.Flags().UintP(FlagMessageBufferSize, "m", 1024, "message buffer size in bytes") subscribeCmd.Flags().StringP(FlagSourceID, "s", "", "source id") - subscribeCmd.Flags().StringP(FlagEventName, "n", "", "event name") + subscribeCmd.Flags().StringP(FlagEventNames, "n", "", "event names (comma separated)") if err := subscribeCmd.MarkFlagRequired(FlagSourceID); err != nil { panic(err) } - if err := subscribeCmd.MarkFlagRequired(FlagEventName); err != nil { - panic(err) - } - // Here you will define your flags and configuration settings. // Cobra supports Persistent Flags which will work for this command diff --git a/route/api_route_action.go b/route/api_route_action.go index 49e6b8c..8a14bee 100644 --- a/route/api_route_action.go +++ b/route/api_route_action.go @@ -13,14 +13,18 @@ func (r *APIRoute) RegisterActionType(ctx echo.Context) error { panic("implement me") // TODO: Implement } -func (r *APIRoute) GetActionTypesBySourceID(ctx echo.Context, sourceID codegen.SourceId) error { +func (r *APIRoute) GetActionTypesBySourceID(ctx echo.Context, sourceID codegen.SourceID) error { panic("implement me") // TODO: Implement } -func (r *APIRoute) GetActionType(ctx echo.Context, sourceID codegen.SourceId, name codegen.Name) error { +func (r *APIRoute) GetActionType(ctx echo.Context, sourceID codegen.SourceID, name codegen.EventName) error { panic("implement me") // TODO: Implement } -func (r *APIRoute) TriggerAction(ctx echo.Context, sourceID codegen.SourceId, name codegen.Name) error { +func (r *APIRoute) TriggerAction(ctx echo.Context, sourceID codegen.SourceID, name codegen.EventName) error { + panic("implement me") // TODO: Implement +} + +func (r *APIRoute) SubscribeAction(ctx echo.Context, sourceID codegen.SourceID, params codegen.SubscribeActionParams) error { panic("implement me") // TODO: Implement } diff --git a/route/api_route_event.go b/route/api_route_event.go index 04e41d6..bb94501 100644 --- a/route/api_route_event.go +++ b/route/api_route_event.go @@ -53,9 +53,9 @@ func (r *APIRoute) RegisterEventType(ctx echo.Context) error { return ctx.JSON(http.StatusOK, result) } -func (r *APIRoute) GetEventTypesBySourceID(ctx echo.Context, sourceID codegen.SourceId) error { +func (r *APIRoute) GetEventTypesBySourceID(ctx echo.Context, sourceID codegen.SourceID) error { results, err := r.services.EventTypeService.GetEventTypesBySourceID(sourceID) - if err != nil { + if err != nil || results == nil { message := err.Error() return ctx.JSON(http.StatusBadRequest, codegen.ResponseBadRequest{Message: &message}) } @@ -63,17 +63,23 @@ func (r *APIRoute) GetEventTypesBySourceID(ctx echo.Context, sourceID codegen.So return ctx.JSON(http.StatusOK, results) } -func (r *APIRoute) GetEventType(ctx echo.Context, sourceID codegen.SourceId, name codegen.Name) error { +func (r *APIRoute) GetEventType(ctx echo.Context, sourceID codegen.SourceID, name codegen.EventName) error { result, err := r.services.EventTypeService.GetEventType(sourceID, name) - if err != nil { + if err != nil || result == nil { message := err.Error() - return ctx.JSON(http.StatusBadRequest, codegen.ResponseBadRequest{Message: &message}) + return ctx.JSON(http.StatusNotFound, codegen.ResponseNotFound{Message: &message}) } return ctx.JSON(http.StatusOK, result) } -func (r *APIRoute) PublishEvent(ctx echo.Context, sourceID codegen.SourceId, name codegen.Name) error { +func (r *APIRoute) PublishEvent(ctx echo.Context, sourceID codegen.SourceID, name codegen.EventName) error { + eventType, err := r.services.EventTypeService.GetEventType(sourceID, name) + if err != nil || eventType == nil { + message := err.Error() + return ctx.JSON(http.StatusNotFound, codegen.ResponseNotFound{Message: &message}) + } + var properties []codegen.Property if err := ctx.Bind(&properties); err != nil { message := err.Error() @@ -98,25 +104,49 @@ func (r *APIRoute) PublishEvent(ctx echo.Context, sourceID codegen.SourceId, nam return ctx.JSON(http.StatusOK, out.EventAdapter(*result)) } -func (r *APIRoute) SubscribeEvent(c echo.Context, sourceID codegen.SourceId, name codegen.Name) error { +func (r *APIRoute) SubscribeEvent(c echo.Context, sourceID codegen.SourceID, params codegen.SubscribeEventParams) error { + var eventNames []string + if params.Names != nil { + for _, eventName := range *params.Names { + eventType, err := r.services.EventTypeService.GetEventType(sourceID, eventName) + if err != nil || eventType == nil { + message := err.Error() + return c.JSON(http.StatusBadRequest, codegen.ResponseBadRequest{Message: &message}) + } + eventNames = append(eventNames, eventName) + } + } else { + eventTypes, err := r.services.EventTypeService.GetEventTypesBySourceID(sourceID) + if err != nil { + return err + } + + for _, eventType := range eventTypes { + eventNames = append(eventNames, eventType.Name) + } + } + conn, _, _, err := ws.UpgradeHTTP(c.Request(), c.Response()) if err != nil { return err } - channel, err := r.services.EventTypeService.Subscribe(sourceID, name) + channel, err := r.services.EventTypeService.Subscribe(sourceID, eventNames) if err != nil { conn.Close() // need to close connection here, instead of defer, because of the goroutine return err } - go func(conn net.Conn, channel chan model.Event) { + go func(conn net.Conn, channel chan model.Event, eventNames []string) { defer conn.Close() - defer func() { - if err := r.services.EventTypeService.Unsubscribe(sourceID, name, channel); err != nil { - logger.Error("error when trying to unsubscribe an event type", zap.Error(err), zap.String("source_id", sourceID), zap.String("name", name)) + defer close(channel) + defer func(eventNames []string) { + for _, name := range eventNames { + if err := r.services.EventTypeService.Unsubscribe(sourceID, name, channel); err != nil { + logger.Error("error when trying to unsubscribe an event type", zap.Error(err), zap.String("source_id", sourceID), zap.String("name", name)) + } } - }() + }(eventNames) logger.Info("started", zap.String("remote_addr", conn.RemoteAddr().String())) @@ -152,7 +182,7 @@ func (r *APIRoute) SubscribeEvent(c echo.Context, sourceID codegen.SourceId, nam return } } - }(conn, channel) + }(conn, channel, eventNames) return nil } diff --git a/service/event_type_service.go b/service/event_type_service.go index 9efbed1..afe2306 100644 --- a/service/event_type_service.go +++ b/service/event_type_service.go @@ -68,14 +68,27 @@ func (s *EventTypeService) Publish(event model.Event) (*model.Event, error) { return &event, nil } -func (s *EventTypeService) Subscribe(sourceID string, name string) (chan model.Event, error) { - eventType, err := s.GetEventType(sourceID, name) - if err != nil { - return nil, err +func (s *EventTypeService) Subscribe(sourceID string, names []string) (chan model.Event, error) { + if len(names) == 0 { + eventTypes, err := s.GetEventTypesBySourceID(sourceID) + if err != nil { + return nil, err + } + + for _, eventType := range eventTypes { + names = append(names, eventType.Name) + } } - if eventType == nil { - return nil, ErrEventNameNotFound + for _, name := range names { + eventType, err := s.GetEventType(sourceID, name) + if err != nil { + return nil, err + } + + if eventType == nil { + return nil, ErrEventNameNotFound + } } if s.subscriberChannels == nil { @@ -86,12 +99,14 @@ func (s *EventTypeService) Subscribe(sourceID string, name string) (chan model.E s.subscriberChannels[sourceID] = make(map[string][]chan model.Event) } - if s.subscriberChannels[sourceID][name] == nil { - s.subscriberChannels[sourceID][name] = make([]chan model.Event, 0) - } - c := make(chan model.Event, 1) - s.subscriberChannels[sourceID][name] = append(s.subscriberChannels[sourceID][name], c) + + for _, name := range names { + if s.subscriberChannels[sourceID][name] == nil { + s.subscriberChannels[sourceID][name] = make([]chan model.Event, 0) + } + s.subscriberChannels[sourceID][name] = append(s.subscriberChannels[sourceID][name], c) + } return c, nil } @@ -113,7 +128,6 @@ func (s *EventTypeService) Unsubscribe(sourceID string, name string, c chan mode if subscriber == c { logger.Info("unsubscribing from event type", zap.String("sourceID", sourceID), zap.String("name", name), zap.Int("subscriber", i)) s.subscriberChannels[sourceID][name] = append(s.subscriberChannels[sourceID][name][:i], s.subscriberChannels[sourceID][name][i+1:]...) - close(c) return nil } } @@ -133,7 +147,14 @@ func (s *EventTypeService) Start(ctx *context.Context) { for sourceID, source := range s.subscriberChannels { for eventName, subscribers := range source { for _, subscriber := range subscribers { - close(subscriber) + select { + case _, ok := <-subscriber: + if ok { + close(subscriber) + } + default: + continue + } } delete(s.subscriberChannels[sourceID], eventName) } diff --git a/service/event_type_service_test.go b/service/event_type_service_test.go index 4547d21..49e0853 100644 --- a/service/event_type_service_test.go +++ b/service/event_type_service_test.go @@ -28,61 +28,73 @@ func TestEventTypeService(t *testing.T) { go service.Start(&ctx) sourceID := "Foo" - name := "Bar" + eventNames := []string{"Bar", "Baz"} // register event type - _, err = service.RegisterEventType(model.EventType{ - SourceID: sourceID, - Name: name, - PropertyTypeList: []model.PropertyType{{Name: "Property1"}, {Name: "Property2"}}, - }) + for _, name := range eventNames { + _, err = service.RegisterEventType(model.EventType{ + SourceID: sourceID, + Name: name, + PropertyTypeList: []model.PropertyType{{Name: "Property1"}, {Name: "Property2"}}, + }) + } assert.NilError(t, err) // get event types eventTypes, err := service.GetEventTypes() assert.NilError(t, err) - assert.Equal(t, len(eventTypes), 1) + assert.Equal(t, len(eventTypes), 2) // get event types by source id eventTypes, err = service.GetEventTypesBySourceID(sourceID) assert.NilError(t, err) - assert.Equal(t, len(eventTypes), 1) + assert.Equal(t, len(eventTypes), 2) // get event type - eventType, err := service.GetEventType(sourceID, name) - assert.NilError(t, err) - assert.Equal(t, eventType.SourceID, sourceID) - assert.Equal(t, eventType.Name, name) + for _, name := range eventNames { + eventType, err := service.GetEventType(sourceID, name) + assert.NilError(t, err) + assert.Equal(t, eventType.SourceID, sourceID) + assert.Equal(t, eventType.Name, name) + } // subscribe event type - channel, err := service.Subscribe(sourceID, name) + channel, err := service.Subscribe(sourceID, eventNames) assert.NilError(t, err) outputChannel := make(chan model.Event) - go func() { - event, ok := <-channel - if !ok { - t.Error("channel closed") + go func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case event, ok := <-channel: + if !ok { + t.Error("channel closed") + } + outputChannel <- event + } + } + }(ctx) + + for _, name := range eventNames { + expectedEvent := model.Event{ + SourceID: sourceID, + Name: name, + Properties: []model.Property{ + {Name: "Property1", Value: "Value1"}, + {Name: "Property2", Value: "Value2"}, + }, } - outputChannel <- event - }() - - expectedEvent := model.Event{ - SourceID: sourceID, - Name: name, - Properties: []model.Property{ - {Name: "Property1", Value: "Value1"}, - {Name: "Property2", Value: "Value2"}, - }, - } - actualEvent1, err := service.Publish(expectedEvent) - assert.NilError(t, err) - assert.DeepEqual(t, model.Event{SourceID: actualEvent1.SourceID, Name: actualEvent1.Name, Properties: actualEvent1.Properties}, expectedEvent) + actualEvent1, err := service.Publish(expectedEvent) + assert.NilError(t, err) + assert.DeepEqual(t, model.Event{SourceID: actualEvent1.SourceID, Name: actualEvent1.Name, Properties: actualEvent1.Properties}, expectedEvent) - actualEvent2, ok := <-outputChannel - assert.Equal(t, ok, true) - assert.DeepEqual(t, actualEvent2, *actualEvent1) + actualEvent2, ok := <-outputChannel + assert.Equal(t, ok, true) + assert.DeepEqual(t, actualEvent2, *actualEvent1) + } }