From 2b5e625cf98693703dba9a31f1b89ed5b170d524 Mon Sep 17 00:00:00 2001 From: Caio Teixeira Date: Fri, 24 May 2024 14:51:37 -0300 Subject: [PATCH] auth: add signature middleware (#13) What SignatureMiddleware validates if the API caller has ownership of the configured Stellar public key that signed the request. This adds authentication when requesting the resources. Why Security reasons. --- cmd/serve.go | 21 ++- cmd/utils/custom_set_value.go | 26 +++ cmd/utils/custom_set_value_test.go | 126 +++++++++++++++ go.mod | 2 +- internal/serve/auth/mock.go | 18 +++ internal/serve/auth/signature_verifier.go | 153 ++++++++++++++++++ .../serve/auth/signature_verifier_test.go | 118 ++++++++++++++ internal/serve/httperror/errors.go | 28 +++- internal/serve/httperror/errors_test.go | 5 +- .../serve/httphandler/payments_handler.go | 6 +- internal/serve/middleware/middleware.go | 49 ++++++ internal/serve/middleware/middleware_test.go | 138 ++++++++++++++++ internal/serve/serve.go | 31 ++-- internal/utils/utils.go | 17 ++ internal/utils/utils_test.go | 77 +++++++++ 15 files changed, 792 insertions(+), 23 deletions(-) create mode 100644 cmd/utils/custom_set_value.go create mode 100644 cmd/utils/custom_set_value_test.go create mode 100644 internal/serve/auth/mock.go create mode 100644 internal/serve/auth/signature_verifier.go create mode 100644 internal/serve/auth/signature_verifier_test.go create mode 100644 internal/serve/middleware/middleware.go create mode 100644 internal/serve/middleware/middleware_test.go create mode 100644 internal/utils/utils.go create mode 100644 internal/utils/utils_test.go diff --git a/cmd/serve.go b/cmd/serve.go index 0211c46..4e2c524 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" "github.com/stellar/go/support/config" supportlog "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/cmd/utils" "github.com/stellar/wallet-backend/internal/serve" ) @@ -35,15 +36,33 @@ func (c *serveCmd) Command() *cobra.Command { FlagDefault: "postgres://postgres@localhost:5432/wallet-backend?sslmode=disable", Required: true, }, + { + Name: "server-base-url", + Usage: "The server base URL", + OptType: types.String, + ConfigKey: &cfg.ServerBaseURL, + FlagDefault: "http://localhost:8000", + Required: true, + }, + { + Name: "wallet-signing-key", + Usage: "The public key of the Stellar account that signs the payloads when making HTTP Request to this server.", + OptType: types.String, + CustomSetValue: utils.SetConfigOptionStellarPublicKey, + ConfigKey: &cfg.WalletSigningKey, + Required: true, + }, } cmd := &cobra.Command{ Use: "serve", Short: "Run Wallet Backend server", - Run: func(_ *cobra.Command, _ []string) { + PersistentPreRun: func(_ *cobra.Command, _ []string) { cfgOpts.Require() if err := cfgOpts.SetValues(); err != nil { c.Logger.Fatalf("Error setting values of config options: %s", err.Error()) } + }, + Run: func(_ *cobra.Command, _ []string) { c.Run(cfg) }, } diff --git a/cmd/utils/custom_set_value.go b/cmd/utils/custom_set_value.go new file mode 100644 index 0000000..f075a40 --- /dev/null +++ b/cmd/utils/custom_set_value.go @@ -0,0 +1,26 @@ +package utils + +import ( + "fmt" + + "github.com/spf13/viper" + "github.com/stellar/go/keypair" + "github.com/stellar/go/support/config" +) + +func SetConfigOptionStellarPublicKey(co *config.ConfigOption) error { + publicKey := viper.GetString(co.Name) + + kp, err := keypair.ParseAddress(publicKey) + if err != nil { + return fmt.Errorf("validating public key in %s: %w", co.Name, err) + } + + key, ok := co.ConfigKey.(*string) + if !ok { + return fmt.Errorf("the expected type for the config key in %s is a string, but a %T was provided instead", co.Name, co.ConfigKey) + } + *key = kp.Address() + + return nil +} diff --git a/cmd/utils/custom_set_value_test.go b/cmd/utils/custom_set_value_test.go new file mode 100644 index 0000000..d390d5c --- /dev/null +++ b/cmd/utils/custom_set_value_test.go @@ -0,0 +1,126 @@ +package utils + +import ( + "go/types" + "os" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/stellar/go/support/config" + "github.com/stellar/wallet-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// customSetterTestCase is a test case to test a custom_set_value function. +type customSetterTestCase[T any] struct { + name string + args []string + envValue string + wantErrContains string + wantResult T +} + +// customSetterTester tests a custom_set_value function, according with the customSetterTestCase provided. +func customSetterTester[T any](t *testing.T, tc customSetterTestCase[T], co config.ConfigOption) { + t.Helper() + ClearTestEnvironment(t) + if tc.envValue != "" { + envName := strings.ToUpper(co.Name) + envName = strings.ReplaceAll(envName, "-", "_") + t.Setenv(envName, tc.envValue) + } + + // start the CLI command + testCmd := cobra.Command{ + RunE: func(cmd *cobra.Command, args []string) error { + co.Require() + return co.SetValue() + }, + } + // mock the command line output + buf := new(strings.Builder) + testCmd.SetOut(buf) + + // Initialize the command for the given option + err := co.Init(&testCmd) + require.NoError(t, err) + + // execute command line + if len(tc.args) > 0 { + testCmd.SetArgs(tc.args) + } + err = testCmd.Execute() + + // check the result + if tc.wantErrContains != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrContains) + } else { + assert.NoError(t, err) + } + + if !utils.IsEmpty(tc.wantResult) { + destPointer := utils.UnwrapInterfaceToPointer[T](co.ConfigKey) + assert.Equal(t, tc.wantResult, *destPointer) + } +} + +// clearTestEnvironment removes all envs from the test environment. It's useful +// to make tests independent from the localhost environment variables. +func ClearTestEnvironment(t *testing.T) { + t.Helper() + + // remove all envs from tghe test environment + for _, env := range os.Environ() { + key := env[:strings.Index(env, "=")] + t.Setenv(key, "") + } +} + +func TestSetConfigOptionStellarPublicKey(t *testing.T) { + opts := struct{ sep10SigningPublicKey string }{} + + co := config.ConfigOption{ + Name: "wallet-signing-key", + OptType: types.String, + CustomSetValue: SetConfigOptionStellarPublicKey, + ConfigKey: &opts.sep10SigningPublicKey, + } + expectedPublicKey := "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S" + + testCases := []customSetterTestCase[string]{ + { + name: "returns an error if the public key is empty", + wantErrContains: "validating public key in wallet-signing-key: strkey is 0 bytes long; minimum valid length is 5", + }, + { + name: "returns an error if the public key is invalid", + args: []string{"--wallet-signing-key", "invalid_public_key"}, + wantErrContains: "validating public key in wallet-signing-key: base32 decode failed: illegal base32 data at input byte 18", + }, + { + name: "returns an error if the public key is invalid (private key instead)", + args: []string{"--wallet-signing-key", "SDISQRUPIHAO5WIIGY4QRDCINZSA44TX3OIIUK3C63NUKN5DABKEQ276"}, + wantErrContains: "validating public key in wallet-signing-key: invalid version byte", + }, + { + name: "handles Stellar public key through the CLI flag", + args: []string{"--wallet-signing-key", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S"}, + wantResult: expectedPublicKey, + }, + { + name: "handles Stellar public key through the ENV vars", + envValue: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + wantResult: expectedPublicKey, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.sep10SigningPublicKey = "" + customSetterTester(t, tc, co) + }) + } +} diff --git a/go.mod b/go.mod index eeb5532..1866e67 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/rubenv/sql-migrate v1.6.1 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 + github.com/spf13/viper v1.17.0 github.com/stellar/go v0.0.0-20240416222646-fd107948e6c4 github.com/stretchr/testify v1.9.0 ) @@ -47,7 +48,6 @@ require ( github.com/spf13/afero v1.10.0 // indirect github.com/spf13/cast v1.5.1 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/spf13/viper v1.17.0 // indirect github.com/stellar/go-xdr v0.0.0-20231122183749-b53fb00bcac2 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect diff --git a/internal/serve/auth/mock.go b/internal/serve/auth/mock.go new file mode 100644 index 0000000..b78b273 --- /dev/null +++ b/internal/serve/auth/mock.go @@ -0,0 +1,18 @@ +package auth + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type MockSignatureVerifier struct { + mock.Mock +} + +var _ SignatureVerifier = (*MockSignatureVerifier)(nil) + +func (sv *MockSignatureVerifier) VerifySignature(ctx context.Context, signatureHeaderContent string, reqBody []byte) error { + args := sv.Called(ctx, signatureHeaderContent, reqBody) + return args.Error(0) +} diff --git a/internal/serve/auth/signature_verifier.go b/internal/serve/auth/signature_verifier.go new file mode 100644 index 0000000..2dbb502 --- /dev/null +++ b/internal/serve/auth/signature_verifier.go @@ -0,0 +1,153 @@ +package auth + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/strkey" + "github.com/stellar/go/support/log" +) + +type SignatureVerifier interface { + VerifySignature(ctx context.Context, signatureHeaderContent string, rawReqBody []byte) error +} + +var ( + ErrStellarSignatureNotVerified = errors.New("neither Signature nor X-Stellar-Signature header could be verified") +) + +type ErrInvalidTimestampFormat struct { + TimestampString string + timestampValueError bool +} + +func (e ErrInvalidTimestampFormat) Error() string { + if e.timestampValueError { + return fmt.Sprintf("signature format different than expected. expected unix seconds, got: %s", e.TimestampString) + } + return fmt.Sprintf("malformed timestamp: %s", e.TimestampString) +} + +type ErrExpiredSignatureTimestamp struct { + ExpiredSignatureTimestamp time.Time + CheckTime time.Time +} + +func (e ErrExpiredSignatureTimestamp) Error() string { + return fmt.Sprintf("signature timestamp has expired. sig timestamp: %s, check time %s", e.ExpiredSignatureTimestamp.Format(time.RFC3339), e.CheckTime.Format(time.RFC3339)) +} + +type StellarSignatureVerifier struct { + ServerHostname string + WalletSigningKey string +} + +var _ SignatureVerifier = (*StellarSignatureVerifier)(nil) + +// VerifySignature verifies the Signature or X-Stellar-Signature content and checks if the signature is signed for a known caller. +func (sv *StellarSignatureVerifier) VerifySignature(ctx context.Context, signatureHeaderContent string, rawReqBody []byte) error { + t, s, err := ExtractTimestampedSignature(signatureHeaderContent) + if err != nil { + log.Ctx(ctx).Error(err) + return ErrStellarSignatureNotVerified + } + + // 2 seconds + err = VerifyGracePeriodSeconds(t, 2*time.Second) + if err != nil { + log.Ctx(ctx).Error(err) + return ErrStellarSignatureNotVerified + } + + signatureBytes, err := base64.StdEncoding.DecodeString(s) + if err != nil { + log.Ctx(ctx).Errorf("unable to decode signature value %s: %s", s, err.Error()) + return ErrStellarSignatureNotVerified + } + + payload := t + "." + sv.ServerHostname + "." + string(rawReqBody) + + // TODO: perhaps add possibility to have more than one signing key. + kp, err := keypair.ParseAddress(sv.WalletSigningKey) + if err != nil { + return fmt.Errorf("parsing wallet signing key %s: %w", sv.WalletSigningKey, err) + } + + err = kp.Verify([]byte(payload), signatureBytes) + if err != nil { + log.Ctx(ctx).Errorf("unable to verify the signature: %s", err.Error()) + return ErrStellarSignatureNotVerified + } + + return nil +} + +func ExtractTimestampedSignature(signatureHeaderContent string) (t string, s string, err error) { + parts := strings.SplitN(signatureHeaderContent, ",", 2) + if len(parts) != 2 { + return "", "", fmt.Errorf("malformed header: %s", signatureHeaderContent) + } + + tHeaderContent := parts[0] + timestampParts := strings.SplitN(tHeaderContent, "=", 2) + if len(timestampParts) != 2 || strings.TrimSpace(timestampParts[0]) != "t" { + return "", "", &ErrInvalidTimestampFormat{TimestampString: tHeaderContent} + } + t = strings.TrimSpace(timestampParts[1]) + + sHeaderContent := parts[1] + signatureParts := strings.SplitN(sHeaderContent, "=", 2) + if len(signatureParts) != 2 || strings.TrimSpace(signatureParts[0]) != "s" { + return "", "", fmt.Errorf("malformed signature: %s", signatureParts) + } + s = strings.TrimSpace(signatureParts[1]) + + return t, s, nil +} + +func VerifyGracePeriodSeconds(timestampString string, gracePeriod time.Duration) error { + // Note: from Nov 20th, 2286 this RegEx will fail because of an extra digit + if ok, _ := regexp.MatchString(`^\d{10}$`, timestampString); !ok { + return &ErrInvalidTimestampFormat{TimestampString: timestampString, timestampValueError: true} + } + + timestampUnix, err := strconv.ParseInt(timestampString, 10, 64) + if err != nil { + return fmt.Errorf("unable to parse timestamp value %s: %v", timestampString, err) + } + + return verifyGracePeriod(time.Unix(timestampUnix, 0), gracePeriod) +} + +func verifyGracePeriod(timestamp time.Time, gracePeriod time.Duration) error { + now := time.Now() + if !timestamp.Add(gracePeriod).After(now) { + return &ErrExpiredSignatureTimestamp{ExpiredSignatureTimestamp: timestamp, CheckTime: now} + } + + return nil +} + +func NewStellarSignatureVerifier(serverHostName, walletSigningKey string) (*StellarSignatureVerifier, error) { + if !strkey.IsValidEd25519PublicKey(walletSigningKey) { + return nil, fmt.Errorf("invalid wallet signing key") + } + + u, err := url.ParseRequestURI(serverHostName) + if err != nil { + return nil, fmt.Errorf("invalid server hostname: %w", err) + } + + return &StellarSignatureVerifier{ + ServerHostname: u.Hostname(), + WalletSigningKey: walletSigningKey, + }, nil +} diff --git a/internal/serve/auth/signature_verifier_test.go b/internal/serve/auth/signature_verifier_test.go new file mode 100644 index 0000000..aae45f6 --- /dev/null +++ b/internal/serve/auth/signature_verifier_test.go @@ -0,0 +1,118 @@ +package auth + +import ( + "context" + "fmt" + "net/url" + "strconv" + "testing" + "time" + + "github.com/stellar/go/keypair" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSignatureVerifierVerifySignature(t *testing.T) { + host, err := url.ParseRequestURI("https://example.com") + require.NoError(t, err) + signingKey := keypair.MustRandom() + + ctx := context.Background() + signatureVerifier, err := NewStellarSignatureVerifier(host.String(), signingKey.Address()) + require.NoError(t, err) + + t.Run("returns_error_when_the_wallet_signing_key_is_not_the_singer", func(t *testing.T) { + signer := keypair.MustRandom() + now := time.Now() + reqBody := `{"value": "new value"}` + sig := fmt.Sprintf("%d.%s.%s", now.Unix(), host.Hostname(), reqBody) + sig, err = signer.SignBase64([]byte(sig)) + require.NoError(t, err) + signatureHeaderContent := fmt.Sprintf("t=%d, s=%s", now.Unix(), sig) + + err := signatureVerifier.VerifySignature(ctx, signatureHeaderContent, []byte(reqBody)) + assert.EqualError(t, err, ErrStellarSignatureNotVerified.Error()) + }) + + t.Run("successfully_verifies_signature", func(t *testing.T) { + now := time.Now() + reqBody := `{"value": "new value"}` + sig := fmt.Sprintf("%d.%s.%s", now.Unix(), host.Hostname(), reqBody) + sig, err = signingKey.SignBase64([]byte(sig)) + require.NoError(t, err) + signatureHeaderContent := fmt.Sprintf("t=%d, s=%s", now.Unix(), sig) + + err := signatureVerifier.VerifySignature(ctx, signatureHeaderContent, []byte(reqBody)) + assert.NoError(t, err) + + // When there's no request body + now = time.Now() + sig = fmt.Sprintf("%d.%s.%s", now.Unix(), host.Hostname(), "") + sig, err = signingKey.SignBase64([]byte(sig)) + require.NoError(t, err) + signatureHeaderContent = fmt.Sprintf("t=%d, s=%s", now.Unix(), sig) + + err = signatureVerifier.VerifySignature(ctx, signatureHeaderContent, []byte{}) + assert.NoError(t, err) + }) +} + +func TestExtractTimestampedSignature(t *testing.T) { + t.Run("invalid_header_content", func(t *testing.T) { + ts, s, err := ExtractTimestampedSignature("") + assert.EqualError(t, err, "malformed header: ") + assert.Empty(t, ts) + assert.Empty(t, s) + + ts, s, err = ExtractTimestampedSignature("a,b") + var errTimestampFormat *ErrInvalidTimestampFormat + assert.ErrorAs(t, err, &errTimestampFormat) + assert.EqualError(t, err, "malformed timestamp: a") + assert.Empty(t, ts) + assert.Empty(t, s) + + ts, s, err = ExtractTimestampedSignature("t=abc,b") + assert.EqualError(t, err, "malformed signature: [b]") + assert.Empty(t, ts) + assert.Empty(t, s) + }) + + t.Run("successfully_extracts_timestamp_and_signature", func(t *testing.T) { + ts, s, err := ExtractTimestampedSignature("t=123,s=abc") + assert.NoError(t, err) + assert.Equal(t, "123", ts) + assert.Equal(t, "abc", s) + }) +} + +func TestVerifyGracePeriodSeconds(t *testing.T) { + t.Run("invalid_timestamp", func(t *testing.T) { + var errTimestampFormat *ErrInvalidTimestampFormat + err := VerifyGracePeriodSeconds("", 2*time.Second) + assert.ErrorAs(t, err, &errTimestampFormat) + assert.EqualError(t, err, "signature format different than expected. expected unix seconds, got: ") + + err = VerifyGracePeriodSeconds("123", 2*time.Second) + assert.ErrorAs(t, err, &errTimestampFormat) + assert.EqualError(t, err, "signature format different than expected. expected unix seconds, got: 123") + + err = VerifyGracePeriodSeconds("12345678910", 2*time.Second) + assert.ErrorAs(t, err, &errTimestampFormat) + assert.EqualError(t, err, "signature format different than expected. expected unix seconds, got: 12345678910") + }) + + t.Run("successfully_verifies_grace_period", func(t *testing.T) { + var errExpiredSignatureTimestamp *ErrExpiredSignatureTimestamp + now := time.Now().Add(-5 * time.Second) + ts := now.Unix() + err := VerifyGracePeriodSeconds(strconv.FormatInt(ts, 10), 2*time.Second) + assert.ErrorAs(t, err, &errExpiredSignatureTimestamp) + assert.ErrorContains(t, err, fmt.Sprintf("signature timestamp has expired. sig timestamp: %s, check time", now.Format(time.RFC3339))) + + now = time.Now().Add(-1 * time.Second) + ts = now.Unix() + err = VerifyGracePeriodSeconds(strconv.FormatInt(ts, 10), 2*time.Second) + assert.NoError(t, err) + }) +} diff --git a/internal/serve/httperror/errors.go b/internal/serve/httperror/errors.go index 05017f4..06a7e16 100644 --- a/internal/serve/httperror/errors.go +++ b/internal/serve/httperror/errors.go @@ -1,8 +1,10 @@ package httperror import ( + "context" "net/http" + "github.com/stellar/go/support/log" "github.com/stellar/go/support/render/httpjson" ) @@ -23,11 +25,6 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.Error.Render(w) } -var InternalServerError = errorResponse{ - Status: http.StatusInternalServerError, - Error: "An error occurred while processing this request.", -} - var NotFound = errorResponse{ Status: http.StatusNotFound, Error: "The resource at the url requested was not found.", @@ -48,3 +45,24 @@ func BadRequest(message string) errorResponse { Error: message, } } + +func Unauthorized(message string) errorResponse { + if message == "" { + message = "Not authorized." + } + + return errorResponse{ + Status: http.StatusUnauthorized, + Error: message, + } +} + +func InternalServerError(ctx context.Context, message string, err error) errorResponse { + // TODO: track error in Sentry + log.Ctx(ctx).Error(err) + + return errorResponse{ + Status: http.StatusInternalServerError, + Error: "An error occurred while processing this request.", + } +} diff --git a/internal/serve/httperror/errors_test.go b/internal/serve/httperror/errors_test.go index 64256bb..a353494 100644 --- a/internal/serve/httperror/errors_test.go +++ b/internal/serve/httperror/errors_test.go @@ -1,6 +1,7 @@ package httperror import ( + "context" "fmt" "io" "net/http" @@ -17,7 +18,7 @@ func TestErrorResponseRender(t *testing.T) { want errorResponse }{ { - in: InternalServerError, + in: InternalServerError(context.Background(), "", nil), want: errorResponse{Status: http.StatusInternalServerError, Error: "An error occurred while processing this request."}, }, { @@ -49,7 +50,7 @@ func TestErrorHandler(t *testing.T) { want errorResponse }{ { - in: ErrorHandler{InternalServerError}, + in: ErrorHandler{InternalServerError(context.Background(), "", nil)}, want: errorResponse{Status: http.StatusInternalServerError, Error: "An error occurred while processing this request."}, }, { diff --git a/internal/serve/httphandler/payments_handler.go b/internal/serve/httphandler/payments_handler.go index 89dd452..9aee57d 100644 --- a/internal/serve/httphandler/payments_handler.go +++ b/internal/serve/httphandler/payments_handler.go @@ -28,8 +28,7 @@ func (h PaymentsHandler) SubscribeAddress(w http.ResponseWriter, r *http.Request err = h.PaymentModel.SubscribeAddress(ctx, reqBody.Address) if err != nil { - httperror.InternalServerError.Render(w) - // TODO: track in Sentry + httperror.InternalServerError(ctx, "", err).Render(w) return } } @@ -46,8 +45,7 @@ func (h PaymentsHandler) UnsubscribeAddress(w http.ResponseWriter, r *http.Reque err = h.PaymentModel.UnsubscribeAddress(ctx, reqBody.Address) if err != nil { - httperror.InternalServerError.Render(w) - // TODO: track in Sentry + httperror.InternalServerError(ctx, "", err).Render(w) return } } diff --git a/internal/serve/middleware/middleware.go b/internal/serve/middleware/middleware.go new file mode 100644 index 0000000..96b3346 --- /dev/null +++ b/internal/serve/middleware/middleware.go @@ -0,0 +1,49 @@ +package middleware + +import ( + "bytes" + "fmt" + "io" + "net/http" + + "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/internal/serve/auth" + "github.com/stellar/wallet-backend/internal/serve/httperror" +) + +const MaxBodySize int64 = 10_240 // 10kb + +func SignatureMiddleware(signatureVerifier auth.SignatureVerifier) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + sig := req.Header.Get("Signature") + if sig == "" { + sig = req.Header.Get("X-Stellar-Signature") + if sig == "" { + httperror.Unauthorized("").Render(rw) + return + } + } + + ctx := req.Context() + + reqBody, err := io.ReadAll(io.LimitReader(req.Body, MaxBodySize)) + if err != nil { + err = fmt.Errorf("reading request body: %w", err) + httperror.InternalServerError(ctx, "", err).Render(rw) + return + } + + err = signatureVerifier.VerifySignature(ctx, sig, reqBody) + if err != nil { + err = fmt.Errorf("checking request signature: %w", err) + log.Ctx(ctx).Error(err) + httperror.Unauthorized("").Render(rw) + return + } + + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + next.ServeHTTP(rw, req) + }) + } +} diff --git a/internal/serve/middleware/middleware_test.go b/internal/serve/middleware/middleware_test.go new file mode 100644 index 0000000..cc23ba5 --- /dev/null +++ b/internal/serve/middleware/middleware_test.go @@ -0,0 +1,138 @@ +package middleware + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi" + "github.com/stellar/go/support/log" + "github.com/stellar/wallet-backend/internal/serve/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestSignatureMiddleware(t *testing.T) { + signatureVerifierMock := auth.MockSignatureVerifier{} + defer signatureVerifierMock.AssertExpectations(t) + + r := chi.NewRouter() + r.Group(func(r chi.Router) { + r.Use(SignatureMiddleware(&signatureVerifierMock)) + + r.Get("/authenticated", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + r.Post("/authenticated", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + }) + + r.Get("/unauthenticated", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + t.Run("returns_Unauthorized_error_when_no_header_is_sent", func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/authenticated", nil) + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns_Unauthorized_when_a_unexpected_error_occurs_validating_the_token", func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/authenticated", nil) + req.Header.Set("Signature", "signature") + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + signatureVerifierMock. + On("VerifySignature", mock.Anything, "signature", []byte{}). + Return(errors.New("unexpected error")). + Once() + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + + entries := getEntries() + require.Len(t, entries, 1) + assert.Equal(t, entries[0].Message, "checking request signature: unexpected error") + }) + + t.Run("returns_the_response_successfully", func(t *testing.T) { + // Without body - GET requests + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/authenticated", nil) + req.Header.Set("X-Stellar-Signature", "signature") + + signatureVerifierMock. + On("VerifySignature", mock.Anything, "signature", []byte{}). + Return(nil). + Once() + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"status":"ok"}`, string(respBody)) + + // With body - POST, PUT, PATCH requests + reqBody := `{"status": "ok"}` + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/authenticated", strings.NewReader(reqBody)) + req.Header.Set("X-Stellar-Signature", "signature") + + signatureVerifierMock. + On("VerifySignature", mock.Anything, "signature", []byte(reqBody)). + Return(nil). + Once() + + r.ServeHTTP(w, req) + + resp = w.Result() + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"status":"ok"}`, string(respBody)) + }) + + t.Run("doesn't_return_Unauthorized_for_unauthenticated_routes", func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/unauthenticated", nil) + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"status":"ok"}`, string(respBody)) + }) +} diff --git a/internal/serve/serve.go b/internal/serve/serve.go index 562ad84..fa9e738 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -10,19 +10,24 @@ import ( "github.com/stellar/go/support/render/health" "github.com/stellar/wallet-backend/internal/data" "github.com/stellar/wallet-backend/internal/db" + "github.com/stellar/wallet-backend/internal/serve/auth" "github.com/stellar/wallet-backend/internal/serve/httperror" "github.com/stellar/wallet-backend/internal/serve/httphandler" + "github.com/stellar/wallet-backend/internal/serve/middleware" ) type Configs struct { - Logger *supportlog.Entry - Port int - DatabaseURL string + Logger *supportlog.Entry + Port int + ServerBaseURL string + WalletSigningKey string + DatabaseURL string } type handlerDeps struct { - Logger *supportlog.Entry - Models *data.Models + Logger *supportlog.Entry + Models *data.Models + SignatureVerifier auth.SignatureVerifier } func Serve(cfg Configs) error { @@ -49,16 +54,22 @@ func Serve(cfg Configs) error { func getHandlerDeps(cfg Configs) (handlerDeps, error) { dbConnectionPool, err := db.OpenDBConnectionPool(cfg.DatabaseURL) if err != nil { - return handlerDeps{}, fmt.Errorf("error connecting to the database: %w", err) + return handlerDeps{}, fmt.Errorf("connecting to the database: %w", err) } models, err := data.NewModels(dbConnectionPool) if err != nil { - return handlerDeps{}, fmt.Errorf("error creating models for Serve: %w", err) + return handlerDeps{}, fmt.Errorf("creating models for Serve: %w", err) + } + + signatureVerifier, err := auth.NewStellarSignatureVerifier(cfg.ServerBaseURL, cfg.WalletSigningKey) + if err != nil { + return handlerDeps{}, fmt.Errorf("instantiating stellar signature verifier: %w", err) } return handlerDeps{ - Logger: cfg.Logger, - Models: models, + Logger: cfg.Logger, + Models: models, + SignatureVerifier: signatureVerifier, }, nil } @@ -71,7 +82,7 @@ func handler(deps handlerDeps) http.Handler { // Authenticated routes mux.Group(func(r chi.Router) { - // r.Use(...authMiddleware...) + r.Use(middleware.SignatureMiddleware(deps.SignatureVerifier)) r.Route("/payments", func(r chi.Router) { handler := &httphandler.PaymentsHandler{ diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 0000000..3dbdc80 --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,17 @@ +package utils + +import "reflect" + +// IsEmpty checks if a value is empty. +func IsEmpty[T any](v T) bool { + return reflect.ValueOf(&v).Elem().IsZero() +} + +// UnwrapInterfaceToPointer unwraps an interface to a pointer of the given type. +func UnwrapInterfaceToPointer[T any](i interface{}) *T { + t, ok := i.(*T) + if ok { + return t + } + return nil +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go new file mode 100644 index 0000000..557d025 --- /dev/null +++ b/internal/utils/utils_test.go @@ -0,0 +1,77 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnwrapInterfaceToPointer(t *testing.T) { + // Test with a string + strValue := "test" + strValuePtr := &strValue + i := interface{}(strValuePtr) + + unwrappedValue := UnwrapInterfaceToPointer[string](i) + assert.Equal(t, "test", *unwrappedValue) + + // Test with a struct + type testStruct struct { + Name string + } + testStructValue := testStruct{Name: "test"} + testStructValuePtr := &testStructValue + i = interface{}(testStructValuePtr) + assert.Equal(t, testStruct{Name: "test"}, *UnwrapInterfaceToPointer[testStruct](i)) +} + +func TestIsEmpty(t *testing.T) { + type testCase struct { + name string + isEmptyFn func() bool + expected bool + } + + // testStruct is used just for testing empty and non empty structs. + type testStruct struct{ Name string } + + // Define test cases + testCases := []testCase{ + // String + {name: "String empty", isEmptyFn: func() bool { return IsEmpty[string]("") }, expected: true}, + {name: "String non-empty", isEmptyFn: func() bool { return IsEmpty[string]("not empty") }, expected: false}, + // Int + {name: "Int zero", isEmptyFn: func() bool { return IsEmpty[int](0) }, expected: true}, + {name: "Int non-zero", isEmptyFn: func() bool { return IsEmpty[int](1) }, expected: false}, + // Slice: + {name: "Slice nil", isEmptyFn: func() bool { return IsEmpty[[]string](nil) }, expected: true}, + {name: "Slice empty", isEmptyFn: func() bool { return IsEmpty[[]string]([]string{}) }, expected: false}, + {name: "Slice non-empty", isEmptyFn: func() bool { return IsEmpty[[]string]([]string{"not empty"}) }, expected: false}, + // Struct: + {name: "Struct zero", isEmptyFn: func() bool { return IsEmpty[testStruct](testStruct{}) }, expected: true}, + {name: "Struct non-zero", isEmptyFn: func() bool { return IsEmpty[testStruct](testStruct{Name: "not empty"}) }, expected: false}, + // Pointer: + {name: "Pointer nil", isEmptyFn: func() bool { return IsEmpty[*string](nil) }, expected: true}, + {name: "Pointer non-nil", isEmptyFn: func() bool { return IsEmpty[*string](new(string)) }, expected: false}, + // Function: + {name: "Function nil", isEmptyFn: func() bool { return IsEmpty[func() string](nil) }, expected: true}, + {name: "Function non-nil", isEmptyFn: func() bool { return IsEmpty[func() string](func() string { return "not empty" }) }, expected: false}, + // Interface: + {name: "Interface nil", isEmptyFn: func() bool { return IsEmpty[interface{}](nil) }, expected: true}, + {name: "Interface non-nil", isEmptyFn: func() bool { return IsEmpty[interface{}](new(string)) }, expected: false}, + // Map: + {name: "Map nil", isEmptyFn: func() bool { return IsEmpty[map[string]string](nil) }, expected: true}, + {name: "Map empty", isEmptyFn: func() bool { return IsEmpty[map[string]string](map[string]string{}) }, expected: false}, + {name: "Map non-empty", isEmptyFn: func() bool { return IsEmpty[map[string]string](map[string]string{"not empty": "not empty"}) }, expected: false}, + // Channel: + {name: "Channel nil", isEmptyFn: func() bool { return IsEmpty[chan string](nil) }, expected: true}, + {name: "Channel non-nil", isEmptyFn: func() bool { return IsEmpty[chan string](make(chan string)) }, expected: false}, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.isEmptyFn()) + }) + } +}