From 913512484e14d9331f64887b381b0e037af7f29d Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Mon, 16 Dec 2024 14:52:11 +0200 Subject: [PATCH] Propagate bacerrors over nclprotocol (#4774) ## Summary by CodeRabbit - **New Features** - Introduced JSON serialization and deserialization for error handling. - Added standardized error handling mechanisms using the `bacerrors` package. - Implemented new error reporting functions for connection management. - **Bug Fixes** - Enhanced error handling logic in various components, improving robustness and clarity. - **Tests** - Added comprehensive unit tests for JSON marshalling and unmarshalling of error structures. - **Chores** - Updated import statements across multiple files to include the `bacerrors` package. --- pkg/bacerrors/json.go | 50 ++++++++ pkg/bacerrors/json_test.go | 114 ++++++++++++++++++ pkg/lib/ncl/encoder.go | 3 +- pkg/lib/ncl/encoder_test.go | 13 +- pkg/lib/ncl/error_response.go | 47 ++------ pkg/lib/ncl/publisher.go | 6 +- pkg/lib/ncl/responder.go | 40 +++--- .../nclprotocol/compute/controlplane.go | 5 +- .../nclprotocol/orchestrator/errors.go | 14 +++ .../nclprotocol/orchestrator/manager.go | 6 +- 10 files changed, 223 insertions(+), 75 deletions(-) create mode 100644 pkg/bacerrors/json.go create mode 100644 pkg/bacerrors/json_test.go create mode 100644 pkg/transport/nclprotocol/orchestrator/errors.go diff --git a/pkg/bacerrors/json.go b/pkg/bacerrors/json.go new file mode 100644 index 0000000000..a7b57af3f4 --- /dev/null +++ b/pkg/bacerrors/json.go @@ -0,0 +1,50 @@ +package bacerrors + +import ( + "encoding/json" +) + +// JSONError is a struct used for JSON serialization of errorImpl +type JSONError struct { + Cause string `json:"Cause"` + Hint string `json:"Hint"` + Retryable bool `json:"Retryable"` + FailsExecution bool `json:"FailsExecution"` + Component string `json:"Component"` + HTTPStatusCode int `json:"HTTPStatusCode"` + Details map[string]string `json:"Details"` + Code ErrorCode `json:"Code"` +} + +// MarshalJSON implements the json.Marshaler interface +func (e *errorImpl) MarshalJSON() ([]byte, error) { + return json.Marshal(&JSONError{ + Cause: e.cause, + Hint: e.hint, + Retryable: e.retryable, + FailsExecution: e.failsExecution, + Component: e.component, + HTTPStatusCode: e.httpStatusCode, + Details: e.details, + Code: e.code, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (e *errorImpl) UnmarshalJSON(data []byte) error { + var je JSONError + if err := json.Unmarshal(data, &je); err != nil { + return err + } + + e.cause = je.Cause + e.hint = je.Hint + e.retryable = je.Retryable + e.failsExecution = je.FailsExecution + e.component = je.Component + e.httpStatusCode = je.HTTPStatusCode + e.details = je.Details + e.code = je.Code + + return nil +} diff --git a/pkg/bacerrors/json_test.go b/pkg/bacerrors/json_test.go new file mode 100644 index 0000000000..048cb91af5 --- /dev/null +++ b/pkg/bacerrors/json_test.go @@ -0,0 +1,114 @@ +//go:build unit || !integration + +package bacerrors + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestErrorJSONMarshalling(t *testing.T) { + // Create an error with all fields populated + originalErr := &errorImpl{ + cause: "test error", + hint: "try this instead", + retryable: true, + failsExecution: true, + component: "TestComponent", + httpStatusCode: 404, + details: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + code: NotFoundError, + } + + // Marshal to JSON + jsonData, err := json.Marshal(originalErr) + require.NoError(t, err, "Failed to marshal error to JSON") + + // Unmarshal back to a new error + var unmarshaled errorImpl + err = json.Unmarshal(jsonData, &unmarshaled) + require.NoError(t, err, "Failed to unmarshal JSON to error") + + // Verify all fields match + assert.Equal(t, originalErr.cause, unmarshaled.cause, "Cause field mismatch") + assert.Equal(t, originalErr.hint, unmarshaled.hint, "Hint field mismatch") + assert.Equal(t, originalErr.retryable, unmarshaled.retryable, "Retryable field mismatch") + assert.Equal(t, originalErr.failsExecution, unmarshaled.failsExecution, "FailsExecution field mismatch") + assert.Equal(t, originalErr.component, unmarshaled.component, "Component field mismatch") + assert.Equal(t, originalErr.httpStatusCode, unmarshaled.httpStatusCode, "HTTPStatusCode field mismatch") + assert.Equal(t, originalErr.code, unmarshaled.code, "Code field mismatch") + assert.Equal(t, originalErr.details, unmarshaled.details, "Details field mismatch") +} + +func TestErrorJSONMarshallingEmpty(t *testing.T) { + // Test with minimal fields + originalErr := &errorImpl{ + cause: "minimal error", + } + + // Marshal to JSON + jsonData, err := json.Marshal(originalErr) + require.NoError(t, err, "Failed to marshal minimal error to JSON") + + // Unmarshal back to a new error + var unmarshaled errorImpl + err = json.Unmarshal(jsonData, &unmarshaled) + require.NoError(t, err, "Failed to unmarshal JSON to minimal error") + + // Verify fields + assert.Equal(t, originalErr.cause, unmarshaled.cause, "Cause field mismatch") + assert.Empty(t, unmarshaled.hint, "Hint should be empty") + assert.False(t, unmarshaled.retryable, "Retryable should be false") + assert.False(t, unmarshaled.failsExecution, "FailsExecution should be false") + assert.Empty(t, unmarshaled.component, "Component should be empty") + assert.Zero(t, unmarshaled.httpStatusCode, "HTTPStatusCode should be zero") + assert.Nil(t, unmarshaled.details, "Details should be nil") + assert.Zero(t, unmarshaled.code, "Code should be zero value") +} + +func TestErrorJSONMarshallingInvalid(t *testing.T) { + // Test unmarshalling invalid JSON + invalidJSON := []byte(`{"Cause": "test", "Retryable": "invalid"}`) + var unmarshaled errorImpl + err := json.Unmarshal(invalidJSON, &unmarshaled) + assert.Error(t, err, "Should fail to unmarshal invalid JSON") +} + +func TestErrorJSONFieldVisibility(t *testing.T) { + originalErr := &errorImpl{ + cause: "test error", + hint: "test hint", + retryable: true, + failsExecution: true, + component: "TestComponent", + httpStatusCode: 404, + details: map[string]string{ + "key": "value", + }, + code: NotFoundError, + // These fields should not be marshalled + wrappedErr: nil, + wrappingMsg: "should not appear", + stack: nil, + } + + // Marshal to JSON + jsonData, err := json.Marshal(originalErr) + require.NoError(t, err, "Failed to marshal error to JSON") + + // Convert to map to check field presence + var result map[string]interface{} + err = json.Unmarshal(jsonData, &result) + require.NoError(t, err, "Failed to unmarshal JSON to map") + + // Check that internal fields are not exposed + assert.NotContains(t, result, "wrappedErr", "wrappedErr should not be in JSON") + assert.NotContains(t, result, "wrappingMsg", "wrappingMsg should not be in JSON") + assert.NotContains(t, result, "stack", "stack should not be in JSON") +} diff --git a/pkg/lib/ncl/encoder.go b/pkg/lib/ncl/encoder.go index 3ba1185b7b..3f0782576e 100644 --- a/pkg/lib/ncl/encoder.go +++ b/pkg/lib/ncl/encoder.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/bacalhau-project/bacalhau/pkg/bacerrors" "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/validate" ) @@ -40,7 +41,7 @@ func newEncoder(config encoderConfig) (*encoder, error) { } // Register error response type - if err := config.messageRegistry.Register(ErrorMessageType, ErrorResponse{}); err != nil { + if err := config.messageRegistry.Register(BacErrorMessageType, bacerrors.New("")); err != nil { if errors.Is(err, envelope.ErrAlreadyRegistered{}) { return nil, fmt.Errorf("failed to register error response type: %w", err) } diff --git a/pkg/lib/ncl/encoder_test.go b/pkg/lib/ncl/encoder_test.go index 1ccfecb894..b6b3889e4e 100644 --- a/pkg/lib/ncl/encoder_test.go +++ b/pkg/lib/ncl/encoder_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/suite" + "github.com/bacalhau-project/bacalhau/pkg/bacerrors" "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" ) @@ -177,18 +178,18 @@ func (suite *EncoderTestSuite) TestErrorResponseRegistration() { suite.NotNil(encoder2) // Verify we can encode/decode error responses with both encoders - errorResp := NewErrorResponse(StatusServerError, "test error") - data, err := encoder1.encode(errorResp.ToEnvelope()) + errorResp := bacerrors.New("test error").WithCode(bacerrors.IOError) + data, err := encoder1.encode(BacErrorToEnvelope(errorResp)) suite.Require().NoError(err) decoded, err := encoder2.decode(data) suite.Require().NoError(err) - payload, ok := decoded.GetPayload(&ErrorResponse{}) + payload, ok := decoded.GetPayload(bacerrors.New("")) suite.True(ok) - errResp := payload.(*ErrorResponse) - suite.Equal(StatusServerError, errResp.StatusCode) - suite.Equal("test error", errResp.Message) + errResp := payload.(bacerrors.Error) + suite.Equal(bacerrors.IOError, errResp.Code()) + suite.Equal("test error", errResp.Error()) } func TestEncoderTestSuite(t *testing.T) { diff --git a/pkg/lib/ncl/error_response.go b/pkg/lib/ncl/error_response.go index dd4a340de1..6fb41faa14 100644 --- a/pkg/lib/ncl/error_response.go +++ b/pkg/lib/ncl/error_response.go @@ -4,53 +4,26 @@ import ( "fmt" "time" + "github.com/bacalhau-project/bacalhau/pkg/bacerrors" "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" ) const ( - // ErrorMessageType is used when responding with an error - ErrorMessageType = "ncl.ErrorResponse" + BacErrorMessageType = "BacError" // KeyStatusCode is the key for the status code KeyStatusCode = "Bacalhau-StatusCode" - // StatusBadRequest is the status code for a bad request - StatusBadRequest = 400 - - // StatusNotFound is the status code for a not handler found - StatusNotFound = 404 - - // StatusServerError is the status code for a server error - StatusServerError = 500 + // KeyErrorCode is the key for the error code + KeyErrorCode = "Bacalhau-ErrorCode" ) -// ErrorResponse is used to respond with an error -type ErrorResponse struct { - StatusCode int `json:"StatusCode"` - Message string `json:"Message"` -} - -// NewErrorResponse creates a new error response -func NewErrorResponse(statusCode int, message string) ErrorResponse { - return ErrorResponse{ - StatusCode: statusCode, - Message: message, - } -} - -// Error returns the error message -func (e *ErrorResponse) Error() string { - return fmt.Sprintf("status code: %d, message: %s", e.StatusCode, e.Message) -} - -// ToEnvelope converts the error to an envelope -func (e *ErrorResponse) ToEnvelope() *envelope.Message { - errMsg := envelope.NewMessage(e) - errMsg.WithMetadataValue(envelope.KeyMessageType, ErrorMessageType) - errMsg.WithMetadataValue(KeyStatusCode, fmt.Sprintf("%d", e.StatusCode)) +// BacErrorToEnvelope converts the error to an envelope +func BacErrorToEnvelope(err bacerrors.Error) *envelope.Message { + errMsg := envelope.NewMessage(err) + errMsg.WithMetadataValue(envelope.KeyMessageType, BacErrorMessageType) + errMsg.WithMetadataValue(KeyStatusCode, fmt.Sprintf("%d", err.HTTPStatusCode())) + errMsg.WithMetadataValue(KeyErrorCode, string(err.Code())) errMsg.WithMetadataValue(KeyEventTime, time.Now().Format(time.RFC3339)) return errMsg } - -// compile-time check for interface conformance -var _ error = &ErrorResponse{} diff --git a/pkg/lib/ncl/publisher.go b/pkg/lib/ncl/publisher.go index c1363150f0..aebee3966c 100644 --- a/pkg/lib/ncl/publisher.go +++ b/pkg/lib/ncl/publisher.go @@ -7,6 +7,7 @@ import ( "github.com/nats-io/nats.go" + "github.com/bacalhau-project/bacalhau/pkg/bacerrors" "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/validate" ) @@ -94,10 +95,9 @@ func (p *publisher) Request(ctx context.Context, request PublishRequest) (*envel } // Check if response is an error - if errorResponse, ok := message.GetPayload(&ErrorResponse{}); ok { - return nil, errorResponse.(*ErrorResponse) + if errorResponse, ok := message.GetPayload(bacerrors.New("")); ok { + return nil, errorResponse.(bacerrors.Error) } - return message, nil } diff --git a/pkg/lib/ncl/responder.go b/pkg/lib/ncl/responder.go index b95ea08778..a229d696ba 100644 --- a/pkg/lib/ncl/responder.go +++ b/pkg/lib/ncl/responder.go @@ -9,6 +9,7 @@ import ( "github.com/nats-io/nats.go" "github.com/rs/zerolog/log" + "github.com/bacalhau-project/bacalhau/pkg/bacerrors" "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/validate" ) @@ -142,8 +143,7 @@ func (r *responder) handleRequest(requestMsg *nats.Msg) { // Deserialize request envelope request, err := r.encoder.decode(requestMsg.Data) if err != nil { - errorResponse := NewErrorResponse(StatusBadRequest, err.Error()) - r.sendErrorResponse(requestMsg, errorResponse) + r.sendErrorResponse(requestMsg, bacerrors.Wrap(err, "failed to deserialize request")) return } @@ -158,18 +158,15 @@ func (r *responder) handleRequest(requestMsg *nats.Msg) { Str("messageType", messageType). Str("subject", requestMsg.Subject). Msg("No handler registered for message type") - errorResponse := NewErrorResponse( - StatusNotFound, fmt.Errorf("no handler found for message type: %s", messageType).Error()) - r.sendErrorResponse(requestMsg, errorResponse) + r.sendErrorResponse(requestMsg, bacerrors.New("no handler found for message type: %s", messageType). + WithCode(bacerrors.NotFoundError)) return } // Process request with the appropriate handler response, err := handler.HandleRequest(ctx, request) if err != nil { - errorResponse := NewErrorResponse( - StatusServerError, fmt.Errorf("failed to process request: %w", err).Error()) - r.sendErrorResponse(requestMsg, errorResponse) + r.sendErrorResponse(requestMsg, bacerrors.Wrap(err, "failed to process request")) return } @@ -187,8 +184,16 @@ func (r *responder) sendResponse(requestMsg *nats.Msg, response *envelope.Messag // Serialize response data, err := r.encoder.encode(response) if err != nil { - errorResponse := NewErrorResponse(StatusServerError, err.Error()) - r.sendOrLogError(requestMsg, response, errorResponse) + // If we failed to encode an error response, just log it + if response.Metadata.Get(envelope.KeyMessageType) == BacErrorMessageType { + log.Error().Err(err). + Str("subject", requestMsg.Subject). + Msg("Failed to encode error response") + return + } + + // For normal responses that fail to encode, send a new error response + r.sendErrorResponse(requestMsg, bacerrors.Wrap(err, "failed to encode response")) return } @@ -205,19 +210,8 @@ func (r *responder) sendResponse(requestMsg *nats.Msg, response *envelope.Messag // sendErrorResponse is a convenience method to send an error response. // It converts the ErrorResponse to an envelope before sending. -func (r *responder) sendErrorResponse(requestMsg *nats.Msg, response ErrorResponse) { - r.sendResponse(requestMsg, response.ToEnvelope()) -} - -// sendOrLogError handles errors that occur while sending error responses. -// If we fail to send an error response, we log it instead of trying again -// to avoid potential infinite loops. -func (r *responder) sendOrLogError(requestMsg *nats.Msg, originalResponse *envelope.Message, errorResponse ErrorResponse) { - if originalResponse.Metadata.Get(envelope.KeyMessageType) == ErrorMessageType { - log.Error().Msgf("failed to send error response to %s: %s", requestMsg.Subject, originalResponse.Payload) - } else { - r.sendResponse(requestMsg, originalResponse) - } +func (r *responder) sendErrorResponse(requestMsg *nats.Msg, err bacerrors.Error) { + r.sendResponse(requestMsg, BacErrorToEnvelope(err)) } // compile-time check for interface conformance diff --git a/pkg/transport/nclprotocol/compute/controlplane.go b/pkg/transport/nclprotocol/compute/controlplane.go index 5a55fac937..f3ed2b969d 100644 --- a/pkg/transport/nclprotocol/compute/controlplane.go +++ b/pkg/transport/nclprotocol/compute/controlplane.go @@ -3,16 +3,17 @@ package compute import ( "context" "fmt" - "strings" "sync" "time" "github.com/rs/zerolog/log" + "github.com/bacalhau-project/bacalhau/pkg/bacerrors" "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/models/messages" + "github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes" "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" ) @@ -105,7 +106,7 @@ func (cp *ControlPlane) run(ctx context.Context) { case <-heartbeat.C: if err := cp.heartbeat(ctx); err != nil { - if strings.Contains(err.Error(), "handshake required") { + if bacerrors.IsErrorWithCode(err, nodes.HandshakeRequired) { cp.healthTracker.HandshakeRequired() return } diff --git a/pkg/transport/nclprotocol/orchestrator/errors.go b/pkg/transport/nclprotocol/orchestrator/errors.go new file mode 100644 index 0000000000..7db6485de9 --- /dev/null +++ b/pkg/transport/nclprotocol/orchestrator/errors.go @@ -0,0 +1,14 @@ +package orchestrator + +import ( + "github.com/bacalhau-project/bacalhau/pkg/bacerrors" + "github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes" +) + +const errComponent = "ConnectionManager" + +// NewErrHandshakeRequired returns a standardized error for when a handshake is required +func NewErrHandshakeRequired(nodeID string) bacerrors.Error { + return nodes.NewErrHandshakeRequired(nodeID). + WithComponent(errComponent) +} diff --git a/pkg/transport/nclprotocol/orchestrator/manager.go b/pkg/transport/nclprotocol/orchestrator/manager.go index db990ccff4..a6b6bef117 100644 --- a/pkg/transport/nclprotocol/orchestrator/manager.go +++ b/pkg/transport/nclprotocol/orchestrator/manager.go @@ -248,7 +248,7 @@ func (cm *ComputeManager) handleHeartbeatRequest(ctx context.Context, msg *envel // Verify data plane exists dataPlane, exists := cm.getDataPlane(request.NodeID) if !exists { - return nil, fmt.Errorf("no active data plane for node %s - handshake required", request.NodeID) + return nil, NewErrHandshakeRequired(request.NodeID) } // Process through node manager with sequence info @@ -271,7 +271,7 @@ func (cm *ComputeManager) handleNodeInfoUpdateRequest(ctx context.Context, msg * // Verify data plane exists if _, ok := cm.dataPlanes.Load(request.NodeInfo.ID()); !ok { // Return error asking node to reconnect since it has no active data plane - return nil, fmt.Errorf("no active data plane - handshake required") + return nil, NewErrHandshakeRequired(request.NodeInfo.ID()) } // Process through node manager @@ -293,7 +293,7 @@ func (cm *ComputeManager) handleShutdownRequest(ctx context.Context, msg *envelo // Get data plane to access sequence numbers dataPlane, exists := cm.getDataPlane(notification.NodeID) if !exists { - return nil, fmt.Errorf("no active data plane for node %s", notification.NodeID) + return nil, NewErrHandshakeRequired(notification.NodeID) } response, err := cm.nodeManager.ShutdownNotice(ctx, nodes.ExtendedShutdownNoticeRequest{