Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): path prefix support via HTTP header #4497

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ func API(application *application.Application) (*fiber.App, error) {

router := fiber.New(fiberCfg)

router.Use(middleware.StripPathPrefix())

router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
Expand Down
52 changes: 52 additions & 0 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,31 @@ func postInvalidRequest(url string) (error, int) {
return nil, resp.StatusCode
}

func getRequest(url string, header http.Header) (error, int, []byte) {

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err, -1, nil
}

req.Header = header

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err, -1, nil
}

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return err, -1, nil
}

return nil, resp.StatusCode, body
}

const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml`

//go:embed backend-assets/*
Expand Down Expand Up @@ -345,6 +370,33 @@ var _ = Describe("API test", func() {
})
})

Context("URL routing Tests", func() {
It("Should support reverse-proxy when unauthenticated", func() {

err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"/myprefix/"},
})
Expect(err).To(BeNil(), "error")
Expect(sc).To(Equal(401), "status code")
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body")
})

It("Should support reverse-proxy when authenticated", func() {

err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
"Authorization": {bearerKey},
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Host": {"example.org"},
"X-Forwarded-Prefix": {"/myprefix/"},
})
Expect(err).To(BeNil(), "error")
Expect(sc).To(Equal(200), "status code")
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body")
})
})

Context("Applying models", func() {

It("applies models from a gallery", func() {
Expand Down
6 changes: 3 additions & 3 deletions core/http/elements/buttons.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func installButton(galleryName string) elem.Node {
"class": "float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/install/model/" + galleryName,
"hx-post": "browse/install/model/" + galleryName,
},
elem.I(
attrs.Props{
Expand All @@ -36,7 +36,7 @@ func reInstallButton(galleryName string) elem.Node {
"hx-target": "#action-div-" + dropBadChars(galleryName),
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/install/model/" + galleryName,
"hx-post": "browse/install/model/" + galleryName,
},
elem.I(
attrs.Props{
Expand Down Expand Up @@ -80,7 +80,7 @@ func deleteButton(galleryID string) elem.Node {
"hx-target": "#action-div-" + dropBadChars(galleryID),
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/delete/model/" + galleryID,
"hx-post": "browse/delete/model/" + galleryID,
},
elem.I(
attrs.Props{
Expand Down
2 changes: 1 addition & 1 deletion core/http/elements/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func searchableElement(text, icon string) elem.Node {
// "value": text,
//"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2",
"href": "#!",
"hx-post": "/browse/search/models",
"hx-post": "browse/search/models",
"hx-target": "#search-results",
// TODO: this doesn't work
// "hx-vals": `{ \"search\": \"` + text + `\" }`,
Expand Down
4 changes: 2 additions & 2 deletions core/http/elements/progressbar.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func StartProgressBar(uid, progress, text string) string {
return elem.Div(
attrs.Props{
"hx-trigger": "done",
"hx-get": "/browse/job/" + uid,
"hx-get": "browse/job/" + uid,
"hx-swap": "outerHTML",
"hx-target": "this",
},
Expand All @@ -77,7 +77,7 @@ func StartProgressBar(uid, progress, text string) string {
},
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), //Perhaps overly defensive
elem.Div(attrs.Props{
"hx-get": "/browse/job/progress/" + uid,
"hx-get": "browse/job/progress/" + uid,
"hx-trigger": "every 600ms",
"hx-target": "this",
"hx-swap": "innerHTML",
Expand Down
2 changes: 2 additions & 0 deletions core/http/endpoints/explorer/dashboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/internal"
)

Expand All @@ -14,6 +15,7 @@ func Dashboard() func(*fiber.Ctx) error {
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
}

if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
Expand Down
6 changes: 4 additions & 2 deletions core/http/endpoints/localai/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -82,7 +83,8 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe
Galleries: mgs.galleries,
ConfigURL: input.ConfigURL,
}
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})

return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}

Expand All @@ -105,7 +107,7 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib
return err
}

return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}

Expand Down
2 changes: 2 additions & 0 deletions core/http/endpoints/localai/welcome.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
Expand Down Expand Up @@ -32,6 +33,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
"Models": modelsWithoutConfig,
"ModelsConfig": backendConfigs,
"GalleryConfig": galleryConfigs,
Expand Down
2 changes: 2 additions & 0 deletions core/http/explorer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
)

Expand All @@ -22,6 +23,7 @@ func Explorer(db *explorer.Database) *fiber.App {

app := fiber.New(fiberCfg)

app.Use(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(app, db)

httpFS := http.FS(embedDirStatic)
Expand Down
5 changes: 4 additions & 1 deletion core/http/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
)

// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
Expand Down Expand Up @@ -39,7 +40,9 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(401)
}
return ctx.Status(401).Render("views/login", nil)
return ctx.Status(401).Render("views/login", fiber.Map{
"BaseURL": utils.BaseURL(ctx),
})
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500)
Expand Down
36 changes: 36 additions & 0 deletions core/http/middleware/strippathprefix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package middleware

import (
"strings"

"github.com/gofiber/fiber/v2"
)

// StripPathPrefix returns a middleware that strips a path prefix from the request path.
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
func StripPathPrefix() fiber.Handler {
return func(c *fiber.Ctx) error {
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
if prefix != "" {
path := c.Path()
pos := len(prefix)

if prefix[pos-1] == '/' {
pos--
} else {
prefix += "/"
}

if strings.HasPrefix(path, prefix) {
c.Path(path[pos:])
break
} else if prefix[:pos] == path {
c.Redirect(prefix)
return nil
}
}
}

return c.Next()
}
}
121 changes: 121 additions & 0 deletions core/http/middleware/strippathprefix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package middleware

import (
"net/http/httptest"
"testing"

"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)

func TestStripPathPrefix(t *testing.T) {
var actualPath string

app := fiber.New()

app.Use(StripPathPrefix())

app.Get("/hello/world", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})

app.Get("/", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})

for _, tc := range []struct {
name string
path string
prefixHeader []string
expectStatus int
expectPath string
}{
{
name: "without prefix and header",
path: "/hello/world",
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "without prefix and headers on root path",
path: "/",
expectStatus: 200,
expectPath: "/",
},
{
name: "without prefix but header",
path: "/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix but non-matching header",
path: "/prefix/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 404,
},
{
name: "with prefix and matching header",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 1st header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 2nd header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and header not ending with slash",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and non-matching header not ending with slash",
path: "/myprefix-suffix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 404,
},
{
name: "redirect when prefix does not end with a slash",
path: "/myprefix",
prefixHeader: []string{"/myprefix"},
expectStatus: 302,
expectPath: "/myprefix/",
},
} {
t.Run(tc.name, func(t *testing.T) {
actualPath = ""
req := httptest.NewRequest("GET", tc.path, nil)
if tc.prefixHeader != nil {
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
}

resp, err := app.Test(req, -1)

require.NoError(t, err)
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")

if tc.expectStatus == 200 {
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
} else if tc.expectStatus == 302 {
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
}
})
}
}
Loading
Loading