diff --git a/src/SPConfig.js b/src/SPConfig.js
index 281487c1..3cee9b4d 100644
--- a/src/SPConfig.js
+++ b/src/SPConfig.js
@@ -17,6 +17,7 @@ export default class SPConfig {
this.idpCert = argv.spIdpCert;
this.idpThumbprint = argv.spIdpThumbprint;
this.idpMetaUrl = argv.spIdpMetaUrl;
+ this.idpIssuerMatchOverride = argv.spIdpIssuerMatchOverride;
this.audience = argv.spAudience;
this.providerName = argv.spProviderName;
this.signAuthnRequests = argv.spSignAuthnRequests;
diff --git a/src/cli/index.js b/src/cli/index.js
index 704da48d..3ce7b2c1 100644
--- a/src/cli/index.js
+++ b/src/cli/index.js
@@ -175,6 +175,12 @@ export function processArgs() {
string: true,
default: "samlp",
},
+ spIdpCategory: {
+ description: "The IDP category for the default upstream IDP",
+ required: true,
+ string: true,
+ default: "id_me",
+ },
spIdpIssuer: {
description: "IdP Issuer URI",
required: false,
diff --git a/src/routes/acsHandlers.test.ts b/src/routes/acsHandlers.test.ts
index bf6bff94..efde670d 100644
--- a/src/routes/acsHandlers.test.ts
+++ b/src/routes/acsHandlers.test.ts
@@ -6,11 +6,9 @@ import { VsoClient } from "../VsoClient";
import { MVIRequestMetrics } from "../metrics";
import { TestCache } from "./types";
import {
- buildSamlResponseFunction,
defaultMockRequest,
+ b64encodedDataFromFile,
} from "../../test/testUtils";
-import { idpConfig } from "../../test/testServer";
-import { IDME_USER } from "../../test/testUsers";
import { accessiblePhoneNumber } from "../utils";
import samlp from "samlp";
jest.mock("passport");
@@ -593,7 +591,6 @@ describe("buildPassportLoginHandler", () => {
let req: any;
let mockResponse: any;
let mockNext: any;
- const buildSamlResponse = buildSamlResponseFunction(1);
beforeEach(async () => {
req = defaultMockRequest;
mockResponse = {
@@ -608,7 +605,7 @@ describe("buildPassportLoginHandler", () => {
});
it("happy path", () => {
- req.query.SAMLResponse = buildSamlResponse(IDME_USER, "3", idpConfig);
+ req.query.SAMLResponse = b64encodedDataFromFile("idp1_example.xml");
handlers.buildPassportLoginHandler("http://example.com/acs")(
req,
mockResponse,
diff --git a/src/routes/passport.js b/src/routes/passport.js
index 2a2050b8..4fbc00de 100644
--- a/src/routes/passport.js
+++ b/src/routes/passport.js
@@ -2,6 +2,7 @@ import passport from "passport";
import { Strategy } from "passport-wsfed-saml2";
import omit from "lodash.omit";
import { IDPProfileMapper } from "../IDPProfileMapper";
+import { issuerFromSamlResponse } from "../utils";
/**
* Creates the passport strategy using the response params
@@ -74,12 +75,31 @@ export function preparePassport(strategy) {
* @returns {*} A string with the correct spIdp key
*/
export const selectPassportStrategyKey = (req) => {
- const origin = req.headers.origin;
- let passportKey = "id_me";
- Object.entries(req.sps.options).forEach((spIdpEntry) => {
- if (spIdpEntry[1].idpSsoUrl.startsWith(origin)) {
- passportKey = spIdpEntry[0];
- }
+ const samlResponse = req.body?.SAMLResponse || req.query?.SAMLResponse;
+ const issuer = issuerFromSamlResponse(samlResponse);
+ const spIdpKeys = Object.keys(req.sps.options);
+ const foundSpIdpKey = spIdpKeys.find((spIdpKey) => {
+ const spIdpOption = req.sps.options[spIdpKey];
+ const domain = spIdpHostDomain(spIdpOption);
+ return spIdpOption.idpIssuerMatchOverride
+ ? issuer.includes(spIdpOption.idpIssuerMatchOverride)
+ : issuer.includes(domain);
});
- return passportKey;
+ return foundSpIdpKey ? foundSpIdpKey : spIdpKeys[0];
};
+
+/**
+ * Returns the domain that cooresponds to the host of the metadata url.
+ *
+ * @param {spConfig} spIdpOption The SP Config Option
+ * @returns {string} The domain name
+ */
+function spIdpHostDomain(spIdpOption) {
+ const url = new URL(spIdpOption.idpMetaUrl);
+ const domain_parts = url.host.split(".");
+ const domain =
+ domain_parts.length > 1
+ ? domain_parts[domain_parts.length - 2]
+ : domain_parts[0];
+ return domain;
+}
diff --git a/src/routes/passport.test.js b/src/routes/passport.test.js
index 87443e46..2309c8a1 100644
--- a/src/routes/passport.test.js
+++ b/src/routes/passport.test.js
@@ -1,5 +1,7 @@
+/* eslint-disable jsdoc/require-returns */
import "jest";
import { selectPassportStrategyKey } from "./passport";
+import { b64encodedDataFromFile } from "../../test/testUtils";
const mockReq = {
headers: {
@@ -8,24 +10,50 @@ const mockReq = {
sps: {
options: {
idp1: {
- idpSsoUrl: "http://login.example1.com/saml/sso",
+ category: "idp1",
+ idpMetaUrl: "https://api.idp1.com/saml/metadata/provider",
},
idp2: {
- idpSsoUrl: "http://login.example2.com/saml/sso",
+ category: "idp2",
+ idpMetaUrl: "https://idp.int.idp2.org/api/saml/metadata2023",
+ },
+ idp3: {
+ category: "idp3",
+ idpMetaUrl: "https://idp3:18443/realms/xxxx/protocol/saml/descriptor",
+ },
+ idp4: {
+ category: "idp4",
+ idpMetaUrl:
+ "https://deptyyy.idp4preview.com/app/yyyy/sso/saml/metadata",
+ idpIssuerMatchOverride: "idp4.com",
},
},
},
};
describe("selectPassportStrategyKey", () => {
test("selectPassportStrategyKey idp1", () => {
+ mockReq.body = { SAMLResponse: b64encodedDataFromFile("idp1_example.xml") };
expect(selectPassportStrategyKey(mockReq)).toBe("idp1");
});
test("selectPassportStrategyKey idp2", () => {
- mockReq.headers.origin = "http://login.example2.com";
+ mockReq.body = { SAMLResponse: b64encodedDataFromFile("idp2_example.xml") };
expect(selectPassportStrategyKey(mockReq)).toBe("idp2");
});
- test("selectPassportStrategyKey default 'id_me'", () => {
- mockReq.headers.origin = "http://login.example0.com";
- expect(selectPassportStrategyKey(mockReq)).toBe("id_me");
+
+ test("selectPassportStrategyKey idp3", () => {
+ mockReq.body = { SAMLResponse: b64encodedDataFromFile("idp3_example.xml") };
+ expect(selectPassportStrategyKey(mockReq)).toBe("idp3");
+ });
+
+ test("selectPassportStrategyKey idp4", () => {
+ mockReq.body = { SAMLResponse: b64encodedDataFromFile("idp4_example.xml") };
+ expect(selectPassportStrategyKey(mockReq)).toBe("idp4");
+ });
+
+ test("selectPassportStrategyKey default 'idp1'", () => {
+ mockReq.body = {
+ SAMLResponse: b64encodedDataFromFile("unmatched_example.xml"),
+ };
+ expect(selectPassportStrategyKey(mockReq)).toBe("idp1");
});
});
diff --git a/src/utils.js b/src/utils.js
index 6695f109..dd04c1f6 100644
--- a/src/utils.js
+++ b/src/utils.js
@@ -220,6 +220,29 @@ function getInResponseToFromSAML(samlResponse) {
}
}
+/**
+ * Retrieves the issuer from a b64 encoded SAMLResponse
+ *
+ * @param {string} samlResponse the raw samlResponse
+ * @returns {*} a string if Issuer is present
+ */
+export function issuerFromSamlResponse(samlResponse) {
+ try {
+ const decoded = Buffer.from(samlResponse, "base64").toString("ascii");
+ const parser = new DOMParser();
+ const issuerElems = parser
+ .parseFromString(decoded)
+ .documentElement.getElementsByTagNameNS(
+ "urn:oasis:names:tc:SAML:2.0:assertion",
+ "Issuer"
+ );
+ const issuer = issuerElems[0].textContent.trim();
+ return issuer;
+ } catch (err) {
+ logger.error("decodedSamlResponse failed: ", err);
+ }
+}
+
/**
* Retrieves ID assertion from SAMLRequest
*
diff --git a/test/samlResponses/decoded/idp1_example.xml b/test/samlResponses/decoded/idp1_example.xml
new file mode 100644
index 00000000..908c5dae
--- /dev/null
+++ b/test/samlResponses/decoded/idp1_example.xml
@@ -0,0 +1,7 @@
+
+ api.idp1.com
+
\ No newline at end of file
diff --git a/test/samlResponses/decoded/idp2_example.xml b/test/samlResponses/decoded/idp2_example.xml
new file mode 100644
index 00000000..bc82b11c
--- /dev/null
+++ b/test/samlResponses/decoded/idp2_example.xml
@@ -0,0 +1,7 @@
+
+
+ https://idp2.org/api/saml
+
\ No newline at end of file
diff --git a/test/samlResponses/decoded/idp3_example.xml b/test/samlResponses/decoded/idp3_example.xml
new file mode 100644
index 00000000..5385c1f8
--- /dev/null
+++ b/test/samlResponses/decoded/idp3_example.xml
@@ -0,0 +1,7 @@
+
+ https://idp3:18443/idp
+
\ No newline at end of file
diff --git a/test/samlResponses/decoded/idp4_example.xml b/test/samlResponses/decoded/idp4_example.xml
new file mode 100644
index 00000000..c9fdb2b8
--- /dev/null
+++ b/test/samlResponses/decoded/idp4_example.xml
@@ -0,0 +1,8 @@
+
+
+ http://www.idp4.com/yyyyyy
+
\ No newline at end of file
diff --git a/test/samlResponses/decoded/unmatched_example.xml b/test/samlResponses/decoded/unmatched_example.xml
new file mode 100644
index 00000000..4852da64
--- /dev/null
+++ b/test/samlResponses/decoded/unmatched_example.xml
@@ -0,0 +1,10 @@
+
+ http://wontfind/xxxx
+
+
+
+
\ No newline at end of file
diff --git a/test/testUtils.js b/test/testUtils.js
index f4691728..38a83a0b 100644
--- a/test/testUtils.js
+++ b/test/testUtils.js
@@ -1,6 +1,8 @@
import btoa from "btoa";
import { getSamlResponse } from "samlp";
import { getUser } from "./testUsers";
+const fs = require("fs");
+const path = require("path");
/**
* This test function builds the saml response function using session index
@@ -33,9 +35,6 @@ export let defaultMockRequest = {
query: {
relayState: "relay",
},
- headers: {
- origin: "https://idp.example.com",
- },
body: {
RelayState: "relay",
SAMLResponse: null
@@ -45,6 +44,7 @@ export let defaultMockRequest = {
id_me: {
getResponseParams: jest.fn(() => {}),
idpSsoUrl: "https://idp.example.com/saml/sso",
+ idpMetaUrl: "https://api.idmelabs.com/metadata",
},
},
},
@@ -72,3 +72,17 @@ export let defaultMockRequest = {
originalUrl: "http://original.example.com",
x_fowarded_host: "fowarded.example.com",
};
+
+/**
+ * Loads test data into a string.
+ *
+ * @param {*} fname The file with test data
+ */
+export function b64encodedDataFromFile(fname) {
+ const file = path.join("./test/samlResponses/decoded", fname);
+ const samlResponse = fs.readFileSync(file, "utf8", function (err, data) {
+ return data;
+ });
+ const encoded = Buffer.from(samlResponse, "ascii").toString("base64");
+ return encoded;
+}