Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ export JWKS_URL=http://lfx-platform-heimdall.lfx.svc.cluster.local:4457/.well-kn
# JWT audience
export AUDIENCE=lfx-v2-project-service

# JWT signature algorithm (PS256, PS384, PS512, RS256, RS384, RS512, ES256, ES384, ES512)
export JWT_SIGNATURE_ALGORITHM=PS256

# Skip the ETag validation that requires the correct revision on PUT/DELETE requests.
# When this is set to false, it means you need to make a GET request on the resource
# to get the ETag response header and use it as the ETag request header on the PUT/DELETE
Expand Down
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ func TestEndpoint(t *testing.T) {
| `JWKS_URL` | JWT verification endpoint | - | No |
| `AUDIENCE` | JWT audience | lfx-v2-project-service | No |
| `JWT_AUTH_DISABLED_MOCK_LOCAL_PRINCIPAL` | Mock auth for local dev | - | No |
| `JWT_SIGNATURE_ALGORITHM` | JWT signature algorithm | PS256 | No |

## Authorization (OpenFGA)

Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ help:
deps:
@echo "==> Installing dependencies..."
go mod download
go install goa.design/goa/$(GOA_VERSION)/cmd/goa@latest
go install goa.design/goa/$(GOA_VERSION)/cmd/goa@v3.22.6
@command -v golangci-lint >/dev/null 2>&1 || { \
echo "==> Installing golangci-lint..."; \
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest; \
Expand Down
2 changes: 1 addition & 1 deletion charts/lfx-v2-project-service/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ apiVersion: v2
name: lfx-v2-project-service
description: LFX Platform V2 Project Service chart
type: application
version: 0.5.3
version: 0.5.4
appVersion: "latest"
2 changes: 2 additions & 0 deletions charts/lfx-v2-project-service/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ spec:
value: {{ .Values.app.skipEtagValidation | quote }}
- name: JWT_AUTH_DISABLED_MOCK_LOCAL_PRINCIPAL
value: {{ .Values.app.jwtAuthDisabledMockLocalPrincipal }}
- name: JWT_SIGNATURE_ALGORITHM
value: {{ .Values.app.jwtSignatureAlgorithm }}
ports:
- containerPort: {{ .Values.service.port }}
name: web
Expand Down
4 changes: 4 additions & 0 deletions charts/lfx-v2-project-service/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ app:
# jwtAuthDisabledMockLocalPrincipal mocks auth for local development to use a set principal
# (only use for local development)
jwtAuthDisabledMockLocalPrincipal: ""
# jwtSignatureAlgorithm is the JWT signature algorithm for token validation
# Supported: PS256 (default), PS384, PS512, RS256, RS384, RS512, ES256, ES384, ES512
# Algorithm names are case-sensitive and must be uppercase
jwtSignatureAlgorithm: "PS256"
# use_oidc_contextualizer is a boolean to determine if the OIDC contextualizer should be used
use_oidc_contextualizer: true

Expand Down
1 change: 1 addition & 0 deletions cmd/project-api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func main() {
JWKSURL: os.Getenv("JWKS_URL"),
Audience: os.Getenv("AUDIENCE"),
MockLocalPrincipal: os.Getenv("JWT_AUTH_DISABLED_MOCK_LOCAL_PRINCIPAL"),
SignatureAlgorithm: os.Getenv("JWT_SIGNATURE_ALGORITHM"),
}
jwtAuth, err := auth.NewJWTAuth(jwtAuthConfig)
if err != nil {
Expand Down
57 changes: 50 additions & 7 deletions internal/infrastructure/auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,40 @@ import (
)

const (
// PS256 is the default for Heimdall's JWT finalizer.
signatureAlgorithm = validator.PS256
defaultIssuer = "heimdall"
defaultAudience = "lfx-v2-project-service"
defaultJWKSURL = "http://heimdall:4457/.well-known/jwks"
// PS256 is the default signature algorithm used when JWT_SIGNATURE_ALGORITHM is not set.
defaultSignatureAlgorithm = validator.PS256
defaultIssuer = "heimdall"
defaultAudience = "lfx-v2-project-service"
defaultJWKSURL = "http://heimdall:4457/.well-known/jwks"
)

// parseSignatureAlgorithm converts the algorithm string to a validator.SignatureAlgorithm.
// Returns PS256 as default if algoString is empty.
// Algorithm names are case-sensitive and must be uppercase (e.g., "PS256").
func parseSignatureAlgorithm(algoString string) (validator.SignatureAlgorithm, error) {
if algoString == "" {
return validator.PS256, nil
}

algorithms := map[string]validator.SignatureAlgorithm{
"PS256": validator.PS256,
"PS384": validator.PS384,
"PS512": validator.PS512,
"RS256": validator.RS256,
"RS384": validator.RS384,
"RS512": validator.RS512,
"ES256": validator.ES256,
"ES384": validator.ES384,
"ES512": validator.ES512,
}

if algo, exists := algorithms[algoString]; exists {
return algo, nil
}

return "", errors.New("unsupported JWT signature algorithm: " + algoString + " (supported: PS256, PS384, PS512, RS256, RS384, RS512, ES256, ES384, ES512)")
}

// JWTAuthConfig holds the configuration parameters for JWT authentication.
type JWTAuthConfig struct {
// JWKSURL is the URL to the JSON Web Key Set endpoint
Expand All @@ -33,6 +60,8 @@ type JWTAuthConfig struct {
Audience string
// MockLocalPrincipal is used for local development to bypass JWT validation
MockLocalPrincipal string
// SignatureAlgorithm is the JWT signature algorithm (e.g., PS256, RS256, ES256)
SignatureAlgorithm string
}

var (
Expand Down Expand Up @@ -67,6 +96,20 @@ type JWTAuth struct {
var _ domain.Authenticator = (*JWTAuth)(nil)

func NewJWTAuth(config JWTAuthConfig) (*JWTAuth, error) {
// Parse signature algorithm
algo, err := parseSignatureAlgorithm(config.SignatureAlgorithm)
if err != nil {
slog.With(constants.ErrKey, err).Error("invalid JWT signature algorithm")
return nil, err
}

// Log algorithm selection (especially if non-default)
if config.SignatureAlgorithm != "" && config.SignatureAlgorithm != "PS256" {
slog.Info("using non-default JWT signature algorithm",
"algorithm", config.SignatureAlgorithm,
)
}

// Set up defaults if not provided
jwksURLStr := config.JWKSURL
if jwksURLStr == "" {
Expand All @@ -92,10 +135,10 @@ func NewJWTAuth(config JWTAuthConfig) (*JWTAuth, error) {
}
provider := jwks.NewCachingProvider(issuer, 5*time.Minute, jwks.WithCustomJWKSURI(jwksURL))

// Set up the JWT validator.
// Set up the JWT validator with selected algorithm.
jwtValidator, err := validator.New(
provider.KeyFunc,
signatureAlgorithm,
algo,
issuer.String(),
[]string{audience},
validator.WithCustomClaims(customClaims),
Expand Down
86 changes: 85 additions & 1 deletion internal/infrastructure/auth/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func TestJWTAuth_Constants(t *testing.T) {
assert.Equal(t, "heimdall", defaultIssuer)
assert.Equal(t, "lfx-v2-project-service", defaultAudience)
assert.Equal(t, "http://heimdall:4457/.well-known/jwks", defaultJWKSURL)
assert.NotNil(t, signatureAlgorithm)
assert.NotNil(t, defaultSignatureAlgorithm)
})
}

Expand Down Expand Up @@ -307,6 +307,38 @@ func TestJWTAuth_ConfigurationHandling(t *testing.T) {
shouldError: false,
description: "should accept mock principal",
},
{
name: "custom signature algorithm ES256",
config: JWTAuthConfig{
SignatureAlgorithm: "ES256",
},
shouldError: false,
description: "should accept valid signature algorithm",
},
{
name: "custom signature algorithm RS256",
config: JWTAuthConfig{
SignatureAlgorithm: "RS256",
},
shouldError: false,
description: "should accept RS256 signature algorithm",
},
{
name: "invalid signature algorithm",
config: JWTAuthConfig{
SignatureAlgorithm: "INVALID",
},
shouldError: true,
description: "should reject invalid signature algorithm",
},
{
name: "lowercase signature algorithm rejected",
config: JWTAuthConfig{
SignatureAlgorithm: "ps256",
},
shouldError: true,
description: "should reject lowercase signature algorithm",
},
}

for _, tt := range tests {
Expand All @@ -326,3 +358,55 @@ func TestJWTAuth_ConfigurationHandling(t *testing.T) {
})
}
}

func TestParseSignatureAlgorithm(t *testing.T) {
tests := []struct {
name string
algorithm string
wantErr bool
}{
// Valid algorithms - PS family
{name: "PS256 valid", algorithm: "PS256", wantErr: false},
{name: "PS384 valid", algorithm: "PS384", wantErr: false},
{name: "PS512 valid", algorithm: "PS512", wantErr: false},

// Valid algorithms - RS family
{name: "RS256 valid", algorithm: "RS256", wantErr: false},
{name: "RS384 valid", algorithm: "RS384", wantErr: false},
{name: "RS512 valid", algorithm: "RS512", wantErr: false},

// Valid algorithms - ES family
{name: "ES256 valid", algorithm: "ES256", wantErr: false},
{name: "ES384 valid", algorithm: "ES384", wantErr: false},
{name: "ES512 valid", algorithm: "ES512", wantErr: false},

// Empty string uses default
{name: "empty defaults to PS256", algorithm: "", wantErr: false},

// Invalid - case sensitivity
{name: "lowercase rejected", algorithm: "ps256", wantErr: true},
{name: "mixed case rejected", algorithm: "Ps256", wantErr: true},

// Invalid - HMAC algorithms not supported
{name: "HS256 unsupported", algorithm: "HS256", wantErr: true},
{name: "HS384 unsupported", algorithm: "HS384", wantErr: true},
{name: "HS512 unsupported", algorithm: "HS512", wantErr: true},

// Invalid - unknown algorithms
{name: "unknown algorithm", algorithm: "UNKNOWN", wantErr: true},
{name: "typo", algorithm: "PS265", wantErr: true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
algo, err := parseSignatureAlgorithm(tt.algorithm)
if tt.wantErr {
assert.Error(t, err, "expected error for algorithm %q", tt.algorithm)
assert.Empty(t, algo, "expected empty algorithm for %q", tt.algorithm)
} else {
assert.NoError(t, err, "unexpected error for algorithm %q", tt.algorithm)
assert.NotEmpty(t, algo, "expected valid algorithm for %q", tt.algorithm)
}
})
}
}