diff --git a/library/agent/Context.test.ts b/library/agent/Context.test.ts index ccc93cec9..546bec68b 100644 --- a/library/agent/Context.test.ts +++ b/library/agent/Context.test.ts @@ -1,16 +1,20 @@ import * as t from "tap"; +import { extractStringsFromUserInputCached } from "../helpers/extractStringsFromUserInputCached"; import { type Context, getContext, runWithContext, bindContext, + updateContext, } from "./Context"; const sampleContext: Context = { remoteAddress: "::1", method: "POST", url: "http://localhost:4000", - query: {}, + query: { + abc: "def", + }, headers: {}, body: undefined, cookies: {}, @@ -98,3 +102,33 @@ t.test("Get context does work with bindContext", async (t) => { emitter.emit("event"); }); + +t.test("it clears cache when context is mutated", async (t) => { + const context = { ...sampleContext }; + + runWithContext(context, () => { + t.same(extractStringsFromUserInputCached(getContext(), "body"), undefined); + t.same( + extractStringsFromUserInputCached(getContext(), "query"), + new Map(Object.entries({ abc: ".", def: ".abc" })) + ); + + updateContext(getContext(), "query", {}); + t.same(extractStringsFromUserInputCached(getContext(), "body"), undefined); + t.same( + extractStringsFromUserInputCached(getContext(), "query"), + new Map(Object.entries({})) + ); + + runWithContext({ ...context, body: { a: "z" }, query: { b: "y" } }, () => { + t.same( + extractStringsFromUserInputCached(getContext(), "body"), + new Map(Object.entries({ a: ".", z: ".a" })) + ); + t.same( + extractStringsFromUserInputCached(getContext(), "query"), + new Map(Object.entries({ b: ".", y: ".b" })) + ); + }); + }); +}); diff --git a/library/agent/Context.ts b/library/agent/Context.ts index 23a1b02b7..4e09b0a7f 100644 --- a/library/agent/Context.ts +++ b/library/agent/Context.ts @@ -1,6 +1,8 @@ import type { ParsedQs } from "qs"; +import { extractStringsFromUserInput } from "../helpers/extractStringsFromUserInput"; import { ContextStorage } from "./context/ContextStorage"; import { AsyncResource } from "async_hooks"; +import { Source, SOURCES } from "./Source"; export type User = { id: string; name?: string }; @@ -22,15 +24,35 @@ export type Context = { graphql?: string[]; xml?: unknown; subdomains?: string[]; // https://expressjs.com/en/5x/api.html#req.subdomains + cache?: Map>; }; /** * Get the current request context that is being handled + * + * We don't want to allow the user to modify the context directly, so we use `Readonly` */ -export function getContext() { +export function getContext(): Readonly | undefined { return ContextStorage.getStore(); } +function isSourceKey(key: string): key is Source { + return SOURCES.includes(key as Source); +} + +// We need to use a function to mutate the context because we need to clear the cache when the user input changes +export function updateContext( + context: Context, + key: K, + value: Context[K] +) { + context[key] = value; + + if (context.cache && isSourceKey(key)) { + context.cache.delete(key); + } +} + /** * Executes a function with a given request context * @@ -58,9 +80,17 @@ export function runWithContext(context: Context, fn: () => T) { current.xml = context.xml; current.subdomains = context.subdomains; + // Clear all the cached user input strings + delete current.cache; + return fn(); } + // Cleanup lingering cache + // In tests the context is often passed by reference + // Make sure to clean up the cache before running the function + delete context.cache; + // If there's no context yet, we create a new context and run the function with it return ContextStorage.run(context, fn); } diff --git a/library/agent/applyHooks.ts b/library/agent/applyHooks.ts index 428b078f1..3b19c780c 100644 --- a/library/agent/applyHooks.ts +++ b/library/agent/applyHooks.ts @@ -7,7 +7,7 @@ import { satisfiesVersion } from "../helpers/satisfiesVersion"; import { escapeHTML } from "../helpers/escapeHTML"; import { Agent } from "./Agent"; import { attackKindHumanName } from "./Attack"; -import { bindContext, getContext } from "./Context"; +import { bindContext, getContext, updateContext } from "./Context"; import { BuiltinModule } from "./hooks/BuiltinModule"; import { ConstructorInterceptor } from "./hooks/ConstructorInterceptor"; import { Hooks } from "./hooks/Hooks"; @@ -209,7 +209,7 @@ function wrapWithoutArgumentModification( if (result && context && !isAllowedIP) { // Flag request as having an attack detected - context.attackDetected = true; + updateContext(context, "attackDetected", true); agent.onDetectedAttack({ module: module, diff --git a/library/helpers/extractStringsFromUserInputCached.ts b/library/helpers/extractStringsFromUserInputCached.ts new file mode 100644 index 000000000..50669bbe0 --- /dev/null +++ b/library/helpers/extractStringsFromUserInputCached.ts @@ -0,0 +1,25 @@ +import { Context } from "../agent/Context"; +import { Source } from "../agent/Source"; +import { extractStringsFromUserInput } from "./extractStringsFromUserInput"; + +export function extractStringsFromUserInputCached( + context: Context, + source: Source +): ReturnType | undefined { + if (!context[source]) { + return undefined; + } + + if (!context.cache) { + context.cache = new Map(); + } + + let result = context.cache.get(source); + + if (!result) { + result = extractStringsFromUserInput(context[source]); + context.cache.set(source, result); + } + + return result; +} diff --git a/library/ratelimiting/shouldRateLimitRequest.ts b/library/ratelimiting/shouldRateLimitRequest.ts index 76a4107b1..ac8260383 100644 --- a/library/ratelimiting/shouldRateLimitRequest.ts +++ b/library/ratelimiting/shouldRateLimitRequest.ts @@ -1,5 +1,5 @@ import { Agent } from "../agent/Agent"; -import { Context } from "../agent/Context"; +import { Context, updateContext } from "../agent/Context"; import { isLocalhostIP } from "../helpers/isLocalhostIP"; type Result = @@ -16,7 +16,10 @@ type Result = }; // eslint-disable-next-line max-lines-per-function -export function shouldRateLimitRequest(context: Context, agent: Agent): Result { +export function shouldRateLimitRequest( + context: Readonly, + agent: Agent +): Result { const match = agent.getConfig().getEndpoint(context); if (!match) { @@ -61,7 +64,7 @@ export function shouldRateLimitRequest(context: Context, agent: Agent): Result { // This function is executed for every middleware and route handler // We want to count the request only once - context.consumedRateLimitForIP = true; + updateContext(context, "consumedRateLimitForIP", true); if (!allowed) { return { block: true, trigger: "ip" }; @@ -79,7 +82,7 @@ export function shouldRateLimitRequest(context: Context, agent: Agent): Result { // This function is executed for every middleware and route handler // We want to count the request only once - context.consumedRateLimitForUser = true; + updateContext(context, "consumedRateLimitForUser", true); if (!allowed) { return { block: true, trigger: "user" }; diff --git a/library/sources/FastXmlParser.ts b/library/sources/FastXmlParser.ts index 7d7a0dd6d..d6fe475f7 100644 --- a/library/sources/FastXmlParser.ts +++ b/library/sources/FastXmlParser.ts @@ -1,5 +1,5 @@ /* eslint-disable prefer-rest-params */ -import { getContext } from "../agent/Context"; +import { getContext, updateContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; @@ -30,7 +30,7 @@ export class FastXmlParser implements Wrapper { // Replace the body in the context with the parsed result if (result && isPlainObject(result)) { - context.xml = result; + updateContext(context, "xml", result); } } diff --git a/library/sources/FunctionsFramework.test.ts b/library/sources/FunctionsFramework.test.ts index dd9a6e9f3..4b901d439 100644 --- a/library/sources/FunctionsFramework.test.ts +++ b/library/sources/FunctionsFramework.test.ts @@ -5,7 +5,7 @@ import { Agent } from "../agent/Agent"; import { setInstance } from "../agent/AgentSingleton"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; -import { getContext } from "../agent/Context"; +import { getContext, updateContext } from "../agent/Context"; import { LoggerForTesting } from "../agent/logger/LoggerForTesting"; import { createCloudFunctionWrapper, @@ -50,7 +50,7 @@ function getExpressApp() { asyncHandler( createCloudFunctionWrapper((req, res) => { const context = getContext(); - context.attackDetected = true; + updateContext(context, "attackDetected", true); res.send(context); }) ) diff --git a/library/sources/GraphQL.ts b/library/sources/GraphQL.ts index f051cfa97..984c04883 100644 --- a/library/sources/GraphQL.ts +++ b/library/sources/GraphQL.ts @@ -1,7 +1,7 @@ /* eslint-disable prefer-rest-params */ import { Agent } from "../agent/Agent"; import { getInstance } from "../agent/AgentSingleton"; -import { getContext } from "../agent/Context"; +import { getContext, updateContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; import type { ExecutionArgs } from "graphql/execution/execute"; @@ -60,9 +60,9 @@ export class GraphQL implements Wrapper { if (userInputs.length > 0) { if (Array.isArray(context.graphql)) { - context.graphql.push(...userInputs); + updateContext(context, "graphql", context.graphql.concat(userInputs)); } else { - context.graphql = userInputs; + updateContext(context, "graphql", userInputs); } } } diff --git a/library/sources/Lambda.test.ts b/library/sources/Lambda.test.ts index c0e8d6fd9..bc0d2173f 100644 --- a/library/sources/Lambda.test.ts +++ b/library/sources/Lambda.test.ts @@ -5,7 +5,7 @@ import { Agent } from "../agent/Agent"; import { setInstance } from "../agent/AgentSingleton"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; -import { getContext } from "../agent/Context"; +import { getContext, updateContext } from "../agent/Context"; import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { createLambdaWrapper, SQSEvent, APIGatewayProxyEvent } from "./Lambda"; @@ -427,7 +427,7 @@ t.test("it counts attacks", async () => { const handler = createLambdaWrapper(async (event, context) => { const ctx = getContext(); - ctx.attackDetected = true; + updateContext(ctx, "attackDetected", true); return ctx; }); diff --git a/library/sources/Xml2js.ts b/library/sources/Xml2js.ts index 81f8984fd..9abb9afc6 100644 --- a/library/sources/Xml2js.ts +++ b/library/sources/Xml2js.ts @@ -1,5 +1,5 @@ /* eslint-disable prefer-rest-params */ -import { getContext, runWithContext } from "../agent/Context"; +import { getContext, updateContext, runWithContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; @@ -36,8 +36,9 @@ export class Xml2js implements Wrapper { const originalCallback = args[1] as Function; args[1] = function wrapCallback(err: Error, result: unknown) { if (result && isPlainObject(result)) { - context.xml = result; + updateContext(context, "xml", result); } + runWithContext(context, () => originalCallback(err, result)); }; diff --git a/library/sources/XmlMinusJs.ts b/library/sources/XmlMinusJs.ts index 8c6e1729f..b07bbe859 100644 --- a/library/sources/XmlMinusJs.ts +++ b/library/sources/XmlMinusJs.ts @@ -1,5 +1,5 @@ /* eslint-disable prefer-rest-params */ -import { getContext } from "../agent/Context"; +import { getContext, updateContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; @@ -31,7 +31,7 @@ export class XmlMinusJs implements Wrapper { // Replace the body in the context with the parsed result if (parsed && isPlainObject(parsed)) { - context.xml = parsed; + updateContext(context, "xml", parsed); } } diff --git a/library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts b/library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts index 9dfba72e7..d703c5d3c 100644 --- a/library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts +++ b/library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts @@ -1,7 +1,7 @@ import { Context } from "../../agent/Context"; import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; import { SOURCES } from "../../agent/Source"; -import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput"; +import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { detectPathTraversal } from "./detectPathTraversal"; /** @@ -26,21 +26,23 @@ export function checkContextForPathTraversal({ } for (const source of SOURCES) { - if (context[source]) { - const userInput = extractStringsFromUserInput(context[source]); - for (const [str, path] of userInput.entries()) { - if (detectPathTraversal(pathString, str, checkPathStart, isUrl)) { - return { - operation: operation, - kind: "path_traversal", - source: source, - pathToPayload: path, - metadata: { - filename: pathString, - }, - payload: str, - }; - } + const userInput = extractStringsFromUserInputCached(context, source); + if (!userInput) { + continue; + } + + for (const [str, path] of userInput.entries()) { + if (detectPathTraversal(pathString, str, checkPathStart, isUrl)) { + return { + operation: operation, + kind: "path_traversal", + source: source, + pathToPayload: path, + metadata: { + filename: pathString, + }, + payload: str, + }; } } } diff --git a/library/vulnerabilities/shell-injection/checkContextForShellInjection.ts b/library/vulnerabilities/shell-injection/checkContextForShellInjection.ts index f317b0b85..8c0c44fe7 100644 --- a/library/vulnerabilities/shell-injection/checkContextForShellInjection.ts +++ b/library/vulnerabilities/shell-injection/checkContextForShellInjection.ts @@ -1,7 +1,7 @@ import { Context } from "../../agent/Context"; import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; import { SOURCES } from "../../agent/Source"; -import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput"; +import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { detectShellInjection } from "./detectShellInjection"; /** @@ -18,19 +18,21 @@ export function checkContextForShellInjection({ context: Context; }): InterceptorResult { for (const source of SOURCES) { - if (context[source]) { - const userInput = extractStringsFromUserInput(context[source]); - for (const [str, path] of userInput.entries()) { - if (detectShellInjection(command, str)) { - return { - operation: operation, - kind: "shell_injection", - source: source, - pathToPayload: path, - metadata: {}, - payload: str, - }; - } + const userInput = extractStringsFromUserInputCached(context, source); + if (!userInput) { + continue; + } + + for (const [str, path] of userInput.entries()) { + if (detectShellInjection(command, str)) { + return { + operation: operation, + kind: "shell_injection", + source: source, + pathToPayload: path, + metadata: {}, + payload: str, + }; } } } diff --git a/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts b/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts index 29197a4c2..d75f43764 100644 --- a/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts +++ b/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts @@ -1,7 +1,7 @@ import { Context } from "../../agent/Context"; import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; import { SOURCES } from "../../agent/Source"; -import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput"; +import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { detectSQLInjection } from "./detectSQLInjection"; import { SQLDialect } from "./dialects/SQLDialect"; @@ -21,19 +21,21 @@ export function checkContextForSqlInjection({ dialect: SQLDialect; }): InterceptorResult { for (const source of SOURCES) { - if (context[source]) { - const userInput = extractStringsFromUserInput(context[source]); - for (const [str, path] of userInput.entries()) { - if (detectSQLInjection(sql, str, dialect)) { - return { - operation: operation, - kind: "sql_injection", - source: source, - pathToPayload: path, - metadata: {}, - payload: str, - }; - } + const userInput = extractStringsFromUserInputCached(context, source); + if (!userInput) { + continue; + } + + for (const [str, path] of userInput.entries()) { + if (detectSQLInjection(sql, str, dialect)) { + return { + operation: operation, + kind: "sql_injection", + source: source, + pathToPayload: path, + metadata: {}, + payload: str, + }; } } } diff --git a/library/vulnerabilities/ssrf/checkContextForSSRF.ts b/library/vulnerabilities/ssrf/checkContextForSSRF.ts index 1c080bd3c..ca58f772c 100644 --- a/library/vulnerabilities/ssrf/checkContextForSSRF.ts +++ b/library/vulnerabilities/ssrf/checkContextForSSRF.ts @@ -1,7 +1,7 @@ import { Context } from "../../agent/Context"; import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; import { SOURCES } from "../../agent/Source"; -import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput"; +import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { containsPrivateIPAddress } from "./containsPrivateIPAddress"; import { findHostnameInUserInput } from "./findHostnameInUserInput"; @@ -21,20 +21,22 @@ export function checkContextForSSRF({ context: Context; }): InterceptorResult { for (const source of SOURCES) { - if (context[source]) { - const userInput = extractStringsFromUserInput(context[source]); - for (const [str, path] of userInput.entries()) { - const found = findHostnameInUserInput(str, hostname, port); - if (found && containsPrivateIPAddress(hostname)) { - return { - operation: operation, - kind: "ssrf", - source: source, - pathToPayload: path, - metadata: {}, - payload: str, - }; - } + const userInput = extractStringsFromUserInputCached(context, source); + if (!userInput) { + continue; + } + + for (const [str, path] of userInput.entries()) { + const found = findHostnameInUserInput(str, hostname, port); + if (found && containsPrivateIPAddress(hostname)) { + return { + operation: operation, + kind: "ssrf", + source: source, + pathToPayload: path, + metadata: {}, + payload: str, + }; } } } diff --git a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts index df21cdd12..96fe64822 100644 --- a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts +++ b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts @@ -4,8 +4,8 @@ import { Agent } from "../../agent/Agent"; import { attackKindHumanName } from "../../agent/Attack"; import { Context, getContext } from "../../agent/Context"; import { Source, SOURCES } from "../../agent/Source"; -import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput"; import { escapeHTML } from "../../helpers/escapeHTML"; +import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { isPlainObject } from "../../helpers/isPlainObject"; import { findHostnameInUserInput } from "./findHostnameInUserInput"; import { isPrivateIP } from "./isPrivateIP"; @@ -181,17 +181,19 @@ function findHostnameInContext( port: number | undefined ): Location | undefined { for (const source of SOURCES) { - if (context[source]) { - const userInput = extractStringsFromUserInput(context[source]); - for (const [str, path] of userInput.entries()) { - const found = findHostnameInUserInput(str, hostname, port); - if (found) { - return { - source: source, - pathToPayload: path, - payload: str, - }; - } + const userInput = extractStringsFromUserInputCached(context, source); + if (!userInput) { + continue; + } + + for (const [str, path] of userInput.entries()) { + const found = findHostnameInUserInput(str, hostname, port); + if (found) { + return { + source: source, + pathToPayload: path, + payload: str, + }; } } }