diff --git a/core/http/app.go b/core/http/app.go index a2d8b87a2f73..47d89a106561 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -87,6 +87,8 @@ func API(application *application.Application) (*fiber.App, error) { router := fiber.New(fiberCfg) + router.Use(middleware.StripPathPrefix()) + router.Hooks().OnListen(func(listenData fiber.ListenData) error { scheme := "http" if listenData.TLS { diff --git a/core/http/app_test.go b/core/http/app_test.go index 7c57ba21a701..d27851da761e 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -237,6 +237,35 @@ func postInvalidRequest(url string) (error, int) { return nil, resp.StatusCode } +func getRequest(url string, header http.Header) (error, int, []byte) { + + req, err := http.NewRequest("GET", url, bytes.NewBufferString("")) + if err != nil { + return err, -1, nil + } + + req.Header = header + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err, -1, nil + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return err, -1, nil + } + + if resp.StatusCode < 200 || resp.StatusCode >= 400 { + return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode, nil + } + + return nil, resp.StatusCode, body +} + const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml` //go:embed backend-assets/* @@ -345,6 +374,20 @@ var _ = Describe("API test", func() { }) }) + Context("URL routing Tests", func() { + It("Should support reverse-proxy", func() { + + err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", map[string][]string{ + "X-Forwarded-Proto": []string{"https"}, + "X-Forwarded-Host": []string{"example.org"}, + "X-Forwarded-Prefix": []string{"/myprefix/"}, + }) + Expect(err).Should(NotOccur()) + Expect(sc).To(Equal(200), "status code") + Expect(string(body)).To(Contain(``), "body") + }) + }) + Context("Applying models", func() { It("applies models from a gallery", func() { diff --git a/core/http/elements/buttons.go b/core/http/elements/buttons.go index 7cfe968ffe8b..2364a0b31669 100644 --- a/core/http/elements/buttons.go +++ b/core/http/elements/buttons.go @@ -16,7 +16,7 @@ func installButton(galleryName string) elem.Node { "class": "float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong", "hx-swap": "outerHTML", // post the Model ID as param - "hx-post": "/browse/install/model/" + galleryName, + "hx-post": "browse/install/model/" + galleryName, }, elem.I( attrs.Props{ @@ -36,7 +36,7 @@ func reInstallButton(galleryName string) elem.Node { "hx-target": "#action-div-" + dropBadChars(galleryName), "hx-swap": "outerHTML", // post the Model ID as param - "hx-post": "/browse/install/model/" + galleryName, + "hx-post": "browse/install/model/" + galleryName, }, elem.I( attrs.Props{ @@ -80,7 +80,7 @@ func deleteButton(galleryID string) elem.Node { "hx-target": "#action-div-" + dropBadChars(galleryID), "hx-swap": "outerHTML", // post the Model ID as param - "hx-post": "/browse/delete/model/" + galleryID, + "hx-post": "browse/delete/model/" + galleryID, }, elem.I( attrs.Props{ diff --git a/core/http/elements/gallery.go b/core/http/elements/gallery.go index c9d7a1cb5be2..5ab685080755 100644 --- a/core/http/elements/gallery.go +++ b/core/http/elements/gallery.go @@ -47,7 +47,7 @@ func searchableElement(text, icon string) elem.Node { // "value": text, //"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2", "href": "#!", - "hx-post": "/browse/search/models", + "hx-post": "browse/search/models", "hx-target": "#search-results", // TODO: this doesn't work // "hx-vals": `{ \"search\": \"` + text + `\" }`, diff --git a/core/http/elements/progressbar.go b/core/http/elements/progressbar.go index c9af98d9a5ca..7dc340b24ad1 100644 --- a/core/http/elements/progressbar.go +++ b/core/http/elements/progressbar.go @@ -64,7 +64,7 @@ func StartProgressBar(uid, progress, text string) string { return elem.Div( attrs.Props{ "hx-trigger": "done", - "hx-get": "/browse/job/" + uid, + "hx-get": "browse/job/" + uid, "hx-swap": "outerHTML", "hx-target": "this", }, @@ -77,7 +77,7 @@ func StartProgressBar(uid, progress, text string) string { }, elem.Text(bluemonday.StrictPolicy().Sanitize(text)), //Perhaps overly defensive elem.Div(attrs.Props{ - "hx-get": "/browse/job/progress/" + uid, + "hx-get": "browse/job/progress/" + uid, "hx-trigger": "every 600ms", "hx-target": "this", "hx-swap": "innerHTML", diff --git a/core/http/endpoints/explorer/dashboard.go b/core/http/endpoints/explorer/dashboard.go index 9c731d9a4f78..3c8966819c9c 100644 --- a/core/http/endpoints/explorer/dashboard.go +++ b/core/http/endpoints/explorer/dashboard.go @@ -6,6 +6,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/explorer" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/internal" ) @@ -14,6 +15,7 @@ func Dashboard() func(*fiber.Ctx) error { summary := fiber.Map{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), + "BaseURL": utils.BaseURL(c), } if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 { diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go index 23c5d4b8d29d..5b2968f43511 100644 --- a/core/http/endpoints/localai/gallery.go +++ b/core/http/endpoints/localai/gallery.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/rs/zerolog/log" @@ -82,7 +83,8 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe Galleries: mgs.galleries, ConfigURL: input.ConfigURL, } - return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) + + return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())}) } } @@ -105,7 +107,7 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib return err } - return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) + return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())}) } } diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go index a14768861396..57cf88095e2f 100644 --- a/core/http/endpoints/localai/welcome.go +++ b/core/http/endpoints/localai/welcome.go @@ -4,6 +4,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" @@ -32,6 +33,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig, summary := fiber.Map{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), + "BaseURL": utils.BaseURL(c), "Models": modelsWithoutConfig, "ModelsConfig": backendConfigs, "GalleryConfig": galleryConfigs, diff --git a/core/http/explorer.go b/core/http/explorer.go index bdcb93b16d55..36609add6b35 100644 --- a/core/http/explorer.go +++ b/core/http/explorer.go @@ -7,6 +7,7 @@ import ( "github.com/gofiber/fiber/v2/middleware/favicon" "github.com/gofiber/fiber/v2/middleware/filesystem" "github.com/mudler/LocalAI/core/explorer" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" ) @@ -22,6 +23,7 @@ func Explorer(db *explorer.Database) *fiber.App { app := fiber.New(fiberCfg) + app.Use(middleware.StripPathPrefix()) routes.RegisterExplorerRoutes(app, db) httpFS := http.FS(embedDirStatic) diff --git a/core/http/middleware/strippathprefix.go b/core/http/middleware/strippathprefix.go new file mode 100644 index 000000000000..5c45d55d3645 --- /dev/null +++ b/core/http/middleware/strippathprefix.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "strings" + + "github.com/gofiber/fiber/v2" +) + +// StripPathPrefix returns a middleware that strips a path prefix from the request path. +// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header. +func StripPathPrefix() fiber.Handler { + return func(c *fiber.Ctx) error { + for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] { + if prefix != "" { + path := c.Path() + pos := len(prefix) + + if prefix[pos-1] == '/' { + pos-- + } else { + prefix += "/" + } + + if strings.HasPrefix(path, prefix) { + c.Path(path[pos:]) + break + } else if prefix[:pos] == path { + c.Redirect(prefix) + return nil + } + } + } + + return c.Next() + } +} diff --git a/core/http/middleware/strippathprefix_test.go b/core/http/middleware/strippathprefix_test.go new file mode 100644 index 000000000000..529f815f71c0 --- /dev/null +++ b/core/http/middleware/strippathprefix_test.go @@ -0,0 +1,121 @@ +package middleware + +import ( + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/require" +) + +func TestStripPathPrefix(t *testing.T) { + var actualPath string + + app := fiber.New() + + app.Use(StripPathPrefix()) + + app.Get("/hello/world", func(c *fiber.Ctx) error { + actualPath = c.Path() + return nil + }) + + app.Get("/", func(c *fiber.Ctx) error { + actualPath = c.Path() + return nil + }) + + for _, tc := range []struct { + name string + path string + prefixHeader []string + expectStatus int + expectPath string + }{ + { + name: "without prefix and header", + path: "/hello/world", + expectStatus: 200, + expectPath: "/hello/world", + }, + { + name: "without prefix and headers on root path", + path: "/", + expectStatus: 200, + expectPath: "/", + }, + { + name: "without prefix but header", + path: "/hello/world", + prefixHeader: []string{"/otherprefix/"}, + expectStatus: 200, + expectPath: "/hello/world", + }, + { + name: "with prefix but non-matching header", + path: "/prefix/hello/world", + prefixHeader: []string{"/otherprefix/"}, + expectStatus: 404, + }, + { + name: "with prefix and matching header", + path: "/myprefix/hello/world", + prefixHeader: []string{"/myprefix/"}, + expectStatus: 200, + expectPath: "/hello/world", + }, + { + name: "with prefix and 1st header matching", + path: "/myprefix/hello/world", + prefixHeader: []string{"/myprefix/", "/otherprefix/"}, + expectStatus: 200, + expectPath: "/hello/world", + }, + { + name: "with prefix and 2nd header matching", + path: "/myprefix/hello/world", + prefixHeader: []string{"/otherprefix/", "/myprefix/"}, + expectStatus: 200, + expectPath: "/hello/world", + }, + { + name: "with prefix and header not ending with slash", + path: "/myprefix/hello/world", + prefixHeader: []string{"/myprefix"}, + expectStatus: 200, + expectPath: "/hello/world", + }, + { + name: "with prefix and non-matching header not ending with slash", + path: "/myprefix-suffix/hello/world", + prefixHeader: []string{"/myprefix"}, + expectStatus: 404, + }, + { + name: "redirect when prefix does not end with a slash", + path: "/myprefix", + prefixHeader: []string{"/myprefix"}, + expectStatus: 302, + expectPath: "/myprefix/", + }, + } { + t.Run(tc.name, func(t *testing.T) { + actualPath = "" + req := httptest.NewRequest("GET", tc.path, nil) + if tc.prefixHeader != nil { + req.Header["X-Forwarded-Prefix"] = tc.prefixHeader + } + + resp, err := app.Test(req, -1) + + require.NoError(t, err) + require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code") + + if tc.expectStatus == 200 { + require.Equal(t, tc.expectPath, actualPath, "rewritten path") + } else if tc.expectStatus == 302 { + require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location") + } + }) + } +} diff --git a/core/http/render.go b/core/http/render.go index 205f7ca3e5c8..2f889f57e177 100644 --- a/core/http/render.go +++ b/core/http/render.go @@ -10,6 +10,7 @@ import ( "github.com/gofiber/fiber/v2" fiberhtml "github.com/gofiber/template/html/v2" "github.com/microcosm-cc/bluemonday" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/russross/blackfriday" ) @@ -26,7 +27,9 @@ func notFoundHandler(c *fiber.Ctx) error { }) } else { // The client expects an HTML response - return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{}) + return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{ + "BaseURL": utils.BaseURL(c), + }) } } diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index 6ea38f35392f..92d20544b053 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -6,20 +6,21 @@ import ( "sort" "strings" - "github.com/microcosm-cc/bluemonday" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/elements" "github.com/mudler/LocalAI/core/http/endpoints/localai" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/xsync" - "github.com/rs/zerolog/log" "github.com/gofiber/fiber/v2" "github.com/google/uuid" + "github.com/microcosm-cc/bluemonday" + "github.com/rs/zerolog/log" ) type modelOpCache struct { @@ -91,6 +92,7 @@ func RegisterUIRoutes(app *fiber.App, app.Get("/p2p", func(c *fiber.Ctx) error { summary := fiber.Map{ "Title": "LocalAI - P2P dashboard", + "BaseURL": utils.BaseURL(c), "Version": internal.PrintableVersion(), //"Nodes": p2p.GetAvailableNodes(""), //"FederatedNodes": p2p.GetAvailableNodes(p2p.FederatedID), @@ -149,6 +151,7 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": "LocalAI - Models", + "BaseURL": utils.BaseURL(c), "Version": internal.PrintableVersion(), "Models": template.HTML(elements.ListModels(models, processingModels, galleryService)), "Repositories": appConfig.Galleries, @@ -308,6 +311,7 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": "LocalAI - Chat with " + c.Params("model"), + "BaseURL": utils.BaseURL(c), "ModelsConfig": backendConfigs, "Model": c.Params("model"), "Version": internal.PrintableVersion(), @@ -323,11 +327,12 @@ func RegisterUIRoutes(app *fiber.App, if len(backendConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect("/") + return c.Redirect(utils.BaseURL(c)) } summary := fiber.Map{ "Title": "LocalAI - Talk", + "BaseURL": utils.BaseURL(c), "ModelsConfig": backendConfigs, "Model": backendConfigs[0], "IsP2PEnabled": p2p.IsP2PEnabled(), @@ -344,11 +349,12 @@ func RegisterUIRoutes(app *fiber.App, if len(backendConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect("/") + return c.Redirect(utils.BaseURL(c)) } summary := fiber.Map{ "Title": "LocalAI - Chat with " + backendConfigs[0], + "BaseURL": utils.BaseURL(c), "ModelsConfig": backendConfigs, "Model": backendConfigs[0], "Version": internal.PrintableVersion(), @@ -364,6 +370,7 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": "LocalAI - Generate images with " + c.Params("model"), + "BaseURL": utils.BaseURL(c), "ModelsConfig": backendConfigs, "Model": c.Params("model"), "Version": internal.PrintableVersion(), @@ -380,11 +387,12 @@ func RegisterUIRoutes(app *fiber.App, if len(backendConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect("/") + return c.Redirect(utils.BaseURL(c)) } summary := fiber.Map{ "Title": "LocalAI - Generate images with " + backendConfigs[0].Name, + "BaseURL": utils.BaseURL(c), "ModelsConfig": backendConfigs, "Model": backendConfigs[0].Name, "Version": internal.PrintableVersion(), @@ -400,6 +408,7 @@ func RegisterUIRoutes(app *fiber.App, summary := fiber.Map{ "Title": "LocalAI - Generate images with " + c.Params("model"), + "BaseURL": utils.BaseURL(c), "ModelsConfig": backendConfigs, "Model": c.Params("model"), "Version": internal.PrintableVersion(), @@ -416,11 +425,12 @@ func RegisterUIRoutes(app *fiber.App, if len(backendConfigs) == 0 { // If no model is available redirect to the index which suggests how to install models - return c.Redirect("/") + return c.Redirect(utils.BaseURL(c)) } summary := fiber.Map{ "Title": "LocalAI - Generate audio with " + backendConfigs[0].Name, + "BaseURL": utils.BaseURL(c), "ModelsConfig": backendConfigs, "Model": backendConfigs[0].Name, "IsP2PEnabled": p2p.IsP2PEnabled(), diff --git a/core/http/static/assets/font1.css b/core/http/static/assets/font1.css index f46cc3ff10ae..c640d54f72fa 100644 --- a/core/http/static/assets/font1.css +++ b/core/http/static/assets/font1.css @@ -7,33 +7,33 @@ https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&family=Roboto:wg font-style: normal; font-weight: 400; font-display: swap; - src: url(/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuLyfMZg.ttf) format('truetype'); + src: url(./UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuLyfMZg.ttf) format('truetype'); } @font-face { font-family: 'Inter'; font-style: normal; font-weight: 600; font-display: swap; - src: url(/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuGKYMZg.ttf) format('truetype'); + src: url(./UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuGKYMZg.ttf) format('truetype'); } @font-face { font-family: 'Inter'; font-style: normal; font-weight: 700; font-display: swap; - src: url(/static/assets/UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuFuYMZg.ttf) format('truetype'); + src: url(./UcCO3FwrK3iLTeHuS_fvQtMwCp50KnMw2boKoduKmMEVuFuYMZg.ttf) format('truetype'); } @font-face { font-family: 'Roboto'; font-style: normal; font-weight: 400; font-display: swap; - src: url(/static/assets/KFOmCnqEu92Fr1Me5Q.ttf) format('truetype'); + src: url(./KFOmCnqEu92Fr1Me5Q.ttf) format('truetype'); } @font-face { font-family: 'Roboto'; font-style: normal; font-weight: 500; font-display: swap; - src: url(/static/assets/KFOlCnqEu92Fr1MmEU9vAw.ttf) format('truetype'); + src: url(./KFOlCnqEu92Fr1MmEU9vAw.ttf) format('truetype'); } diff --git a/core/http/static/assets/font2.css b/core/http/static/assets/font2.css index f2f47e748f69..387b61d96ae1 100644 --- a/core/http/static/assets/font2.css +++ b/core/http/static/assets/font2.css @@ -7,33 +7,33 @@ https://fonts.googleapis.com/css?family=Roboto:300,400,500,700,900&display=swap font-style: normal; font-weight: 300; font-display: swap; - src: url(/static/assets//KFOlCnqEu92Fr1MmSU5fBBc9.ttf) format('truetype'); + src: url(./KFOlCnqEu92Fr1MmSU5fBBc9.ttf) format('truetype'); } @font-face { font-family: 'Roboto'; font-style: normal; font-weight: 400; font-display: swap; - src: url(/static/assets//KFOmCnqEu92Fr1Mu4mxP.ttf) format('truetype'); + src: url(./KFOmCnqEu92Fr1Mu4mxP.ttf) format('truetype'); } @font-face { font-family: 'Roboto'; font-style: normal; font-weight: 500; font-display: swap; - src: url(/static/assets//KFOlCnqEu92Fr1MmEU9fBBc9.ttf) format('truetype'); + src: url(./KFOlCnqEu92Fr1MmEU9fBBc9.ttf) format('truetype'); } @font-face { font-family: 'Roboto'; font-style: normal; font-weight: 700; font-display: swap; - src: url(/static/assets//KFOlCnqEu92Fr1MmWUlfBBc9.ttf) format('truetype'); + src: url(./KFOlCnqEu92Fr1MmWUlfBBc9.ttf) format('truetype'); } @font-face { font-family: 'Roboto'; font-style: normal; font-weight: 900; font-display: swap; - src: url(/static/assets//KFOlCnqEu92Fr1MmYUtfBBc9.ttf) format('truetype'); + src: url(./KFOlCnqEu92Fr1MmYUtfBBc9.ttf) format('truetype'); } diff --git a/core/http/static/chat.js b/core/http/static/chat.js index ef15f838d09e..67e0bb6015e9 100644 --- a/core/http/static/chat.js +++ b/core/http/static/chat.js @@ -143,7 +143,7 @@ function readInputImage() { // } // Source: https://stackoverflow.com/a/75751803/11386095 - const response = await fetch("/v1/chat/completions", { + const response = await fetch("v1/chat/completions", { method: "POST", headers: { Authorization: `Bearer ${key}`, diff --git a/core/http/static/image.js b/core/http/static/image.js index 315bdda089ba..079c9dc02adf 100644 --- a/core/http/static/image.js +++ b/core/http/static/image.js @@ -48,7 +48,7 @@ async function promptDallE(key, input) { document.getElementById("input").disabled = true; const model = document.getElementById("image-model").value; - const response = await fetch("/v1/images/generations", { + const response = await fetch("v1/images/generations", { method: "POST", headers: { Authorization: `Bearer ${key}`, diff --git a/core/http/static/talk.js b/core/http/static/talk.js index 3072da8473af..ecaa0f0bfdc8 100644 --- a/core/http/static/talk.js +++ b/core/http/static/talk.js @@ -122,7 +122,7 @@ async function sendAudioToWhisper(audioBlob) { formData.append('model', getWhisperModel()); API_KEY = localStorage.getItem("key"); - const response = await fetch('/v1/audio/transcriptions', { + const response = await fetch('v1/audio/transcriptions', { method: 'POST', headers: { 'Authorization': `Bearer ${API_KEY}` @@ -139,7 +139,7 @@ async function sendTextToChatGPT(text) { conversationHistory.push({ role: "user", content: text }); API_KEY = localStorage.getItem("key"); - const response = await fetch('/v1/chat/completions', { + const response = await fetch('v1/chat/completions', { method: 'POST', headers: { 'Authorization': `Bearer ${API_KEY}`, @@ -163,7 +163,7 @@ async function sendTextToChatGPT(text) { async function getTextToSpeechAudio(text) { API_KEY = localStorage.getItem("key"); - const response = await fetch('/v1/audio/speech', { + const response = await fetch('v1/audio/speech', { method: 'POST', headers: { diff --git a/core/http/static/tts.js b/core/http/static/tts.js index 7fc747299ae3..daead3a88ff3 100644 --- a/core/http/static/tts.js +++ b/core/http/static/tts.js @@ -19,7 +19,7 @@ async function tts(key, input) { document.getElementById("input").disabled = true; const model = document.getElementById("tts-model").value; - const response = await fetch("/tts", { + const response = await fetch("tts", { method: "POST", headers: { Authorization: `Bearer ${key}`, diff --git a/core/http/utils/baseurl.go b/core/http/utils/baseurl.go new file mode 100644 index 000000000000..9df73052f464 --- /dev/null +++ b/core/http/utils/baseurl.go @@ -0,0 +1,28 @@ +package utils + +import ( + "strings" + + "github.com/gofiber/fiber/v2" +) + +// BaseURL returns the base URL for the given HTTP request context, honouring the X-Forwarded-Proto, X-Forwarded-Host and X-Forwarded-Prefix HTTP headers. +// This is to allow the web app to run behind a reverse-proxy that may expose it under a different host, path and protocol (HTTPS). +// The returned URL is guaranteed to end with `/`. +// The method should be used in conjunction with the StripPathPrefix middleware. +func BaseURL(c *fiber.Ctx) string { + forwardedPrefix := c.GetReqHeaders()["X-Forwarded-Prefix"] + for _, prefix := range forwardedPrefix { + if len(prefix) > 0 { + if prefix[len(prefix)-1] != '/' { + prefix += "/" + } + + if strings.HasPrefix(c.OriginalURL(), prefix) { + return c.BaseURL() + prefix + } + } + } + + return c.BaseURL() + "/" +} diff --git a/core/http/utils/baseurl_test.go b/core/http/utils/baseurl_test.go new file mode 100644 index 000000000000..3be89959d209 --- /dev/null +++ b/core/http/utils/baseurl_test.go @@ -0,0 +1,99 @@ +package utils + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/stretchr/testify/require" +) + +func TestBaseURL(t *testing.T) { + var actualURL string + + app := fiber.New() + + app.Use(middleware.StripPathPrefix()) + + app.Get("/hello/world", func(c *fiber.Ctx) error { + actualURL = BaseURL(c) + return nil + }) + + for _, tc := range []struct { + name string + prefix string + headers http.Header + expectURL string + }{ + { + name: "without prefix and header", + prefix: "/", + headers: map[string][]string{}, + expectURL: "http://example.com/", + }, + { + name: "without prefix but header", + prefix: "/", + headers: map[string][]string{ + "X-Forwarded-Prefix": []string{"/otherprefix/"}, + }, + expectURL: "http://example.com/", + }, + { + name: "with prefix and matching header", + prefix: "/myprefix/", + headers: map[string][]string{ + "X-Forwarded-Prefix": []string{"/myprefix/"}, + }, + expectURL: "http://example.com/myprefix/", + }, + { + name: "with prefix and 1st header matching", + prefix: "/myprefix/", + headers: map[string][]string{ + "X-Forwarded-Prefix": []string{"/myprefix/", "/otherprefix/"}, + }, + expectURL: "http://example.com/myprefix/", + }, + { + name: "with prefix and 2nd header matching", + prefix: "/myprefix/", + headers: map[string][]string{ + "X-Forwarded-Prefix": []string{"/otherprefix/", "/myprefix/"}, + }, + expectURL: "http://example.com/myprefix/", + }, + { + name: "with prefix and header not ending with slash", + prefix: "/myprefix/", + headers: map[string][]string{ + "X-Forwarded-Prefix": []string{"/myprefix"}, + }, + expectURL: "http://example.com/myprefix/", + }, + { + name: "with other protocol, host and path", + prefix: "/subpath/", + headers: map[string][]string{ + "X-Forwarded-Proto": []string{"https"}, + "X-Forwarded-Host": []string{"example.org"}, + "X-Forwarded-Prefix": []string{"/subpath/"}, + }, + expectURL: "https://example.org/subpath/", + }, + } { + t.Run(tc.name, func(t *testing.T) { + actualURL = "" + req := httptest.NewRequest("GET", tc.prefix+"hello/world", nil) + req.Header = tc.headers + resp, err := app.Test(req, -1) + + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode, "response status code") + require.Equal(t, tc.expectURL, actualURL, "base URL") + }) + } +} diff --git a/core/http/views/404.html b/core/http/views/404.html index 359d85055442..2f5a43864ce5 100644 --- a/core/http/views/404.html +++ b/core/http/views/404.html @@ -12,7 +12,7 @@

Welcome to your LocalAI instance!

- diff --git a/core/http/views/chat.html b/core/http/views/chat.html index 67d40bfd5817..b0f11281df07 100644 --- a/core/http/views/chat.html +++ b/core/http/views/chat.html @@ -28,7 +28,7 @@ {{template "views/partials/head" .}} - +