From f372be88bf2c23e333eaa0a5e16aa22883a4af6f Mon Sep 17 00:00:00 2001 From: 7h3-3mp7y-m4n Date: Tue, 15 Oct 2024 14:03:18 +0530 Subject: [PATCH] added the unit test and updated the function for remote url Signed-off-by: 7h3-3mp7y-m4n --- uriget/example_test.go | 23 +++++++++++++++++++++++ uriget/uriget.go | 14 ++++++++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/uriget/example_test.go b/uriget/example_test.go index ae490df..ee8c1d7 100644 --- a/uriget/example_test.go +++ b/uriget/example_test.go @@ -3,8 +3,11 @@ package uriget import ( "context" "fmt" + "log" "net/http" "net/url" + "os" + "testing" ) func ExampleGetFile_local() { @@ -51,3 +54,23 @@ func ExampleWithHttpClient() { fmt.Println(err) // Output: failed to make get request: Get "https://example.com": no proxy } + +func TestGetOci(t *testing.T) { + logger := log.New(os.Stdout, "TEST: ", log.LstdFlags) + o := &options{ + tempDir: t.TempDir(), + logger: logger, + } + + testUrl := "oci://ghcr.io/mathieu-benoit/policies:0.1.0" + + u, err := url.Parse(testUrl) + if err != nil { + t.Fatalf("failed to parse URL: %v", err) + } + + _, err = o.getOci(context.Background(), u) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } +} diff --git a/uriget/uriget.go b/uriget/uriget.go index ccbf7ff..9e9a0fe 100644 --- a/uriget/uriget.go +++ b/uriget/uriget.go @@ -246,10 +246,13 @@ func (o *options) getOci(ctx context.Context, u *url.URL) ([]byte, error) { return nil, fmt.Errorf("invalid OCI URL format") } registry := parts[0] - repo := strings.Join(parts[1:], "/") + repo := strings.Join(parts[1:len(parts)-1], "/") tag := "latest" - if u.Fragment != "" { - tag = u.Fragment + lastPart := parts[len(parts)-1] + if strings.Contains(lastPart, ":") { + split := strings.Split(lastPart, ":") + repo = strings.Join(parts[1:len(parts)-1], "/") + "/" + split[0] + tag = split[1] } store, err := oci.New(o.tempDir) if err != nil { @@ -260,10 +263,13 @@ func (o *options) getOci(ctx context.Context, u *url.URL) ([]byte, error) { if err != nil { return nil, fmt.Errorf("failed to connect to remote repository: %w", err) } + if strings.HasPrefix(repoUrl, "localhost:") || strings.HasPrefix(repoUrl, "127.0.0.1:") { + remoteRepo.PlainHTTP = true + } manifestDescriptor, err := oras.Copy(ctx, remoteRepo, tag, store, tag, oras.DefaultCopyOptions) if err != nil { return nil, fmt.Errorf("failed to pull OCI image: %w", err) } - o.logger.Printf("Pulled OCI image: %s with manifest descriptor : %v", u.String(), manifestDescriptor.Digest) + o.logger.Printf("Pulled OCI image: %s with manifest descriptor: %v", u.String(), manifestDescriptor.Digest) return []byte(manifestDescriptor.Digest), nil }