Skip to content

Commit

Permalink
Merge pull request #297 from AikidoSec/AIK-2702
Browse files Browse the repository at this point in the history
Cache extractStringsFromUserInput result
  • Loading branch information
willem-delbare authored Jul 25, 2024
2 parents ffb1f37 + 8ba4439 commit 6fa9c0a
Show file tree
Hide file tree
Showing 16 changed files with 195 additions and 92 deletions.
36 changes: 35 additions & 1 deletion library/agent/Context.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import * as t from "tap";
import { extractStringsFromUserInputCached } from "../helpers/extractStringsFromUserInputCached";
import {
type Context,
getContext,
runWithContext,
bindContext,
updateContext,
} from "./Context";

const sampleContext: Context = {
remoteAddress: "::1",
method: "POST",
url: "http://localhost:4000",
query: {},
query: {
abc: "def",
},
headers: {},
body: undefined,
cookies: {},
Expand Down Expand Up @@ -98,3 +102,33 @@ t.test("Get context does work with bindContext", async (t) => {

emitter.emit("event");
});

t.test("it clears cache when context is mutated", async (t) => {
const context = { ...sampleContext };

runWithContext(context, () => {
t.same(extractStringsFromUserInputCached(getContext(), "body"), undefined);
t.same(
extractStringsFromUserInputCached(getContext(), "query"),
new Map(Object.entries({ abc: ".", def: ".abc" }))
);

updateContext(getContext(), "query", {});
t.same(extractStringsFromUserInputCached(getContext(), "body"), undefined);
t.same(
extractStringsFromUserInputCached(getContext(), "query"),
new Map(Object.entries({}))
);

runWithContext({ ...context, body: { a: "z" }, query: { b: "y" } }, () => {
t.same(
extractStringsFromUserInputCached(getContext(), "body"),
new Map(Object.entries({ a: ".", z: ".a" }))
);
t.same(
extractStringsFromUserInputCached(getContext(), "query"),
new Map(Object.entries({ b: ".", y: ".b" }))
);
});
});
});
32 changes: 31 additions & 1 deletion library/agent/Context.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import type { ParsedQs } from "qs";
import { extractStringsFromUserInput } from "../helpers/extractStringsFromUserInput";
import { ContextStorage } from "./context/ContextStorage";
import { AsyncResource } from "async_hooks";
import { Source, SOURCES } from "./Source";

export type User = { id: string; name?: string };

Expand All @@ -22,15 +24,35 @@ export type Context = {
graphql?: string[];
xml?: unknown;
subdomains?: string[]; // https://expressjs.com/en/5x/api.html#req.subdomains
cache?: Map<Source, ReturnType<typeof extractStringsFromUserInput>>;
};

/**
* Get the current request context that is being handled
*
* We don't want to allow the user to modify the context directly, so we use `Readonly<Context>`
*/
export function getContext() {
export function getContext(): Readonly<Context> | undefined {
return ContextStorage.getStore();
}

function isSourceKey(key: string): key is Source {
return SOURCES.includes(key as Source);
}

// We need to use a function to mutate the context because we need to clear the cache when the user input changes
export function updateContext<K extends keyof Context>(
context: Context,
key: K,
value: Context[K]
) {
context[key] = value;

if (context.cache && isSourceKey(key)) {
context.cache.delete(key);
}
}

/**
* Executes a function with a given request context
*
Expand Down Expand Up @@ -58,9 +80,17 @@ export function runWithContext<T>(context: Context, fn: () => T) {
current.xml = context.xml;
current.subdomains = context.subdomains;

// Clear all the cached user input strings
delete current.cache;

return fn();
}

// Cleanup lingering cache
// In tests the context is often passed by reference
// Make sure to clean up the cache before running the function
delete context.cache;

// If there's no context yet, we create a new context and run the function with it
return ContextStorage.run(context, fn);
}
Expand Down
4 changes: 2 additions & 2 deletions library/agent/applyHooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { satisfiesVersion } from "../helpers/satisfiesVersion";
import { escapeHTML } from "../helpers/escapeHTML";
import { Agent } from "./Agent";
import { attackKindHumanName } from "./Attack";
import { bindContext, getContext } from "./Context";
import { bindContext, getContext, updateContext } from "./Context";
import { BuiltinModule } from "./hooks/BuiltinModule";
import { ConstructorInterceptor } from "./hooks/ConstructorInterceptor";
import { Hooks } from "./hooks/Hooks";
Expand Down Expand Up @@ -209,7 +209,7 @@ function wrapWithoutArgumentModification(

if (result && context && !isAllowedIP) {
// Flag request as having an attack detected
context.attackDetected = true;
updateContext(context, "attackDetected", true);

agent.onDetectedAttack({
module: module,
Expand Down
25 changes: 25 additions & 0 deletions library/helpers/extractStringsFromUserInputCached.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { Context } from "../agent/Context";
import { Source } from "../agent/Source";
import { extractStringsFromUserInput } from "./extractStringsFromUserInput";

export function extractStringsFromUserInputCached(
context: Context,
source: Source
): ReturnType<typeof extractStringsFromUserInput> | undefined {
if (!context[source]) {
return undefined;
}

if (!context.cache) {
context.cache = new Map();
}

let result = context.cache.get(source);

if (!result) {
result = extractStringsFromUserInput(context[source]);
context.cache.set(source, result);
}

return result;
}
11 changes: 7 additions & 4 deletions library/ratelimiting/shouldRateLimitRequest.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Agent } from "../agent/Agent";
import { Context } from "../agent/Context";
import { Context, updateContext } from "../agent/Context";
import { isLocalhostIP } from "../helpers/isLocalhostIP";

type Result =
Expand All @@ -16,7 +16,10 @@ type Result =
};

// eslint-disable-next-line max-lines-per-function
export function shouldRateLimitRequest(context: Context, agent: Agent): Result {
export function shouldRateLimitRequest(
context: Readonly<Context>,
agent: Agent
): Result {
const match = agent.getConfig().getEndpoint(context);

if (!match) {
Expand Down Expand Up @@ -61,7 +64,7 @@ export function shouldRateLimitRequest(context: Context, agent: Agent): Result {

// This function is executed for every middleware and route handler
// We want to count the request only once
context.consumedRateLimitForIP = true;
updateContext(context, "consumedRateLimitForIP", true);

if (!allowed) {
return { block: true, trigger: "ip" };
Expand All @@ -79,7 +82,7 @@ export function shouldRateLimitRequest(context: Context, agent: Agent): Result {

// This function is executed for every middleware and route handler
// We want to count the request only once
context.consumedRateLimitForUser = true;
updateContext(context, "consumedRateLimitForUser", true);

if (!allowed) {
return { block: true, trigger: "user" };
Expand Down
4 changes: 2 additions & 2 deletions library/sources/FastXmlParser.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable prefer-rest-params */
import { getContext } from "../agent/Context";
import { getContext, updateContext } from "../agent/Context";
import { Hooks } from "../agent/hooks/Hooks";
import { Wrapper } from "../agent/Wrapper";
import { isPlainObject } from "../helpers/isPlainObject";
Expand Down Expand Up @@ -30,7 +30,7 @@ export class FastXmlParser implements Wrapper {

// Replace the body in the context with the parsed result
if (result && isPlainObject(result)) {
context.xml = result;
updateContext(context, "xml", result);
}
}

Expand Down
4 changes: 2 additions & 2 deletions library/sources/FunctionsFramework.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Agent } from "../agent/Agent";
import { setInstance } from "../agent/AgentSingleton";
import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting";
import { Token } from "../agent/api/Token";
import { getContext } from "../agent/Context";
import { getContext, updateContext } from "../agent/Context";
import { LoggerForTesting } from "../agent/logger/LoggerForTesting";
import {
createCloudFunctionWrapper,
Expand Down Expand Up @@ -50,7 +50,7 @@ function getExpressApp() {
asyncHandler(
createCloudFunctionWrapper((req, res) => {
const context = getContext();
context.attackDetected = true;
updateContext(context, "attackDetected", true);
res.send(context);
})
)
Expand Down
6 changes: 3 additions & 3 deletions library/sources/GraphQL.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable prefer-rest-params */
import { Agent } from "../agent/Agent";
import { getInstance } from "../agent/AgentSingleton";
import { getContext } from "../agent/Context";
import { getContext, updateContext } from "../agent/Context";
import { Hooks } from "../agent/hooks/Hooks";
import { Wrapper } from "../agent/Wrapper";
import type { ExecutionArgs } from "graphql/execution/execute";
Expand Down Expand Up @@ -60,9 +60,9 @@ export class GraphQL implements Wrapper {

if (userInputs.length > 0) {
if (Array.isArray(context.graphql)) {
context.graphql.push(...userInputs);
updateContext(context, "graphql", context.graphql.concat(userInputs));
} else {
context.graphql = userInputs;
updateContext(context, "graphql", userInputs);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions library/sources/Lambda.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Agent } from "../agent/Agent";
import { setInstance } from "../agent/AgentSingleton";
import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting";
import { Token } from "../agent/api/Token";
import { getContext } from "../agent/Context";
import { getContext, updateContext } from "../agent/Context";
import { LoggerNoop } from "../agent/logger/LoggerNoop";
import { createLambdaWrapper, SQSEvent, APIGatewayProxyEvent } from "./Lambda";

Expand Down Expand Up @@ -427,7 +427,7 @@ t.test("it counts attacks", async () => {

const handler = createLambdaWrapper(async (event, context) => {
const ctx = getContext();
ctx.attackDetected = true;
updateContext(ctx, "attackDetected", true);
return ctx;
});

Expand Down
5 changes: 3 additions & 2 deletions library/sources/Xml2js.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable prefer-rest-params */
import { getContext, runWithContext } from "../agent/Context";
import { getContext, updateContext, runWithContext } from "../agent/Context";
import { Hooks } from "../agent/hooks/Hooks";
import { Wrapper } from "../agent/Wrapper";
import { isPlainObject } from "../helpers/isPlainObject";
Expand Down Expand Up @@ -36,8 +36,9 @@ export class Xml2js implements Wrapper {
const originalCallback = args[1] as Function;
args[1] = function wrapCallback(err: Error, result: unknown) {
if (result && isPlainObject(result)) {
context.xml = result;
updateContext(context, "xml", result);
}

runWithContext(context, () => originalCallback(err, result));
};

Expand Down
4 changes: 2 additions & 2 deletions library/sources/XmlMinusJs.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable prefer-rest-params */
import { getContext } from "../agent/Context";
import { getContext, updateContext } from "../agent/Context";
import { Hooks } from "../agent/hooks/Hooks";
import { Wrapper } from "../agent/Wrapper";
import { isPlainObject } from "../helpers/isPlainObject";
Expand Down Expand Up @@ -31,7 +31,7 @@ export class XmlMinusJs implements Wrapper {

// Replace the body in the context with the parsed result
if (parsed && isPlainObject(parsed)) {
context.xml = parsed;
updateContext(context, "xml", parsed);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Context } from "../../agent/Context";
import { InterceptorResult } from "../../agent/hooks/MethodInterceptor";
import { SOURCES } from "../../agent/Source";
import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput";
import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached";
import { detectPathTraversal } from "./detectPathTraversal";

/**
Expand All @@ -26,21 +26,23 @@ export function checkContextForPathTraversal({
}

for (const source of SOURCES) {
if (context[source]) {
const userInput = extractStringsFromUserInput(context[source]);
for (const [str, path] of userInput.entries()) {
if (detectPathTraversal(pathString, str, checkPathStart, isUrl)) {
return {
operation: operation,
kind: "path_traversal",
source: source,
pathToPayload: path,
metadata: {
filename: pathString,
},
payload: str,
};
}
const userInput = extractStringsFromUserInputCached(context, source);
if (!userInput) {
continue;
}

for (const [str, path] of userInput.entries()) {
if (detectPathTraversal(pathString, str, checkPathStart, isUrl)) {
return {
operation: operation,
kind: "path_traversal",
source: source,
pathToPayload: path,
metadata: {
filename: pathString,
},
payload: str,
};
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Context } from "../../agent/Context";
import { InterceptorResult } from "../../agent/hooks/MethodInterceptor";
import { SOURCES } from "../../agent/Source";
import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput";
import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached";
import { detectShellInjection } from "./detectShellInjection";

/**
Expand All @@ -18,19 +18,21 @@ export function checkContextForShellInjection({
context: Context;
}): InterceptorResult {
for (const source of SOURCES) {
if (context[source]) {
const userInput = extractStringsFromUserInput(context[source]);
for (const [str, path] of userInput.entries()) {
if (detectShellInjection(command, str)) {
return {
operation: operation,
kind: "shell_injection",
source: source,
pathToPayload: path,
metadata: {},
payload: str,
};
}
const userInput = extractStringsFromUserInputCached(context, source);
if (!userInput) {
continue;
}

for (const [str, path] of userInput.entries()) {
if (detectShellInjection(command, str)) {
return {
operation: operation,
kind: "shell_injection",
source: source,
pathToPayload: path,
metadata: {},
payload: str,
};
}
}
}
Expand Down
Loading

0 comments on commit 6fa9c0a

Please sign in to comment.