From f8813f8c20c8fdaf9f1e1e7ad235e98c8110e144 Mon Sep 17 00:00:00 2001 From: Justin Wu <93353412+iustinum@users.noreply.github.com> Date: Fri, 25 Oct 2024 14:38:57 -0700 Subject: [PATCH] BED-4766: Correctly Format SAML Provider Details in ListAuthProviders Endpoint (#915) * chore: list properly formatted saml provider * chore: addresses array return of FormatSAMLProviderURLs --- cmd/api/src/api/v2/auth/sso.go | 3 ++- cmd/api/src/auth/bhsaml/db.go | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cmd/api/src/api/v2/auth/sso.go b/cmd/api/src/api/v2/auth/sso.go index f0572c8d27..75a9de63e6 100644 --- a/cmd/api/src/api/v2/auth/sso.go +++ b/cmd/api/src/api/v2/auth/sso.go @@ -17,6 +17,7 @@ package auth import ( + "github.com/specterops/bloodhound/src/auth/bhsaml" "net/http" "strconv" "strings" @@ -116,7 +117,7 @@ func (s ManagementResource) ListAuthProviders(response http.ResponseWriter, requ } case model.SessionAuthProviderSAML: if ssoProvider.SAMLProvider != nil { - provider.Details = ssoProvider.SAMLProvider + provider.Details = bhsaml.FormatSAMLProviderURLs(request.Context(), *ssoProvider.SAMLProvider)[0] } } diff --git a/cmd/api/src/auth/bhsaml/db.go b/cmd/api/src/auth/bhsaml/db.go index 617ab3d3da..3c6874cbc3 100644 --- a/cmd/api/src/auth/bhsaml/db.go +++ b/cmd/api/src/auth/bhsaml/db.go @@ -25,7 +25,7 @@ import ( "github.com/specterops/bloodhound/src/serde" ) -func formatSAMLProviderURLs(requestContext context.Context, samlProviders ...model.SAMLProvider) model.SAMLProviders { +func FormatSAMLProviderURLs(requestContext context.Context, samlProviders ...model.SAMLProvider) model.SAMLProviders { for idx := 0; idx < len(samlProviders); idx++ { providerURLs := FormatServiceProviderURLs(*ctx.Get(requestContext).Host, samlProviders[idx].Name) @@ -42,7 +42,7 @@ func GetSAMLProviderByName(db database.Database, name string, requestContext con if samlProvider, err := db.LookupSAMLProviderByName(requestContext, name); err != nil { return model.SAMLProvider{}, err } else { - return formatSAMLProviderURLs(requestContext, samlProvider)[0], nil + return FormatSAMLProviderURLs(requestContext, samlProvider)[0], nil } } @@ -50,7 +50,7 @@ func GetSAMLProviderByID(db database.Database, id int32, requestContext context. if samlProvider, err := db.GetSAMLProvider(requestContext, id); err != nil { return model.SAMLProvider{}, err } else { - return formatSAMLProviderURLs(requestContext, samlProvider)[0], nil + return FormatSAMLProviderURLs(requestContext, samlProvider)[0], nil } } @@ -58,6 +58,6 @@ func GetAllSAMLProviders(db database.Database, requestContext context.Context) ( if samlProviders, err := db.GetAllSAMLProviders(requestContext); err != nil { return nil, err } else { - return formatSAMLProviderURLs(requestContext, samlProviders...), nil + return FormatSAMLProviderURLs(requestContext, samlProviders...), nil } }