Skip to content

Commit

Permalink
fix: parameterize hostname immutable
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos Alexandro Becker <[email protected]>
  • Loading branch information
caarlos0 committed Aug 21, 2024
1 parent d7df96f commit cf196b0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
28 changes: 20 additions & 8 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,29 +178,41 @@ func NewDefaultV2Config(ctx context.Context) (awsv2.Config, error) {
// - profile: The shared config profile to use; sets SharedConfigProfile.
// - endpoint: The AWS service endpoint to send HTTP request.
func V2ConfigFromURLParams(ctx context.Context, q url.Values) (awsv2.Config, error) {
var endpoint string
var hostnameImmutable bool
var opts []func(*awsv2cfg.LoadOptions) error
for param, values := range q {
value := values[0]
switch param {
case "awsHostnameImmutable":
var err error
hostnameImmutable, err = strconv.ParseBool(value)
if err != nil {
return awsv2.Config{}, err
}
case "region":
opts = append(opts, awsv2cfg.WithRegion(value))
case "endpoint":
endpoint = value
case "profile":
opts = append(opts, awsv2cfg.WithSharedConfigProfile(value))
case "awssdk":
// ignore, should be handled before this
default:
return awsv2.Config{}, fmt.Errorf("unknown query parameter %q", param)
}

if endpoint != "" {
customResolver := awsv2.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (awsv2.Endpoint, error) {
return awsv2.Endpoint{
PartitionID: "aws",
URL: value,
URL: endpoint,
SigningRegion: region,
HostnameImmutable: true,
HostnameImmutable: hostnameImmutable,
}, nil
})
opts = append(opts, awsv2cfg.WithEndpointResolverWithOptions(customResolver))
case "profile":
opts = append(opts, awsv2cfg.WithSharedConfigProfile(value))
case "awssdk":
// ignore, should be handled before this
default:
return awsv2.Config{}, fmt.Errorf("unknown query parameter %q", param)
}
}
return awsv2cfg.LoadDefaultConfig(ctx, opts...)
Expand Down
37 changes: 33 additions & 4 deletions aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package aws_test
import (
"context"
"net/url"
"reflect"
"testing"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws"
"github.com/google/go-cmp/cmp"
gcaws "gocloud.dev/aws"
Expand Down Expand Up @@ -147,12 +149,16 @@ func TestUseV2(t *testing.T) {
}

func TestV2ConfigFromURLParams(t *testing.T) {
const service = "s3"
const region = "us-east-1"
const partitionID = "aws"
ctx := context.Background()
tests := []struct {
name string
query url.Values
wantRegion string
wantErr bool
name string
query url.Values
wantRegion string
wantErr bool
wantEndpoint *awsv2.Endpoint
}{
{
name: "No overrides",
Expand All @@ -168,6 +174,16 @@ func TestV2ConfigFromURLParams(t *testing.T) {
query: url.Values{"region": {"my_region"}},
wantRegion: "my_region",
},
{
name: "Endpoint and hostname immutable",
query: url.Values{"endpoint": {"foo"}, "awsHostnameImmutable": {"true"}},
wantEndpoint: &awsv2.Endpoint{
PartitionID: partitionID,
SigningRegion: region,
URL: "foo",
HostnameImmutable: true,
},
},
// Can't test "profile", since AWS validates that the profile exists.
}

Expand All @@ -184,6 +200,19 @@ func TestV2ConfigFromURLParams(t *testing.T) {
if test.wantRegion != "" && got.Region != test.wantRegion {
t.Errorf("got region %q, want %q", got.Region, test.wantRegion)
}

if test.wantEndpoint != nil {
if got.EndpointResolverWithOptions == nil {
t.Fatalf("expected an EndpointResolverWithOptions, got nil")
}
gotE, err := got.EndpointResolverWithOptions.ResolveEndpoint(service, region)
if err != nil {
return
}
if !reflect.DeepEqual(gotE, *test.wantEndpoint) {
t.Errorf("got endpoint %+v, want %+v", gotE, *test.wantEndpoint)
}
}
})
}
}

0 comments on commit cf196b0

Please sign in to comment.