From 0656fabf2ab7b583a889a8022e29a3b660168bab Mon Sep 17 00:00:00 2001 From: Dave Lee Date: Wed, 18 Sep 2024 17:32:53 -0400 Subject: [PATCH] add downloader version that allows providing an auth header Signed-off-by: Dave Lee --- core/http/app_test.go | 20 ++++++++++++-------- pkg/downloader/uri.go | 15 ++++++++++++++- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/core/http/app_test.go b/core/http/app_test.go index 5ef1e2548b44..22994e98d8e9 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -32,6 +32,7 @@ import ( ) const apiKey = "joshua" +const bearerKey = "Bearer " + apiKey const testPrompt = `### System: You are an AI assistant that follows instruction extremely well. Help as much as you can. @@ -74,14 +75,15 @@ func getModelStatus(url string) (response map[string]interface{}) { return } -func getModels(url string) (response []gallery.GalleryModel) { +func getModels(url string) ([]gallery.GalleryModel, error) { + response := []gallery.GalleryModel{} uri := downloader.URI(url) // TODO: No tests currently seem to exercise file:// urls. Fix? - uri.DownloadAndUnmarshal("", func(url string, i []byte) error { + err := uri.DownloadWithAuthorizationAndUnmarshal("", bearerKey, func(url string, i []byte) error { // Unmarshal YAML data into a struct return json.Unmarshal(i, &response) }) - return + return response, err } func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { @@ -103,7 +105,7 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[ return } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Authorization", bearerKey) // Make the request client := &http.Client{} @@ -143,7 +145,7 @@ func postRequestJSON[B any](url string, bodyJson *B) error { } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) @@ -179,7 +181,7 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson * } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) @@ -341,7 +343,8 @@ var _ = Describe("API test", func() { Context("Applying models", func() { It("applies models from a gallery", func() { - models := getModels("http://127.0.0.1:9090/models/available") + models, err := getModels("http://127.0.0.1:9090/models/available") + Expect(err).To(BeNil()) Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) @@ -374,7 +377,8 @@ var _ = Describe("API test", func() { Expect(content["backend"]).To(Equal("bert-embeddings")) Expect(content["foo"]).To(Equal("bar")) - models = getModels("http://127.0.0.1:9090/models/available") + models, err = getModels("http://127.0.0.1:9090/models/available") + Expect(err).To(BeNil()) Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2"))) diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 7fedd6461205..d74d353d8ac5 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -32,6 +32,10 @@ const ( type URI string func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error { + return uri.DownloadWithAuthorizationAndUnmarshal(basePath, "", f) +} + +func (uri URI) DownloadWithAuthorizationAndUnmarshal(basePath string, authorization string, f func(url string, i []byte) error) error { url := uri.ResolveURL() if strings.HasPrefix(url, LocalPrefix) { @@ -63,7 +67,16 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte } // Send a GET request to the URL - response, err := http.Get(url) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + if authorization != "" { + req.Header.Add("Authorization", authorization) + } + + response, err := http.DefaultClient.Do(req) if err != nil { return err }