Skip to content

Commit f063ec9

Browse files
authored
Merge pull request #163 from EmmEff/oci-artifact-name-mapping
Resolve short named artifact issue
2 parents b5db2aa + 425d5ac commit f063ec9

File tree

2 files changed

+61
-27
lines changed

2 files changed

+61
-27
lines changed

client/oci.go

+30-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2018-2022, Sylabs Inc. All rights reserved.
1+
// Copyright (c) 2018-2023, Sylabs Inc. All rights reserved.
22
// This software is licensed under a 3-clause BSD license. Please consult the
33
// LICENSE.md file distributed with the sources of this project regarding your
44
// rights to use or distribute this software.
@@ -30,8 +30,12 @@ const mediaTypeSIFLayer = "application/vnd.sylabs.sif.layer.v1.sif"
3030
// ociRegistryAuth uses Cloud Library endpoint to determine if artifact can be pulled
3131
// directly from OCI registry.
3232
//
33-
// Returns url and credentials (if applicable) for that url.
34-
func (c *Client) ociRegistryAuth(ctx context.Context, name string, accessTypes []accessType) (*url.URL, *bearerTokenCredentials, error) {
33+
// Returns url, credentials (if applicable) for that url, and mapped name.
34+
//
35+
// The mapped name can be the same value as 'name' or mapped to a fully-qualified name
36+
// (ie. from "alpine" to "library/default/alpine") if supported by cloud library server.
37+
// It will never be an empty string ("")
38+
func (c *Client) ociRegistryAuth(ctx context.Context, name string, accessTypes []accessType) (*url.URL, *bearerTokenCredentials, string, error) {
3539
// Build raw query string to get token for specified namespace and access
3640
v := url.Values{}
3741
v.Set("namespace", name)
@@ -45,7 +49,7 @@ func (c *Client) ociRegistryAuth(ctx context.Context, name string, accessTypes [
4549

4650
req, err := c.newRequest(ctx, http.MethodGet, "v1/oci-redirect", v.Encode(), nil)
4751
if err != nil {
48-
return nil, nil, err
52+
return nil, nil, "", err
4953
}
5054

5155
if c.UserAgent != "" {
@@ -54,30 +58,35 @@ func (c *Client) ociRegistryAuth(ctx context.Context, name string, accessTypes [
5458

5559
res, err := c.HTTPClient.Do(req)
5660
if err != nil {
57-
return nil, nil, fmt.Errorf("error determining direct OCI registry access: %w", err)
61+
return nil, nil, "", fmt.Errorf("error determining direct OCI registry access: %w", err)
5862
}
5963
defer res.Body.Close()
6064

6165
if res.StatusCode != http.StatusOK {
62-
return nil, nil, fmt.Errorf("error determining direct OCI registry access: %w", err)
66+
return nil, nil, "", fmt.Errorf("error determining direct OCI registry access: %w", err)
6367
}
6468

6569
type ociDownloadRedirectResponse struct {
6670
Token string `json:"token"`
6771
RegistryURI string `json:"url"`
72+
Name string `json:"name"`
6873
}
6974

7075
var ociArtifactSpec ociDownloadRedirectResponse
7176

7277
if err := json.NewDecoder(res.Body).Decode(&ociArtifactSpec); err != nil {
73-
return nil, nil, fmt.Errorf("error decoding direct OCI registry access response: %w", err)
78+
return nil, nil, "", fmt.Errorf("error decoding direct OCI registry access response: %w", err)
79+
}
80+
81+
if ociArtifactSpec.Name != "" && ociArtifactSpec.Name != name {
82+
name = ociArtifactSpec.Name
7483
}
7584

7685
endpoint, err := url.Parse(ociArtifactSpec.RegistryURI)
7786
if err != nil {
78-
return nil, nil, fmt.Errorf("malformed OCI registry URI %v: %v", ociArtifactSpec.RegistryURI, err)
87+
return nil, nil, "", fmt.Errorf("malformed OCI registry URI %v: %v", ociArtifactSpec.RegistryURI, err)
7988
}
80-
return endpoint, &bearerTokenCredentials{authToken: ociArtifactSpec.Token}, nil
89+
return endpoint, &bearerTokenCredentials{authToken: ociArtifactSpec.Token}, name, nil
8190
}
8291

8392
const (
@@ -621,21 +630,27 @@ func (r *ociRegistry) getImageConfig(ctx context.Context, creds credentials, nam
621630

622631
var errOCIDownloadNotSupported = errors.New("not supported")
623632

624-
func (c *Client) newOCIRegistry(ctx context.Context, name string, accessTypes []accessType) (*ociRegistry, *bearerTokenCredentials, error) {
633+
func (c *Client) newOCIRegistry(ctx context.Context, name string, accessTypes []accessType) (*ociRegistry, *bearerTokenCredentials, string, error) {
625634
// Attempt to obtain (direct) OCI registry auth token
626-
registryURI, creds, err := c.ociRegistryAuth(ctx, name, accessTypes)
635+
originalName := name
636+
637+
registryURI, creds, name, err := c.ociRegistryAuth(ctx, name, accessTypes)
627638
if err != nil {
628-
return nil, nil, errOCIDownloadNotSupported
639+
return nil, nil, "", errOCIDownloadNotSupported
629640
}
630641

631642
// Download directly from OCI registry
632643
c.Logger.Logf("Using OCI registry endpoint %v", registryURI)
633644

634-
return &ociRegistry{baseURL: registryURI, httpClient: c.HTTPClient, logger: c.Logger}, creds, nil
645+
if name != "" && originalName != name {
646+
c.Logger.Logf("OCI artifact name \"%v\" mapped to \"%v\"", originalName, name)
647+
}
648+
649+
return &ociRegistry{baseURL: registryURI, httpClient: c.HTTPClient, logger: c.Logger}, creds, name, nil
635650
}
636651

637652
func (c *Client) ociDownloadImage(ctx context.Context, arch, name, tag string, w io.WriterAt, spec *Downloader, pb ProgressBar) error {
638-
reg, creds, err := c.newOCIRegistry(ctx, name, []accessType{accessTypePull})
653+
reg, creds, name, err := c.newOCIRegistry(ctx, name, []accessType{accessTypePull})
639654
if err != nil {
640655
return err
641656
}
@@ -663,7 +678,7 @@ func (e *unexpectedImageDigest) Error() string {
663678
}
664679

665680
func (c *Client) ociUploadImage(ctx context.Context, r io.Reader, size int64, name, arch string, tags []string, description, hash string, callback UploadCallback) error {
666-
reg, creds, err := c.newOCIRegistry(ctx, name, []accessType{accessTypePull, accessTypePush})
681+
reg, creds, name, err := c.newOCIRegistry(ctx, name, []accessType{accessTypePull, accessTypePush})
667682
if err != nil {
668683
return err
669684
}

client/oci_test.go

+31-12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
// Copyright (c) 2022-2023, Sylabs Inc. All rights reserved.
2+
// This software is licensed under a 3-clause BSD license. Please consult the
3+
// LICENSE.md file distributed with the sources of this project regarding your
4+
// rights to use or distribute this software.
5+
16
package client
27

38
import (
@@ -14,9 +19,13 @@ func TestOciRegistryAuth(t *testing.T) {
1419
tests := []struct {
1520
name string
1621
directOciDownloadSupported bool
22+
ref string
23+
mappedRef string
1724
}{
18-
{"Basic", true},
19-
{"NotSupported", false},
25+
{"Basic", true, "entity/collection/container", "entity/collection/container"},
26+
{"TwoElements", true, "entity/container", "entity/container"},
27+
{"ShortName", true, "alpine", "library/default/alpine"},
28+
{"NotSupported", false, "", ""},
2029
}
2130

2231
for _, tt := range tests {
@@ -34,9 +43,11 @@ func TestOciRegistryAuth(t *testing.T) {
3443
response := struct {
3544
Token string `json:"token"`
3645
RegistryURI string `json:"url"`
46+
Name string `json:"name"`
3747
}{
3848
Token: "xxx",
3949
RegistryURI: ociRegistryURI,
50+
Name: tt.mappedRef,
4051
}
4152

4253
if v := r.URL.Query().Get("namespace"); v == "" {
@@ -53,30 +64,38 @@ func TestOciRegistryAuth(t *testing.T) {
5364
}))
5465
defer testShimSrv.Close()
5566

56-
c, err := NewClient(&Config{
67+
clientCfg := &Config{
5768
BaseURL: testShimSrv.URL,
5869
Logger: &stdLogger{},
5970
UserAgent: "scs-library-client-unit-tests/1.0",
60-
})
71+
}
72+
73+
c, err := NewClient(clientCfg)
6174
if err != nil {
6275
t.Fatalf("error initializing client: %v", err)
6376
}
6477

65-
u, creds, err := c.ociRegistryAuth(context.Background(), "testproject/testrepo", []accessType{accessTypePull})
78+
u, creds, name, err := c.ociRegistryAuth(context.Background(), tt.ref, []accessType{accessTypePull})
6679
if tt.directOciDownloadSupported && err != nil {
6780
t.Fatalf("error getting OCI registry credentials: %v", err)
6881
} else if !tt.directOciDownloadSupported && err == nil {
6982
t.Fatal("unexpected success")
7083
}
7184

72-
if tt.directOciDownloadSupported {
73-
if got, want := u.String(), ociRegistryURI; got != want {
74-
t.Fatalf("unexpected OCI registry URI: got %v, want %v", got, want)
75-
}
85+
if !tt.directOciDownloadSupported {
86+
return
87+
}
7688

77-
if creds == nil {
78-
t.Fatal("expecting bearer token credential")
79-
}
89+
if got, want := name, tt.mappedRef; got != want {
90+
t.Fatalf("unexpected OCI artifact name: got %v, want %v", got, want)
91+
}
92+
93+
if got, want := u.String(), ociRegistryURI; got != want {
94+
t.Fatalf("unexpected OCI registry URI: got %v, want %v", got, want)
95+
}
96+
97+
if creds == nil {
98+
t.Fatal("expecting bearer token credential")
8099
}
81100
})
82101
}

0 commit comments

Comments
 (0)