diff --git a/pkg/auth/handler/webapp/viewmodels/base.go b/pkg/auth/handler/webapp/viewmodels/base.go index 36705ae13c..32b868bd96 100644 --- a/pkg/auth/handler/webapp/viewmodels/base.go +++ b/pkg/auth/handler/webapp/viewmodels/base.go @@ -21,6 +21,7 @@ import ( "github.com/authgear/authgear-server/pkg/util/httputil" "github.com/authgear/authgear-server/pkg/util/intl" "github.com/authgear/authgear-server/pkg/util/log" + "github.com/authgear/authgear-server/pkg/util/slice" "github.com/authgear/authgear-server/pkg/util/template" "github.com/authgear/authgear-server/pkg/util/wechat" ) @@ -267,17 +268,23 @@ func (m *BaseViewModeler) ViewModel(r *http.Request, rw http.ResponseWriter) Bas } return webapp.MakeURL(u, path, outQuery).String() }, - ForgotPasswordEnabled: *m.ForgotPassword.Enabled, - PublicSignupDisabled: m.Authentication.PublicSignupDisabled, - PageLoadedAt: int(now), - FlashMessageType: m.FlashMessage.Pop(r, rw), - ResolvedLanguageTag: resolvedLanguageTag, - ResolvedCLDRLocale: locale, - HTMLDir: htmlDir, - GoogleTagManagerContainerID: m.GoogleTagManager.ContainerID, - HasThirdPartyClient: hasThirdPartyApp, - AuthUISentryDSN: string(m.AuthUISentryDSN), - AuthUIWindowMessageAllowedOrigins: strings.Join(m.AuthUIWindowMessageAllowedOrigins, ","), + ForgotPasswordEnabled: *m.ForgotPassword.Enabled, + PublicSignupDisabled: m.Authentication.PublicSignupDisabled, + PageLoadedAt: int(now), + FlashMessageType: m.FlashMessage.Pop(r, rw), + ResolvedLanguageTag: resolvedLanguageTag, + ResolvedCLDRLocale: locale, + HTMLDir: htmlDir, + GoogleTagManagerContainerID: m.GoogleTagManager.ContainerID, + HasThirdPartyClient: hasThirdPartyApp, + AuthUISentryDSN: string(m.AuthUISentryDSN), + AuthUIWindowMessageAllowedOrigins: func() string { + requestProto := httputil.GetProto(r, bool(m.TrustProxy)) + processedAllowedOrgins := slice.Map(m.AuthUIWindowMessageAllowedOrigins, func(origin string) string { + return composeAuthUIWindowMessageAllowedOrigin(origin, requestProto) + }) + return strings.Join(processedAllowedOrgins, ",") + }(), LogUnknownError: func(err map[string]interface{}) string { if err != nil { m.Logger.WithFields(err).Errorf("unknown error: %v", err) @@ -326,3 +333,15 @@ func (m *BaseViewModeler) ViewModel(r *http.Request, rw http.ResponseWriter) Bas return model } + +// Assume allowed origin is either host or a real origin +func composeAuthUIWindowMessageAllowedOrigin(allowedOrigin string, proto string) string { + if strings.HasPrefix(allowedOrigin, "http://") || strings.HasPrefix(allowedOrigin, "https://") { + return allowedOrigin + } + u := url.URL{ + Scheme: proto, + Host: allowedOrigin, + } + return u.String() +} diff --git a/pkg/auth/handler/webapp/viewmodels/base_test.go b/pkg/auth/handler/webapp/viewmodels/base_test.go new file mode 100644 index 0000000000..0f5802dc41 --- /dev/null +++ b/pkg/auth/handler/webapp/viewmodels/base_test.go @@ -0,0 +1,27 @@ +package viewmodels + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestComposeAuthUIWindowMessageAllowedOrigin(t *testing.T) { + Convey("composeAuthUIWindowMessageAllowedOrigin", t, func() { + Convey("Given a origin", func() { + origin := "http://www.example.com" + Convey("It returns the origin unprocessed", func() { + requestProto := "http" + So(composeAuthUIWindowMessageAllowedOrigin(origin, requestProto), ShouldEqual, origin) + }) + }) + + Convey("Given a host", func() { + host := "www.example.com" + Convey("It returns the origin according to assgined proto", func() { + requestProto := "http" + So(composeAuthUIWindowMessageAllowedOrigin(host, requestProto), ShouldEqual, "http://www.example.com") + }) + }) + }) +}