Skip to content

Commit

Permalink
add downloader version that allows providing an auth header
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Lee <[email protected]>
  • Loading branch information
dave-gray101 committed Sep 18, 2024
1 parent a882b1c commit 0656fab
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
20 changes: 12 additions & 8 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
)

const apiKey = "joshua"
const bearerKey = "Bearer " + apiKey

const testPrompt = `### System:
You are an AI assistant that follows instruction extremely well. Help as much as you can.
Expand Down Expand Up @@ -74,14 +75,15 @@ func getModelStatus(url string) (response map[string]interface{}) {
return
}

func getModels(url string) (response []gallery.GalleryModel) {
func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
err := uri.DownloadWithAuthorizationAndUnmarshal("", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
return
return response, err
}

func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
Expand All @@ -103,7 +105,7 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Authorization", bearerKey)

// Make the request
client := &http.Client{}
Expand Down Expand Up @@ -143,7 +145,7 @@ func postRequestJSON[B any](url string, bodyJson *B) error {
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Authorization", bearerKey)

client := &http.Client{}
resp, err := client.Do(req)
Expand Down Expand Up @@ -179,7 +181,7 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Authorization", bearerKey)

client := &http.Client{}
resp, err := client.Do(req)
Expand Down Expand Up @@ -341,7 +343,8 @@ var _ = Describe("API test", func() {
Context("Applying models", func() {

It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/available")
models, err := getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
Expand Down Expand Up @@ -374,7 +377,8 @@ var _ = Describe("API test", func() {
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))

models = getModels("http://127.0.0.1:9090/models/available")
models, err = getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
Expand Down
15 changes: 14 additions & 1 deletion pkg/downloader/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ const (
type URI string

func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error {
return uri.DownloadWithAuthorizationAndUnmarshal(basePath, "", f)
}

func (uri URI) DownloadWithAuthorizationAndUnmarshal(basePath string, authorization string, f func(url string, i []byte) error) error {
url := uri.ResolveURL()

if strings.HasPrefix(url, LocalPrefix) {
Expand Down Expand Up @@ -63,7 +67,16 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte
}

// Send a GET request to the URL
response, err := http.Get(url)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
if authorization != "" {
req.Header.Add("Authorization", authorization)
}

response, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
Expand Down

0 comments on commit 0656fab

Please sign in to comment.