diff --git a/benchmarks/express/benchmark.js b/benchmarks/express/benchmark.js index 9a907009e..591bae628 100644 --- a/benchmarks/express/benchmark.js +++ b/benchmarks/express/benchmark.js @@ -4,7 +4,7 @@ const { promisify } = require("util"); const exec = promisify(require("child_process").exec); // Accepted percentage of performance decrease -const AcceptedDecrease = 30; // % +const AcceptedDecrease = 40; // % function generateWrkCommandForUrl(url) { // Define the command with awk included diff --git a/library/agent/Agent.test.ts b/library/agent/Agent.test.ts index d0a3f011b..a1c15c2eb 100644 --- a/library/agent/Agent.test.ts +++ b/library/agent/Agent.test.ts @@ -14,6 +14,7 @@ import { LoggerForTesting } from "./logger/LoggerForTesting"; import { LoggerNoop } from "./logger/LoggerNoop"; import { Wrapper } from "./Wrapper"; import { Context } from "./Context"; +import { createTestAgent } from "../helpers/createTestAgent"; t.test("it throws error if serverless is empty string", async () => { t.throws( @@ -30,12 +31,17 @@ t.test("it throws error if serverless is empty string", async () => { }); t.test("it sends started event", async (t) => { - const logger = new LoggerNoop(); + const logger = new LoggerForTesting(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([new MongoDB()]); + const mongodb = require("mongodb"); + t.match(api.getEvents(), [ { type: "started", @@ -44,13 +50,11 @@ t.test("it sends started event", async (t) => { hostname: hostname(), version: "0.0.0", ipAddress: ip(), - packages: { - mongodb: "6.8.0", - }, + packages: {}, preventedPrototypePollution: false, nodeEnv: "", serverless: false, - stack: ["mongodb"], + stack: [], os: { name: platform(), version: release(), @@ -62,13 +66,22 @@ t.test("it sends started event", async (t) => { }, }, ]); + + t.same(logger.getMessages(), [ + "Starting agent...", + "Found token, reporting enabled!", + "mongodb@6.8.0 is supported!", + ]); }); t.test("it throws error if already started", async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([new MongoDB()]); t.throws(() => agent.start([new MongoDB()]), "Agent already started!"); }); @@ -82,9 +95,15 @@ class WrapperForTesting implements Wrapper { t.test("it logs if package is supported or not", async () => { const logger = new LoggerForTesting(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([new WrapperForTesting()]); + + agent.onPackageWrapped("shell-quote", { version: "1.8.1", supported: false }); + t.same(logger.getMessages(), [ "Starting agent...", "Found token, reporting enabled!", @@ -95,22 +114,30 @@ t.test("it logs if package is supported or not", async () => { t.test("it starts in non-blocking mode", async () => { const logger = new LoggerForTesting(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(false, logger, api, token, undefined); - agent.start([new MongoDB()]); + const agent = createTestAgent({ + block: false, + api, + logger, + token: new Token("123"), + }); + agent.start([]); + t.same(logger.getMessages(), [ "Starting agent...", "Dry mode enabled, no requests will be blocked!", "Found token, reporting enabled!", - "mongodb@6.8.0 is supported!", ]); }); t.test("when prevent prototype pollution is enabled", async (t) => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, "lambda"); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + serverless: "lambda", + }); agent.onPrototypePollutionPrevented(); agent.start([]); t.match(api.getEvents(), [ @@ -126,9 +153,12 @@ t.test("when prevent prototype pollution is enabled", async (t) => { t.test("it does not start interval in serverless mode", async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, "lambda"); - + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + serverless: "lambda", + }); // This would otherwise keep the process running agent.start([]); }); @@ -136,8 +166,11 @@ t.test("it does not start interval in serverless mode", async () => { t.test("when attack detected", async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.onDetectedAttack({ module: "mongodb", kind: "nosql_injection", @@ -196,8 +229,11 @@ t.test("when attack detected", async () => { t.test("it checks if user agent is a string", async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.onDetectedAttack({ module: "mongodb", kind: "nosql_injection", @@ -267,8 +303,11 @@ t.test( block: true, receivedAnyStats: false, }); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([]); t.match(api.getEvents(), [ { @@ -332,8 +371,11 @@ t.test( block: true, receivedAnyStats: false, }); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([]); t.match(api.getEvents(), [ { @@ -368,8 +410,11 @@ t.test("it sends heartbeat when reached max timings", async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([]); for (let i = 0; i < 1000; i++) { agent.getInspectionStatistics().onInspectedCall({ @@ -458,8 +503,11 @@ t.test("it logs when failed to report event", async () => { const logger = new LoggerForTesting(); const api = new ReportingAPIThatThrows(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([]); await waitForCalls(); @@ -514,8 +562,11 @@ t.test("unable to prevent prototype pollution", async () => { const logger = new LoggerForTesting(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([]); agent.unableToPreventPrototypePollution({ mongoose: "1.0.0" }); t.same(logger.getMessages(), [ @@ -542,9 +593,11 @@ t.test("unable to prevent prototype pollution", async () => { t.test("when payload is object", async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); - + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.onDetectedAttack({ module: "mongodb", kind: "nosql_injection", @@ -642,8 +695,11 @@ t.test("it sends hostnames and routes along with heartbeat", async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([]); agent.onConnectHostname("aikido.dev", 443); @@ -738,8 +794,11 @@ t.test( async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); t.same(agent.shouldBlock(), true); agent.start([]); @@ -755,8 +814,12 @@ t.test( async () => { const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(false, logger, api, token, undefined); + const agent = createTestAgent({ + block: false, + api, + logger, + token: new Token("123"), + }); t.same(agent.shouldBlock(), false); agent.start([]); @@ -778,8 +841,12 @@ t.test("it enables blocking mode after sending startup event", async () => { allowedIPAddresses: [], block: true, }); - const token = new Token("123"); - const agent = new Agent(false, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + block: false, + api, + logger, + }); t.same(agent.shouldBlock(), false); agent.start([]); @@ -800,8 +867,11 @@ t.test("it goes into monitoring mode after sending startup event", async () => { allowedIPAddresses: [], block: false, }); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); t.same(agent.shouldBlock(), true); agent.start([]); diff --git a/library/agent/Agent.ts b/library/agent/Agent.ts index 329b6a951..7d37195dc 100644 --- a/library/agent/Agent.ts +++ b/library/agent/Agent.ts @@ -409,22 +409,7 @@ export class Agent { } } - this.wrappedPackages = wrapInstalledPackages(this, wrappers); - - for (const pkg in this.wrappedPackages) { - const details = this.wrappedPackages[pkg]; - - /* c8 ignore next 3 */ - if (!details.version) { - continue; - } - - if (details.supported) { - this.logger.log(`${pkg}@${details.version} is supported!`); - } else { - this.logger.log(`${pkg}@${details.version} is not supported!`); - } - } + wrapInstalledPackages(wrappers); // Send startup event and wait for config // Then start heartbeats and polling for config changes @@ -442,6 +427,26 @@ export class Agent { this.logger.log(`Failed to wrap method ${name} in module ${module}`); } + onFailedToWrapModule(module: string, error: Error) { + this.logger.log(`Failed to wrap module ${module}: ${error.message}`); + } + + onPackageWrapped(name: string, details: WrappedPackage) { + if (this.wrappedPackages[name]) { + // Already reported as wrapped + return; + } + this.wrappedPackages[name] = details; + + if (details.version) { + if (details.supported) { + this.logger.log(`${name}@${details.version} is supported!`); + } else { + this.logger.log(`${name}@${details.version} is not supported!`); + } + } + } + onFailedToWrapPackage(module: string) { this.logger.log(`Failed to wrap package ${module}`); } diff --git a/library/agent/applyHooks.test.ts b/library/agent/applyHooks.test.ts index b2b133c5a..bc7a6fb1f 100644 --- a/library/agent/applyHooks.test.ts +++ b/library/agent/applyHooks.test.ts @@ -1,11 +1,11 @@ import * as t from "tap"; -import { Agent } from "./Agent"; import { ReportingAPIForTesting } from "./api/ReportingAPIForTesting"; import { Token } from "./api/Token"; import { applyHooks } from "./applyHooks"; import { Context, runWithContext } from "./Context"; import { Hooks } from "./hooks/Hooks"; -import { LoggerForTesting } from "./logger/LoggerForTesting"; +import { wrapExport } from "./hooks/wrapExport"; +import { createTestAgent } from "../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -20,183 +20,12 @@ const context: Context = { route: "/posts/:id", }; -function createAgent() { - const logger = new LoggerForTesting(); - const api = new ReportingAPIForTesting(); - const agent = new Agent(true, logger, api, new Token("123"), "lambda"); +const reportingAPI = new ReportingAPIForTesting(); - return { - agent, - logger, - api, - }; -} - -t.test("it ignores if package is not installed", async (t) => { - const hooks = new Hooks(); - hooks.addPackage("unknown").withVersion("^1.0.0"); - - const { agent } = createAgent(); - t.same(applyHooks(hooks, agent), {}); -}); - -t.test("it ignores if packages have empty selectors", async (t) => { - const hooks = new Hooks(); - hooks.addPackage("shell-quote").withVersion("^1.0.0"); - - const { agent } = createAgent(); - t.same(applyHooks(hooks, agent), { - "shell-quote": { - version: "1.8.1", - supported: false, - }, - }); -}); - -t.test("it ignores unknown selectors", async (t) => { - const hooks = new Hooks(); - hooks - .addPackage("shell-quote") - .withVersion("^1.0.0") - .addSubject((exports) => exports.doesNotExist) - .inspect("method", () => {}); - - const { agent } = createAgent(); - t.same(applyHooks(hooks, agent), { - "shell-quote": { - version: "1.8.1", - supported: true, - }, - }); - - // Force require to load shell-quote - require("shell-quote"); -}); - -t.test("it tries to wrap method that does not exist", async (t) => { - const hooks = new Hooks(); - hooks - .addPackage("shell-quote") - .withVersion("^1.0.0") - .addSubject((exports) => exports) - .inspect("does_not_exist", () => {}) - .modifyArguments("another_method_that_does_not_exist", (args) => { - return args; - }) - .inspectResult("another_second_method_that_does_not_exist", () => {}); - - const { agent, logger } = createAgent(); - t.same(applyHooks(hooks, agent), { - "shell-quote": { - version: "1.8.1", - supported: true, - }, - }); - - t.same(logger.getMessages(), [ - "Failed to wrap method another_second_method_that_does_not_exist in module shell-quote", - "Failed to wrap method another_method_that_does_not_exist in module shell-quote", - "Failed to wrap method does_not_exist in module shell-quote", - ]); -}); - -t.test("it ignores if version is not supported", async (t) => { - const hooks = new Hooks(); - hooks - .addPackage("shell-quote") - .withVersion("^2.0.0") - .addSubject((exports) => exports) - .inspect("method", () => {}); - - const { agent } = createAgent(); - t.same(applyHooks(hooks, agent), { - "shell-quote": { - version: "1.8.1", - supported: false, - }, - }); -}); - -function removeStackTraceErrorMessage(error: string) { - const [msg] = error.split("\n"); - - return msg; -} - -t.test("it adds try/catch around the wrapped method", async (t) => { - const hooks = new Hooks(); - const connection = hooks - .addPackage("mysql2") - .withVersion("^3.0.0") - .addSubject((exports) => exports.Connection.prototype); - connection.inspect("query", () => { - throw new Error("THIS SHOULD BE CATCHED"); - }); - connection.modifyArguments("execute", () => { - throw new Error("THIS SHOULD BE CATCHED"); - }); - connection.inspectResult("execute", () => { - throw new Error("THIS SHOULD BE CATCHED"); - }); - - const { agent, logger } = createAgent(); - t.same(applyHooks(hooks, agent), { - mysql2: { - version: "3.11.0", - supported: true, - }, - }); - - const mysql = require("mysql2/promise"); - const actualConnection = await mysql.createConnection({ - host: "localhost", - user: "root", - password: "mypassword", - database: "catsdb", - port: 27015, - multipleStatements: true, - }); - - const [queryRows] = await runWithContext(context, () => - actualConnection.query("SELECT 1 as number") - ); - t.same(queryRows, [{ number: 1 }]); - - const [executeRows] = await runWithContext(context, () => - actualConnection.execute("SELECT 1 as number") - ); - t.same(executeRows, [{ number: 1 }]); - - t.same(logger.getMessages().map(removeStackTraceErrorMessage), [ - 'Internal error in module "mysql2" in method "query"', - 'Internal error in module "mysql2" in method "execute"', - 'Internal error in module "mysql2" in method "execute"', - ]); - - await actualConnection.end(); -}); - -t.test("it hooks into dns module", async (t) => { - const seenDomains: string[] = []; - - const hooks = new Hooks(); - hooks - .addBuiltinModule("dns") - .addSubject((exports) => exports.promises) - .inspect("lookup", (args) => { - if (typeof args[0] === "string") { - seenDomains.push(args[0]); - } - }); - - const { agent } = createAgent(); - t.same(applyHooks(hooks, agent), {}); - - const { lookup } = require("dns/promises"); - - await runWithContext(context, async () => await lookup("google.com")); - - t.same(seenDomains, ["google.com"]); +const agent = createTestAgent({ + serverless: "lambda", + api: reportingAPI, + token: new Token("123"), }); t.test( @@ -206,30 +35,50 @@ t.test( const hooks = new Hooks(); let modifyCalled = false; - hooks.addGlobal("fetch").inspect((args) => { - modifyCalled = true; + hooks.addGlobal("fetch", { + modifyArgs: (args) => { + modifyCalled = true; + return args; + }, }); let inspectCalled = false; - hooks.addGlobal("atob").modifyArguments((args) => { - inspectCalled = true; - - return args; + hooks.addGlobal("atob", { + inspectArgs: (args) => { + inspectCalled = true; + }, }); // Unknown global - hooks.addGlobal("unknown").inspect(() => {}); + hooks.addGlobal("unknown", { + inspectArgs: (args) => { + return; + }, + }); + + // Without name + // @ts-expect-error Test with invalid arguments + const error = t.throws(() => hooks.addGlobal()); + t.ok(error instanceof Error); + if (error instanceof Error) { + t.match(error.message, /Name is required/); + } // Without interceptor - hooks.addGlobal("setTimeout"); + // @ts-expect-error Test with invalid arguments + const error2 = t.throws(() => hooks.addGlobal("setTimeout")); + t.ok(error2 instanceof Error); + if (error2 instanceof Error) { + t.match(error2.message, /Interceptors are required/); + } - const { agent, logger } = createAgent(); - t.same(applyHooks(hooks, agent), {}); + applyHooks(hooks); await runWithContext(context, async () => { await fetch("https://aikido.dev"); t.same(modifyCalled, true); + t.same(inspectCalled, false); atob("aGVsbG8gd29ybGQ="); t.same(inspectCalled, true); }); @@ -240,24 +89,28 @@ t.test("it ignores route if force protection off is on", async (t) => { const inspectionCalls: { args: unknown[] }[] = []; const hooks = new Hooks(); - hooks - .addBuiltinModule("dns/promises") - .addSubject((exports) => exports) - .inspect("lookup", (args) => { - inspectionCalls.push({ args }); + hooks.addBuiltinModule("dns/promises").onRequire((exports, pkgInfo) => { + wrapExport(exports, "lookup", pkgInfo, { + inspectArgs: (args, agent) => { + inspectionCalls.push({ args }); + }, }); + }); - const { agent, api } = createAgent(); - applyHooks(hooks, agent); + applyHooks(hooks); - api.setResult({ + reportingAPI.setResult({ success: true, endpoints: [ { method: "GET", route: "/route", forceProtectionOff: true, - rateLimiting: undefined, + rateLimiting: { + enabled: false, + maxRequests: 0, + windowSizeInMS: 0, + }, }, ], heartbeatIntervalInMS: 10 * 60 * 1000, @@ -302,24 +155,24 @@ t.test("it ignores route if force protection off is on", async (t) => { t.test("it does not report attack if IP is allowed", async (t) => { const hooks = new Hooks(); - hooks - .addBuiltinModule("os") - .addSubject((exports) => exports) - .inspect("hostname", (args, subject, agent) => { - return { - operation: "os.hostname", - source: "body", - pathToPayload: "path", - payload: "payload", - metadata: {}, - kind: "path_traversal", - }; + hooks.addBuiltinModule("os").onRequire((exports, pkgInfo) => { + wrapExport(exports, "hostname", pkgInfo, { + inspectArgs: (args, agent) => { + return { + operation: "os.hostname", + source: "body", + pathToPayload: "path", + payload: "payload", + metadata: {}, + kind: "path_traversal", + }; + }, }); + }); - const { agent, api } = createAgent(); - applyHooks(hooks, agent); + applyHooks(hooks); - api.setResult({ + reportingAPI.setResult({ success: true, endpoints: [], configUpdatedAt: 0, @@ -330,7 +183,7 @@ t.test("it does not report attack if IP is allowed", async (t) => { // Read rules from API await agent.flushStats(1000); - api.clear(); + reportingAPI.clear(); const { hostname } = require("os"); @@ -339,29 +192,5 @@ t.test("it does not report attack if IP is allowed", async (t) => { t.ok(typeof name === "string"); }); - t.same(api.getEvents(), []); -}); - -t.test("it can get the result of a method", async (t) => { - let receivedResult: unknown | undefined; - let receivedArgs: unknown[] | undefined; - - const hooks = new Hooks(); - hooks - .addBuiltinModule("path") - .addSubject((exports) => exports) - .inspectResult("extname", (args, result) => { - receivedArgs = args; - receivedResult = result; - }); - - const { agent } = createAgent(); - t.same(applyHooks(hooks, agent), {}); - - const { extname } = require("path"); - - await runWithContext(context, async () => extname("file.txt")); - - t.same(receivedArgs, ["file.txt"]); - t.same(receivedResult, ".txt"); + t.same(reportingAPI.getEvents(), []); }); diff --git a/library/agent/applyHooks.ts b/library/agent/applyHooks.ts index a3fb1cd54..754a0e9ed 100644 --- a/library/agent/applyHooks.ts +++ b/library/agent/applyHooks.ts @@ -1,91 +1,22 @@ -/* eslint-disable max-lines-per-function */ -import { join, resolve } from "path"; -import { cleanupStackTrace } from "../helpers/cleanupStackTrace"; -import { wrap } from "../helpers/wrap"; -import { getPackageVersion } from "../helpers/getPackageVersion"; -import { satisfiesVersion } from "../helpers/satisfiesVersion"; -import { escapeHTML } from "../helpers/escapeHTML"; -import { Agent } from "./Agent"; -import { attackKindHumanName } from "./Attack"; -import { bindContext, getContext, updateContext } from "./Context"; -import { BuiltinModule } from "./hooks/BuiltinModule"; -import { ConstructorInterceptor } from "./hooks/ConstructorInterceptor"; import { Hooks } from "./hooks/Hooks"; import { - InterceptorResult, - MethodInterceptor, -} from "./hooks/MethodInterceptor"; -import { ModifyingArgumentsMethodInterceptor } from "./hooks/ModifyingArgumentsInterceptor"; -import { Package } from "./hooks/Package"; -import { WrappableFile } from "./hooks/WrappableFile"; -import { WrappableSubject } from "./hooks/WrappableSubject"; -import { MethodResultInterceptor } from "./hooks/MethodResultInterceptor"; -import { isPackageInstalled } from "../helpers/isPackageInstalled"; + setBuiltinModulesToPatch, + setPackagesToPatch, + wrapRequire, +} from "./hooks/wrapRequire"; +import { wrapExport } from "./hooks/wrapExport"; /** * Hooks allows you to register packages and then wrap specific methods on * the exports of the package. This doesn't do the actual wrapping yet. * - * That's where applyHooks comes in, we take the registered packages and - * its methods and do the actual wrapping so that we can intercept method calls. + * This method wraps the require function and sets up the hooks. + * Globals are wrapped directly. */ -export function applyHooks(hooks: Hooks, agent: Agent) { - const wrapped: Record = {}; - - hooks.getPackages().forEach((pkg) => { - const version = getPackageVersion(pkg.getName()); - - if (!version) { - return; - } - - wrapped[pkg.getName()] = { - version, - supported: false, - }; - - const versions = pkg - .getVersions() - .map((versioned) => { - if (!satisfiesVersion(versioned.getRange(), version)) { - return []; - } - - return { - subjects: versioned.getSubjects(), - files: versioned.getFiles(), - }; - }) - .flat(); - - const files = versions.map((hook) => hook.files).flat(); - const subjects = versions.map((hook) => hook.subjects).flat(); - - if (subjects.length === 0 && files.length === 0) { - return; - } - - wrapped[pkg.getName()] = { - version, - supported: true, - }; - - if (subjects.length > 0) { - wrapPackage(pkg, subjects, agent); - } - - if (files.length > 0) { - wrapFiles(pkg, files, agent); - } - }); - - hooks.getBuiltInModules().forEach((module) => { - const subjects = module.getSubjects(); - - if (subjects.length > 0) { - wrapBuiltInModule(module, subjects, agent); - } - }); +export function applyHooks(hooks: Hooks) { + setPackagesToPatch(hooks.getPackages()); + setBuiltinModulesToPatch(hooks.getBuiltInModules()); + wrapRequire(); hooks.getGlobals().forEach((g) => { const name = g.getName(); @@ -94,307 +25,14 @@ export function applyHooks(hooks: Hooks, agent: Agent) { return; } - g.getMethodInterceptors() - .reverse() // Reverse to make sure we wrap in the order they were added - .forEach((interceptor) => { - if (interceptor instanceof ModifyingArgumentsMethodInterceptor) { - wrapWithArgumentModification(global, interceptor, name, agent); - } else { - wrapWithoutArgumentModification(global, interceptor, name, agent); - } - }); - }); - - return wrapped; -} - -function wrapFiles(pkg: Package, files: WrappableFile[], agent: Agent) { - files.forEach((file) => { - try { - const exports = require(join(pkg.getName(), file.getRelativePath())); - - file - .getSubjects() - .forEach( - (subject) => wrapSubject(exports, subject, pkg.getName(), agent), - agent - ); - } catch (error) { - agent.onFailedToWrapFile(pkg.getName(), file.getRelativePath()); - } - }); -} - -function wrapBuiltInModule( - module: BuiltinModule, - subjects: WrappableSubject[], - agent: Agent -) { - if (!isPackageInstalled(module.getName())) { - return; - } - try { - const exports = require(module.getName()); - - subjects.forEach( - (selector) => wrapSubject(exports, selector, module.getName(), agent), - agent + wrapExport( + global, + name, + { + name: name, + type: "global", + }, + g.getInterceptors() ); - } catch (error) { - agent.onFailedToWrapPackage(module.getName()); - } -} - -function wrapPackage(pkg: Package, subjects: WrappableSubject[], agent: Agent) { - try { - const exports = require(pkg.getName()); - - subjects.forEach( - (selector) => wrapSubject(exports, selector, pkg.getName(), agent), - agent - ); - } catch (error) { - agent.onFailedToWrapPackage(pkg.getName()); - } -} - -/** - * Wraps a method call with an interceptor that doesn't modify the arguments of the method call. - */ -function wrapWithoutArgumentModification( - subject: unknown, - method: MethodInterceptor, - module: string, - agent: Agent -) { - const libraryRoot = resolve(__dirname, ".."); - - try { - wrap(subject, method.getName(), function wrap(original: Function) { - return function wrap() { - // eslint-disable-next-line prefer-rest-params - const args = Array.from(arguments); - const context = getContext(); - - for (let i = 0; i < args.length; i++) { - if (typeof args[i] === "function") { - args[i] = bindContext(args[i]); - } - } - - if (context) { - const matches = agent.getConfig().getEndpoints(context); - - if (matches.find((match) => match.forceProtectionOff)) { - return original.apply( - // @ts-expect-error We don't now the type of this - this, - args - ); - } - } - - const start = performance.now(); - let result: InterceptorResult = undefined; - - try { - // @ts-expect-error We don't now the type of this - result = method.getInterceptor()(args, this, agent, context); - } catch (error: any) { - agent.getInspectionStatistics().interceptorThrewError(module); - agent.onErrorThrownByInterceptor({ - error: error, - method: method.getName(), - module: module, - }); - } - - const end = performance.now(); - agent.getInspectionStatistics().onInspectedCall({ - sink: module, - attackDetected: !!result, - blocked: agent.shouldBlock(), - durationInMs: end - start, - withoutContext: !context, - }); - - const isAllowedIP = - context && - context.remoteAddress && - agent.getConfig().isAllowedIP(context.remoteAddress); - - if (result && context && !isAllowedIP) { - // Flag request as having an attack detected - updateContext(context, "attackDetected", true); - - agent.onDetectedAttack({ - module: module, - operation: result.operation, - kind: result.kind, - source: result.source, - blocked: agent.shouldBlock(), - stack: cleanupStackTrace(new Error().stack!, libraryRoot), - path: result.pathToPayload, - metadata: result.metadata, - request: context, - payload: result.payload, - }); - - if (agent.shouldBlock()) { - throw new Error( - `Zen has blocked ${attackKindHumanName(result.kind)}: ${result.operation}(...) originating from ${result.source}${escapeHTML(result.pathToPayload)}` - ); - } - } - - return original.apply( - // @ts-expect-error We don't now the type of this - this, - args - ); - }; - }); - } catch (error) { - agent.onFailedToWrapMethod(module, method.getName()); - } -} - -/** - * Wraps a method call with an interceptor that modifies the arguments of the method call. - */ -function wrapWithArgumentModification( - subject: unknown, - method: ModifyingArgumentsMethodInterceptor, - module: string, - agent: Agent -) { - try { - wrap(subject, method.getName(), function wrap(original: Function) { - return function wrap() { - // eslint-disable-next-line prefer-rest-params - const args = Array.from(arguments); - let updatedArgs = args; - - try { - // @ts-expect-error We don't now the type of this - updatedArgs = method.getInterceptor()(args, this, agent); - } catch (error: any) { - agent.onErrorThrownByInterceptor({ - error: error, - method: method.getName(), - module: module, - }); - } - - return original.apply( - // @ts-expect-error We don't now the type of this - this, - updatedArgs - ); - }; - }); - } catch (error) { - agent.onFailedToWrapMethod(module, method.getName()); - } -} - -function wrapNewInstance( - subject: unknown, - constructor: ConstructorInterceptor, - module: string, - agent: Agent -) { - const subjects = constructor.getSubjects(); - - if (subjects.length === 0) { - return; - } - - try { - wrap(subject, constructor.getName(), function wrap(original: Function) { - return function wrap() { - // eslint-disable-next-line prefer-rest-params - const args = Array.from(arguments); - - // @ts-expect-error It's a constructor - const newInstance = new original(...args); - subjects.forEach((subject) => { - wrapSubject(newInstance, subject, module, agent); - }); - - return newInstance; - }; - }); - } catch (error) { - agent.onFailedToWrapMethod(module, constructor.getName()); - } -} - -/** - * Wraps a method call with an interceptor that is called after the method call has returned. - * Returns the arguments and the result of the method call. - */ -function wrapWithResult( - subject: unknown, - method: MethodResultInterceptor, - module: string, - agent: Agent -) { - try { - wrap(subject, method.getName(), function wrap(original: Function) { - return function wrap() { - // eslint-disable-next-line prefer-rest-params - const args = Array.from(arguments); - - const result = original.apply( - // @ts-expect-error We don't now the type of this - this, - args - ); - - try { - // @ts-expect-error We don't now the type of this - method.getInterceptor()(args, result, this, agent); - } catch (error: any) { - agent.onErrorThrownByInterceptor({ - error: error, - method: method.getName(), - module: module, - }); - } - - return result; - }; - }); - } catch (error) { - agent.onFailedToWrapMethod(module, method.getName()); - } -} - -function wrapSubject( - exports: unknown, - subject: WrappableSubject, - module: string, - agent: Agent -) { - const theSubject = subject.getSelector()(exports); - - if (!theSubject) { - return; - } - - subject - .getMethodInterceptors() - .reverse() // Reverse to make sure we wrap in the order they were added - .forEach((method) => { - if (method instanceof ModifyingArgumentsMethodInterceptor) { - wrapWithArgumentModification(theSubject, method, module, agent); - } else if (method instanceof MethodInterceptor) { - wrapWithoutArgumentModification(theSubject, method, module, agent); - } else if (method instanceof MethodResultInterceptor) { - wrapWithResult(theSubject, method, module, agent); - } else { - wrapNewInstance(theSubject, method, module, agent); - } - }); + }); } diff --git a/library/agent/context/user.test.ts b/library/agent/context/user.test.ts index af0d49bc8..14eb59d89 100644 --- a/library/agent/context/user.test.ts +++ b/library/agent/context/user.test.ts @@ -1,23 +1,12 @@ import * as t from "tap"; -import { Agent } from "../Agent"; -import { setInstance } from "../AgentSingleton"; -import { ReportingAPIForTesting } from "../api/ReportingAPIForTesting"; import { type Context, getContext, runWithContext } from "../Context"; import { LoggerForTesting } from "../logger/LoggerForTesting"; -import { LoggerNoop } from "../logger/LoggerNoop"; import { setUser } from "./user"; +import { createTestAgent } from "../../helpers/createTestAgent"; -t.test("it does not set user if empty id", async (t) => { - setInstance( - new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ) - ); +createTestAgent(); +t.test("it does not set user if empty id", async (t) => { const context: Context = { remoteAddress: "::1", method: "POST", @@ -40,30 +29,10 @@ t.test("it does not set user if empty id", async (t) => { }); t.test("it does not set user if not inside context", async () => { - setInstance( - new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ) - ); - setUser({ id: "id" }); }); t.test("it sets user", async (t) => { - setInstance( - new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ) - ); - const context: Context = { remoteAddress: "::1", method: "POST", @@ -88,16 +57,6 @@ t.test("it sets user", async (t) => { }); t.test("it sets user with number as ID", async (t) => { - setInstance( - new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ) - ); - const context: Context = { remoteAddress: "::1", method: "POST", @@ -122,16 +81,6 @@ t.test("it sets user with number as ID", async (t) => { }); t.test("it sets user with name", async (t) => { - setInstance( - new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ) - ); - const context: Context = { remoteAddress: "::1", method: "POST", @@ -158,9 +107,7 @@ t.test("it sets user with name", async (t) => { t.test("it logs when setUser has invalid input", async () => { const logger = new LoggerForTesting(); - setInstance( - new Agent(true, logger, new ReportingAPIForTesting(), undefined, undefined) - ); + createTestAgent({ logger }); setUser(1); t.same(logger.getMessages(), [ diff --git a/library/agent/hooks/BuiltinModule.ts b/library/agent/hooks/BuiltinModule.ts index c907dff15..6d3685a77 100644 --- a/library/agent/hooks/BuiltinModule.ts +++ b/library/agent/hooks/BuiltinModule.ts @@ -1,7 +1,7 @@ -import { WrappableSubject } from "./WrappableSubject"; +import { RequireInterceptor } from "./RequireInterceptor"; export class BuiltinModule { - private subjects: WrappableSubject[] = []; + private requireInterceptors: RequireInterceptor[] = []; constructor(private readonly name: string) { if (!this.name) { @@ -13,14 +13,11 @@ export class BuiltinModule { return this.name; } - addSubject(selector: (exports: any) => unknown): WrappableSubject { - const fn = new WrappableSubject(selector); - this.subjects.push(fn); - - return fn; + onRequire(interceptor: RequireInterceptor) { + this.requireInterceptors.push(interceptor); } - getSubjects() { - return this.subjects; + getRequireInterceptors() { + return this.requireInterceptors; } } diff --git a/library/agent/hooks/ConstructorInterceptor.ts b/library/agent/hooks/ConstructorInterceptor.ts deleted file mode 100644 index 8b74699a1..000000000 --- a/library/agent/hooks/ConstructorInterceptor.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { WrappableSubject } from "./WrappableSubject"; - -/* - * We want to be notified whenever a new instance of S3 is created - * - * const AWS = require("aws-sdk"); - * new AWS.S3(); - */ -export class ConstructorInterceptor { - private readonly subjects: WrappableSubject[] = []; - - constructor(private readonly name: string) { - if (!this.name) { - throw new Error("Name is required"); - } - } - - getName() { - return this.name; - } - - addSubject(selector: (exports: any) => unknown): WrappableSubject { - const fn = new WrappableSubject(selector); - this.subjects.push(fn); - - return fn; - } - - getSubjects() { - return this.subjects; - } -} diff --git a/library/agent/hooks/Global.ts b/library/agent/hooks/Global.ts index 2f9dc00cf..b9a64ea58 100644 --- a/library/agent/hooks/Global.ts +++ b/library/agent/hooks/Global.ts @@ -1,53 +1,23 @@ -import { Interceptor, MethodInterceptor } from "./MethodInterceptor"; -import { - ModifyingArgumentsInterceptor, - ModifyingArgumentsMethodInterceptor, -} from "./ModifyingArgumentsInterceptor"; +import { InterceptorObject } from "./wrapExport"; export class Global { - private methods: (MethodInterceptor | ModifyingArgumentsMethodInterceptor)[] = - []; - - constructor(private readonly name: string) { + constructor( + private readonly name: string, + private readonly interceptors: InterceptorObject + ) { if (!this.name) { throw new Error("Name is required"); } - } - - /** - * Inspect method calls without modifying arguments - * - * This is the preferred way to use when wrapping methods - */ - inspect(interceptor: Interceptor) { - const method = new MethodInterceptor(this.name, interceptor); - this.methods.push(method); - - return this; - } - - /** - * Inspect methods call and return modified arguments - * - * e.g. to append our middleware to express routes - * - * Don't use this unless you have to, it's better to use inspect - */ - modifyArguments(interceptor: ModifyingArgumentsInterceptor) { - const method = new ModifyingArgumentsMethodInterceptor( - this.name, - interceptor - ); - this.methods.push(method); - - return this; + if (!this.interceptors) { + throw new Error("Interceptors are required"); + } } getName() { return this.name; } - getMethodInterceptors() { - return this.methods; + getInterceptors() { + return this.interceptors; } } diff --git a/library/agent/hooks/Hooks.test.ts b/library/agent/hooks/Hooks.test.ts index 32afef394..3980ea26c 100644 --- a/library/agent/hooks/Hooks.test.ts +++ b/library/agent/hooks/Hooks.test.ts @@ -14,23 +14,23 @@ t.test("withVersion throws error if version is empty", async (t) => { t.throws(() => subject.withVersion("")); }); -t.test("file throws error if path is empty", async (t) => { +t.test("throws error if interceptor is not a function", async (t) => { const hooks = new Hooks(); - const subject = hooks.addPackage("package").withVersion("^1.0.0"); + const vPackage = hooks.addPackage("package").withVersion("^1.0.0"); - t.throws(() => subject.addFile("")); + // @ts-expect-error Testing invalid input + t.throws(() => subject.onRequire("")); }); -t.test("method throws error if name is empty", async (t) => { +t.test("returns require interceptors", async (t) => { const hooks = new Hooks(); - const subject = hooks - .addPackage("package") - .withVersion("^1.0.0") - .addSubject((exports) => exports); - - t.throws(() => subject.inspect("", () => {})); - t.throws(() => subject.modifyArguments("", (args) => args)); - t.throws(() => subject.inspectResult("", (args, result) => {})); + + const interceptor = () => {}; + + const vPackage = hooks.addPackage("package").withVersion("^1.0.0"); + vPackage.onRequire(interceptor); + + t.same(vPackage.getRequireInterceptors(), [interceptor]); }); t.test("add builtin module throws if name is empty", async (t) => { @@ -42,15 +42,5 @@ t.test("add builtin module throws if name is empty", async (t) => { t.test("it throws error if global name is empty", async () => { const hooks = new Hooks(); - t.throws(() => hooks.addGlobal("")); -}); - -t.test("it throws if name is empty", async () => { - const hooks = new Hooks(); - const subject = hooks - .addPackage("package") - .withVersion("^1.0.0") - .addSubject((exports) => exports); - - t.throws(() => subject.inspectNewInstance("")); + t.throws(() => hooks.addGlobal("", {})); }); diff --git a/library/agent/hooks/Hooks.ts b/library/agent/hooks/Hooks.ts index 689554214..988166d29 100644 --- a/library/agent/hooks/Hooks.ts +++ b/library/agent/hooks/Hooks.ts @@ -1,6 +1,7 @@ import { BuiltinModule } from "./BuiltinModule"; import { Global } from "./Global"; import { Package } from "./Package"; +import { InterceptorObject } from "./wrapExport"; export class Hooks { private readonly packages: Package[] = []; @@ -14,11 +15,9 @@ export class Hooks { return pkg; } - addGlobal(name: string): Global { - const global = new Global(name); + addGlobal(name: string, interceptors: InterceptorObject) { + const global = new Global(name, interceptors); this.globals.push(global); - - return global; } addBuiltinModule(name: string): BuiltinModule { diff --git a/library/agent/hooks/InterceptorResult.ts b/library/agent/hooks/InterceptorResult.ts new file mode 100644 index 000000000..99ca0f9e6 --- /dev/null +++ b/library/agent/hooks/InterceptorResult.ts @@ -0,0 +1,11 @@ +import { Kind } from "../Attack"; +import { Source } from "../Source"; + +export type InterceptorResult = { + operation: string; + kind: Kind; + source: Source; + pathToPayload: string; + metadata: Record; + payload: unknown; +} | void; diff --git a/library/agent/hooks/MethodInterceptor.ts b/library/agent/hooks/MethodInterceptor.ts deleted file mode 100644 index 64f1a53b3..000000000 --- a/library/agent/hooks/MethodInterceptor.ts +++ /dev/null @@ -1,37 +0,0 @@ -import { Agent } from "../Agent"; -import { Kind } from "../Attack"; -import { Source } from "../Source"; - -export type InterceptorResult = { - operation: string; - kind: Kind; - source: Source; - pathToPayload: string; - metadata: Record; - payload: unknown; -} | void; - -export type Interceptor = ( - args: unknown[], - subject: unknown, - agent: Agent -) => InterceptorResult; - -export class MethodInterceptor { - constructor( - private readonly name: string, - private readonly interceptor: Interceptor - ) { - if (!this.name) { - throw new Error("Method name is required"); - } - } - - getName() { - return this.name; - } - - getInterceptor() { - return this.interceptor; - } -} diff --git a/library/agent/hooks/MethodResultInterceptor.ts b/library/agent/hooks/MethodResultInterceptor.ts deleted file mode 100644 index d8c26e960..000000000 --- a/library/agent/hooks/MethodResultInterceptor.ts +++ /dev/null @@ -1,27 +0,0 @@ -import { Agent } from "../Agent"; - -export type ResultInterceptor = ( - args: unknown[], - result: unknown, - subject: unknown, - agent: Agent -) => void; - -export class MethodResultInterceptor { - constructor( - private readonly name: string, - private readonly interceptor: ResultInterceptor - ) { - if (!this.name) { - throw new Error("Method name is required"); - } - } - - getName() { - return this.name; - } - - getInterceptor() { - return this.interceptor; - } -} diff --git a/library/agent/hooks/ModifyingArgumentsInterceptor.ts b/library/agent/hooks/ModifyingArgumentsInterceptor.ts deleted file mode 100644 index 5d4d2bbab..000000000 --- a/library/agent/hooks/ModifyingArgumentsInterceptor.ts +++ /dev/null @@ -1,26 +0,0 @@ -import { Agent } from "../Agent"; - -export type ModifyingArgumentsInterceptor = ( - args: unknown[], - subject: unknown, - agent: Agent -) => unknown[]; - -export class ModifyingArgumentsMethodInterceptor { - constructor( - private readonly name: string, - private readonly interceptor: ModifyingArgumentsInterceptor - ) { - if (!this.name) { - throw new Error("Method name is required"); - } - } - - getName() { - return this.name; - } - - getInterceptor() { - return this.interceptor; - } -} diff --git a/library/agent/hooks/RequireInterceptor.ts b/library/agent/hooks/RequireInterceptor.ts new file mode 100644 index 000000000..c6566233a --- /dev/null +++ b/library/agent/hooks/RequireInterceptor.ts @@ -0,0 +1,6 @@ +import { WrapPackageInfo } from "./WrapPackageInfo"; + +export type RequireInterceptor = ( + exports: any, + pkgInfo: WrapPackageInfo +) => void | unknown; diff --git a/library/agent/hooks/VersionedPackage.ts b/library/agent/hooks/VersionedPackage.ts index e4ff647ef..6985543fd 100644 --- a/library/agent/hooks/VersionedPackage.ts +++ b/library/agent/hooks/VersionedPackage.ts @@ -1,9 +1,8 @@ -import { WrappableSubject } from "./WrappableSubject"; -import { WrappableFile } from "./WrappableFile"; +import { RequireInterceptor } from "./RequireInterceptor"; export class VersionedPackage { - private subjects: WrappableSubject[] = []; - private files: WrappableFile[] = []; + private requireInterceptors: RequireInterceptor[] = []; + private requireFileInterceptors = new Map(); constructor(private readonly range: string) { if (!this.range) { @@ -15,25 +14,51 @@ export class VersionedPackage { return this.range; } - addFile(relativePath: string): WrappableFile { - const file = new WrappableFile(relativePath); - this.files.push(file); + onRequire(interceptor: RequireInterceptor) { + if (typeof interceptor !== "function") { + throw new Error("Interceptor must be a function"); + } + + this.requireInterceptors.push(interceptor); - return file; + return this; } - addSubject(selector: (exports: any) => unknown): WrappableSubject { - const fn = new WrappableSubject(selector); - this.subjects.push(fn); + onFileRequire(relativePath: string, interceptor: RequireInterceptor) { + if (relativePath.length === 0) { + throw new Error("Relative path must not be empty"); + } + + if (this.requireFileInterceptors.has(relativePath)) { + throw new Error(`Interceptor for ${relativePath} already exists`); + } + + if (relativePath.startsWith("/")) { + throw new Error( + "Absolute paths are not allowed for require file interceptors" + ); + } + + if (relativePath.includes("..")) { + throw new Error( + "Relative paths with '..' are not allowed for require file interceptors" + ); + } + + if (relativePath.startsWith("./")) { + relativePath = relativePath.slice(2); + } + + this.requireFileInterceptors.set(relativePath, interceptor); - return fn; + return this; } - getSubjects() { - return this.subjects; + getRequireInterceptors() { + return this.requireInterceptors; } - getFiles() { - return this.files; + getRequireFileInterceptor(relativePath: string) { + return this.requireFileInterceptors.get(relativePath); } } diff --git a/library/agent/hooks/WrapPackageInfo.ts b/library/agent/hooks/WrapPackageInfo.ts new file mode 100644 index 000000000..55c77079b --- /dev/null +++ b/library/agent/hooks/WrapPackageInfo.ts @@ -0,0 +1,24 @@ +export type WrapPackageInfo = { + /** + * Name of the package. + */ + name: string; + /** + * Version of the package, only set if the module is not a builtin module. + */ + version?: string; + /** + * Type of the wrap target. + */ + type: "builtin" | "external" | "global"; + /** + * Only set if the module is not a builtin module. + */ + path?: { + base: string; + /** + * Path of the imported js file relative to the module base directory. + */ + relative: string; + }; +}; diff --git a/library/agent/hooks/WrappableFile.ts b/library/agent/hooks/WrappableFile.ts deleted file mode 100644 index 31e5dd084..000000000 --- a/library/agent/hooks/WrappableFile.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { WrappableSubject } from "./WrappableSubject"; - -/** - * Normally we use require-in-the-middle to wrap the exports of a package. - * - * However, sometimes the export don't contain the subjects that we need to wrap. - * - * In that case, we can require the library file directly and wrap the exports of the file. - * - * Using require-in-the-middle is preferred because we don't have to require any files until the package is actually used. - */ -export class WrappableFile { - private subjects: WrappableSubject[] = []; - - constructor(private readonly relativePath: string) { - if (!this.relativePath) { - throw new Error("Relative path is required"); - } - } - - getRelativePath() { - return this.relativePath; - } - - addSubject(selector: (exports: any) => unknown): WrappableSubject { - const fn = new WrappableSubject(selector); - this.subjects.push(fn); - - return fn; - } - - getSubjects() { - return this.subjects; - } -} diff --git a/library/agent/hooks/WrappableSubject.ts b/library/agent/hooks/WrappableSubject.ts deleted file mode 100644 index 1973782d6..000000000 --- a/library/agent/hooks/WrappableSubject.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { ConstructorInterceptor } from "./ConstructorInterceptor"; -import { Interceptor, MethodInterceptor } from "./MethodInterceptor"; -import { - MethodResultInterceptor, - ResultInterceptor, -} from "./MethodResultInterceptor"; -import { - ModifyingArgumentsInterceptor, - ModifyingArgumentsMethodInterceptor, -} from "./ModifyingArgumentsInterceptor"; - -/** - * A subject represents an object from package exports that we want to hook into. - */ -export class WrappableSubject { - private methods: ( - | MethodInterceptor - | ModifyingArgumentsMethodInterceptor - | ConstructorInterceptor - | MethodResultInterceptor - )[] = []; - - constructor(private readonly selector: (exports: unknown) => unknown) {} - - /** - * Inspect method calls without modifying arguments - * - * This is the preferred way to use when wrapping methods - */ - inspect(methodName: string, interceptor: Interceptor) { - const method = new MethodInterceptor(methodName, interceptor); - this.methods.push(method); - - return this; - } - - /** - * Inspection of method results. Also includes the arguments passed to the method. - * - * ! Currently only useable for sources and not for sinks. ! - * - * If not necessary, use inspect instead. - */ - inspectResult(methodName: string, interceptor: ResultInterceptor) { - const method = new MethodResultInterceptor(methodName, interceptor); - this.methods.push(method); - - return this; - } - - /** - * Inspect methods call and return modified arguments - * - * e.g. to append our middleware to express routes - * - * Don't use this unless you have to, it's better to use inspect - */ - modifyArguments( - methodName: string, - interceptor: ModifyingArgumentsInterceptor - ) { - const method = new ModifyingArgumentsMethodInterceptor( - methodName, - interceptor - ); - this.methods.push(method); - - return this; - } - - inspectNewInstance(name: string) { - const construct = new ConstructorInterceptor(name); - this.methods.push(construct); - - return construct; - } - - getSelector() { - return this.selector; - } - - getMethodInterceptors() { - return this.methods; - } -} diff --git a/library/agent/hooks/getModuleInfoFromPath.test.ts b/library/agent/hooks/getModuleInfoFromPath.test.ts new file mode 100644 index 000000000..8a626ba77 --- /dev/null +++ b/library/agent/hooks/getModuleInfoFromPath.test.ts @@ -0,0 +1,33 @@ +import * as t from "tap"; +import { getModuleInfoFromPath } from "./getModuleInfoFromPath"; + +t.test("it works", async (t) => { + t.same( + getModuleInfoFromPath( + "/Users/aikido/Projects/sec/node_modules/mysql/lib/Connection.js" + ), + { + name: "mysql", + base: "/Users/aikido/Projects/sec/node_modules/mysql", + path: "lib/Connection.js", + } + ); +}); + +t.test("it works with scoped package", async (t) => { + t.same( + getModuleInfoFromPath( + "/Users/aikido/Projects/sec/node_modules/@google-cloud/functions-framework/build/src/logger.js" + ), + { + name: "@google-cloud/functions-framework", + base: "/Users/aikido/Projects/sec/node_modules/@google-cloud/functions-framework", + path: "build/src/logger.js", + } + ); +}); + +t.test("returns undefined for invalid path", async (t) => { + const info = getModuleInfoFromPath("/Users/aikido/Projects/sec"); + t.equal(info, undefined); +}); diff --git a/library/agent/hooks/getModuleInfoFromPath.ts b/library/agent/hooks/getModuleInfoFromPath.ts new file mode 100644 index 000000000..8e4954d4f --- /dev/null +++ b/library/agent/hooks/getModuleInfoFromPath.ts @@ -0,0 +1,44 @@ +import { sep } from "path"; + +export type ModulePathInfo = { + /** + * Name of the module, including the scope if it exists. + */ + name: string; + /** + * Absolute path to the package inside node_modules. + */ + base: string; + /** + * Relative path to required file inside the package folder. + */ + path: string; +}; + +/** + * Get the module name and dir from a path that is inside a node_modules folder. + */ +export function getModuleInfoFromPath( + path: string +): ModulePathInfo | undefined { + const segments = path.split(sep); + const i = segments.lastIndexOf("node_modules"); + + if (i === -1 || i + 1 >= segments.length) { + return undefined; + } + + const isScoped = segments[i + 1][0] === "@"; + + const name = isScoped + ? segments[i + 1] + "/" + segments[i + 2] + : segments[i + 1]; + + const offset = isScoped ? 3 : 2; + + return { + name: name, + base: segments.slice(0, i + offset).join(sep), + path: segments.slice(i + offset).join(sep), + }; +} diff --git a/library/agent/hooks/isBuiltinModule.test.ts b/library/agent/hooks/isBuiltinModule.test.ts new file mode 100644 index 000000000..06e61aeb6 --- /dev/null +++ b/library/agent/hooks/isBuiltinModule.test.ts @@ -0,0 +1,11 @@ +import * as t from "tap"; +import { isBuiltinModule } from "./isBuiltinModule"; + +t.test("it works", async (t) => { + t.equal(isBuiltinModule("fs"), true); + t.equal(isBuiltinModule("mysql"), false); + t.equal(isBuiltinModule("http"), true); + t.equal(isBuiltinModule("node:http"), true); + t.equal(isBuiltinModule("test"), false); + t.equal(isBuiltinModule(""), false); +}); diff --git a/library/agent/hooks/isBuiltinModule.ts b/library/agent/hooks/isBuiltinModule.ts new file mode 100644 index 000000000..4a26164ed --- /dev/null +++ b/library/agent/hooks/isBuiltinModule.ts @@ -0,0 +1,18 @@ +import * as mod from "module"; +import { removeNodePrefix } from "../../helpers/removeNodePrefix"; + +// Added in Node.js v9.3.0, v8.10.0, v6.13.0 +const moduleList = mod.builtinModules; + +/** + * Returns true if the module is a builtin module, otherwise false. + */ +export function isBuiltinModule(moduleName: string) { + // Added in Node.js v18.6.0, v16.17.0 + if (typeof mod.isBuiltin === "function") { + return mod.isBuiltin(moduleName); + } + + // The modulelist does not include the node: prefix + return moduleList.includes(removeNodePrefix(moduleName)); +} diff --git a/library/agent/hooks/isMainJsFile.test.ts b/library/agent/hooks/isMainJsFile.test.ts new file mode 100644 index 000000000..6b7535df1 --- /dev/null +++ b/library/agent/hooks/isMainJsFile.test.ts @@ -0,0 +1,354 @@ +import * as t from "tap"; +import { isMainJsFile } from "./isMainJsFile"; +import type { PackageJson } from "type-fest"; + +const basePackageJson: PackageJson = { + name: "aikido-module", + version: "1.0.0", + main: "./index.js", +}; + +t.test("package.json main: is main file", async (t) => { + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "./index.js", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.js", + basePackageJson + ) + ); + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.js", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.js", + basePackageJson + ) + ); + + // Is true because require id and package name are the same + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "test.js", + }, + "aikido-module", + "/home/user/proj/node_modules/aikido-module/test.js", + basePackageJson + ) + ); + + // Fallback if main field is not set + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.js", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.js", + // @ts-expect-error main can not be undefined in types + { + ...basePackageJson, + ...{ main: undefined }, + } + ) + ); +}); + +t.test("package.json main: is not main file", async (t) => { + t.notOk( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "test.js", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/test.js", + basePackageJson + ) + ); + + // Path and filename do not match + t.notOk( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.js", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/test.js", + // @ts-expect-error main can not be undefined in types + { + ...basePackageJson, + ...{ main: undefined }, + } + ) + ); +}); + +t.test("package.json exports: is main file", async (t) => { + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ main: "index.mjs", exports: "index.cjs" }, + } + ) + ); + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ main: "index.mjs", exports: "./index.cjs" }, + } + ) + ); + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "./test/index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/test/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ main: "index.mjs", exports: "test/index.cjs" }, + } + ) + ); + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "./test/index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/test/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ main: "index.mjs", exports: ["test/index.cjs"] }, + } + ) + ); + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ + main: "index.mjs", + exports: { + ".": "./index.cjs", + "./test": "./test/abc.cjs", + }, + }, + } + ) + ); + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ + main: "index.mjs", + exports: { + ".": { + require: "./index.cjs", + import: "./index.mjs", + }, + "./test": "./test/abc.cjs", + }, + }, + } + ) + ); + t.ok( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ + main: "index.mjs", + exports: { + ".": { + node: "./index.cjs", + import: "./index.mjs", + }, + "./test": "./test/abc.cjs", + }, + }, + } + ) + ); +}); + +t.test("package.json exports: is not main file", async (t) => { + t.notOk( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ main: "index.mjs" }, + } + ) + ); + t.notOk( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "./test/index2.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/test/index2.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ main: "index.mjs", exports: "test/index.cjs" }, + } + ) + ); + t.notOk( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "./test/index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/test/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ main: "index.mjs", exports: [] }, + } + ) + ); + t.notOk( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "./test/index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/test/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ main: "index.mjs", exports: null }, + } + ) + ); + t.notOk( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ + main: "index.mjs", + exports: { + "./abc": "./index.cjs", + "./test": "./test/abc.cjs", + }, + }, + } + ) + ); + t.notOk( + isMainJsFile( + { + name: "aikido-module", + base: "/home/user/proj/node_modules/aikido-module", + path: "index.cjs", + }, + "abc", + "/home/user/proj/node_modules/aikido-module/index.cjs", + // @ts-expect-error Merge + { + ...basePackageJson, + ...{ + main: "index.mjs", + exports: { + ".": { + browser: "./index.cjs", + import: "./index.mjs", + }, + "./test": "./test/abc.cjs", + }, + }, + } + ) + ); +}); diff --git a/library/agent/hooks/isMainJsFile.ts b/library/agent/hooks/isMainJsFile.ts new file mode 100644 index 000000000..5ef885b14 --- /dev/null +++ b/library/agent/hooks/isMainJsFile.ts @@ -0,0 +1,97 @@ +import type { PackageJson } from "type-fest"; +import type { ModulePathInfo } from "./getModuleInfoFromPath"; +import { resolve } from "path"; +import { isPlainObject } from "../../helpers/isPlainObject"; + +/** + * This function checks if the required file is the main file of the package. + * It does this by checking the package.json file of the package. + */ +export function isMainJsFile( + pathInfo: ModulePathInfo, + requireId: string, + filename: string, + packageJson: PackageJson +) { + // If the name of the package is the same as the requireId (the argument passed to require), then it is the main file + if (pathInfo.name === requireId) { + return true; + } + + // Check package.json main field + if ( + typeof packageJson.main === "string" && + resolve(pathInfo.base, packageJson.main) === filename + ) { + return true; + } + + // Defaults to index.js if main field is not set + if (packageJson.main === undefined) { + if (resolve(pathInfo.base, "index.js") === filename) { + return true; + } + } + + // Check exports field + return doesMainExportMatchFilename( + packageJson.exports, + pathInfo.base, + filename + ); +} + +const allowedExportConditions = [ + "default", + "node", + "node-addons", + "require", +] as const; + +/** + * This function checks if the main package exported js file is the same as the passed file. + */ +function doesMainExportMatchFilename( + exportsField: PackageJson["exports"], + base: string, + filename: string +) { + if (!exportsField) { + return false; + } + + if (typeof exportsField === "string") { + if (resolve(base, exportsField) === filename) { + return true; + } + } + + if (Array.isArray(exportsField)) { + for (const value of exportsField) { + if (typeof value === "string" && resolve(base, value) === filename) { + return true; + } + } + } else if (isPlainObject(exportsField)) { + for (const [key, value] of Object.entries(exportsField)) { + if ([".", "./", "./index.js"].includes(key)) { + if (typeof value === "string" && resolve(base, value) === filename) { + return true; + } + if (isPlainObject(value)) { + for (const condition of allowedExportConditions) { + if ( + condition in value && + typeof value[condition] === "string" && + resolve(base, value[condition]) === filename + ) { + return true; + } + } + } + } + } + } + + return false; +} diff --git a/library/agent/hooks/wrapDefaultOrNamed.ts b/library/agent/hooks/wrapDefaultOrNamed.ts new file mode 100644 index 000000000..3c6c932b4 --- /dev/null +++ b/library/agent/hooks/wrapDefaultOrNamed.ts @@ -0,0 +1,16 @@ +import { createWrappedFunction, wrap } from "../../helpers/wrap"; + +/** + * This function allows to wrap a default export or a named export. + * If the name is undefined, it will wrap the default export of a module. + */ +export function wrapDefaultOrNamed( + module: any, + name: string | undefined, + wrapper: (original: Function) => Function +) { + if (typeof name === "undefined") { + return createWrappedFunction(module, wrapper); + } + return wrap(module, name, wrapper); +} diff --git a/library/agent/hooks/wrapExport.test.ts b/library/agent/hooks/wrapExport.test.ts new file mode 100644 index 000000000..13c9c5862 --- /dev/null +++ b/library/agent/hooks/wrapExport.test.ts @@ -0,0 +1,211 @@ +import * as t from "tap"; +import { wrapExport } from "./wrapExport"; +import { LoggerForTesting } from "../logger/LoggerForTesting"; +import { Token } from "../api/Token"; +import { bindContext } from "../Context"; +import { createTestAgent } from "../../helpers/createTestAgent"; + +t.test("Agent is not initialized", async (t) => { + try { + wrapExport( + {}, + "test", + { name: "test", type: "external" }, + { + inspectArgs: () => {}, + } + ); + t.fail(); + } catch (e) { + t.same(e.message, "Can not wrap exports if agent is not initialized"); + } +}); + +const logger = new LoggerForTesting(); + +const agent = createTestAgent({ + logger, + token: new Token("123"), +}); + +t.test("Inspect args", async (t) => { + t.plan(2); + const toWrap = { + test(input: string) { + return input; + }, + }; + + wrapExport( + toWrap, + "test", + { name: "test", type: "external" }, + { + inspectArgs: (args) => { + t.same(args, ["input"]); + }, + } + ); + + t.same(toWrap.test("input"), "input"); +}); + +t.test("Modify args", async (t) => { + const toWrap = { + test(input: string) { + return input; + }, + }; + + wrapExport( + toWrap, + "test", + { name: "test", type: "external" }, + { + modifyArgs: (args) => { + return ["modified"]; + }, + } + ); + + t.same(toWrap.test("input"), "modified"); +}); + +t.test("Modify return value", async (t) => { + const toWrap = { + test() { + return "test"; + }, + }; + + wrapExport( + toWrap, + "test", + { name: "test", type: "external" }, + { + modifyReturnValue: (args) => { + return "modified"; + }, + } + ); + + t.same(toWrap.test(), "modified"); +}); + +t.test("Combine interceptors", async (t) => { + const toWrap = { + test(input: string) { + return input; + }, + }; + + wrapExport( + toWrap, + "test", + { name: "test", type: "external" }, + { + inspectArgs: (args) => { + t.same(args, ["input"]); + }, + modifyArgs: (args) => { + return ["modArgs"]; + }, + modifyReturnValue: (args, returnVal) => { + return returnVal + "modReturn"; + }, + } + ); + + t.same(toWrap.test("input"), "modArgsmodReturn"); +}); + +t.test("Catches error in interceptors", async (t) => { + const toWrap = { + test() { + return "test"; + }, + }; + + wrapExport( + toWrap, + "test", + { name: "test", type: "external" }, + { + inspectArgs: () => { + throw new Error("Error in interceptor"); + }, + modifyArgs: () => { + throw new Error("Error in interceptor"); + }, + modifyReturnValue: () => { + throw new Error("Error in interceptor"); + }, + } + ); + + t.same(toWrap.test(), "test"); + t.match( + logger.getMessages(), + /Internal error in module "test" in method "test/ + ); +}); + +t.test("With callback", async (t) => { + const toWrap = { + test(input: string, callback: (input: string) => void) { + callback(input); + }, + }; + + wrapExport( + toWrap, + "test", + { name: "test", type: "external" }, + { + inspectArgs: (args) => { + t.same(args, ["input", bindContext(() => {})]); + }, + } + ); + + toWrap.test("input", () => {}); +}); + +t.test("Wrap non existing method", async (t) => { + const toWrap = {}; + + logger.clear(); + + wrapExport( + toWrap, + "test123", + { name: "test", type: "external" }, + { + inspectArgs: () => {}, + } + ); + + t.match(logger.getMessages(), [ + "Failed to wrap method test123 in module test", + ]); +}); + +t.test("Wrap default export", async (t) => { + t.plan(2); + const toWrap = (input: string) => { + return input; + }; + + const patched = wrapExport( + toWrap, + undefined, + { name: "test", type: "external" }, + { + inspectArgs: (args) => { + t.same(args, ["input"]); + }, + } + ) as Function; + + t.same(patched("input"), "input"); +}); diff --git a/library/agent/hooks/wrapExport.ts b/library/agent/hooks/wrapExport.ts new file mode 100644 index 000000000..4c9624975 --- /dev/null +++ b/library/agent/hooks/wrapExport.ts @@ -0,0 +1,195 @@ +/* eslint-disable max-lines-per-function */ +import { resolve } from "path"; +import { cleanupStackTrace } from "../../helpers/cleanupStackTrace"; +import { escapeHTML } from "../../helpers/escapeHTML"; +import { Agent } from "../Agent"; +import { getInstance } from "../AgentSingleton"; +import { attackKindHumanName } from "../Attack"; +import { bindContext, getContext, updateContext } from "../Context"; +import { InterceptorResult } from "./InterceptorResult"; +import { WrapPackageInfo } from "./WrapPackageInfo"; +import { wrapDefaultOrNamed } from "./wrapDefaultOrNamed"; + +type InspectArgsInterceptor = ( + args: unknown[], + agent: Agent, + subject: unknown +) => InterceptorResult | void; + +type ModifyArgsInterceptor = (args: unknown[], agent: Agent) => unknown[]; + +type ModifyReturnValueInterceptor = ( + args: unknown[], + returnValue: unknown, + agent: Agent +) => unknown; + +export type InterceptorObject = { + inspectArgs?: InspectArgsInterceptor; + modifyArgs?: ModifyArgsInterceptor; + modifyReturnValue?: ModifyReturnValueInterceptor; +}; + +// Used for cleaning up the stack trace +const libraryRoot = resolve(__dirname, "../.."); + +/** + * Wraps a function with the provided interceptors. + * If the function is not part of an object, like default exports, pass undefined as methodName and the function as subject. + */ +export function wrapExport( + subject: unknown, + methodName: string | undefined, + pkgInfo: WrapPackageInfo, + interceptors: InterceptorObject +) { + const agent = getInstance(); + if (!agent) { + throw new Error("Can not wrap exports if agent is not initialized"); + } + + try { + return wrapDefaultOrNamed( + subject, + methodName, + function wrap(original: Function) { + return function wrap() { + // eslint-disable-next-line prefer-rest-params + let args = Array.from(arguments); + const context = getContext(); + + // Run inspectArgs interceptor if provided + if (typeof interceptors.inspectArgs === "function") { + // Bind context to functions in arguments + for (let i = 0; i < args.length; i++) { + if (typeof args[i] === "function") { + args[i] = bindContext(args[i]); + } + } + + inspectArgs.call( + // @ts-expect-error We don't now the type of this + this, + args, + interceptors.inspectArgs, + context, + agent, + pkgInfo, + methodName || "" + ); + } + + // Run modifyArgs interceptor if provided + if (typeof interceptors.modifyArgs === "function") { + try { + args = interceptors.modifyArgs(args, agent); + } catch (error: any) { + agent.onErrorThrownByInterceptor({ + error: error, + method: methodName || "default export", + module: pkgInfo.name, + }); + } + } + + const returnVal = original.apply( + // @ts-expect-error We don't now the type of this + this, + args + ); + + // Run modifyReturnValue interceptor if provided + if (typeof interceptors.modifyReturnValue === "function") { + try { + return interceptors.modifyReturnValue(args, returnVal, agent); + } catch (error: any) { + agent.onErrorThrownByInterceptor({ + error: error, + method: methodName || "default export", + module: pkgInfo.name, + }); + } + } + + return returnVal; + }; + } + ); + } catch (error) { + agent.onFailedToWrapMethod(pkgInfo.name, methodName || "default export"); + } +} + +function inspectArgs( + args: unknown[], + interceptor: InspectArgsInterceptor, + context: ReturnType, + agent: Agent, + pkgInfo: WrapPackageInfo, + methodName: string +) { + if (context) { + const matches = agent.getConfig().getEndpoints(context); + + if (matches.find((match) => match.forceProtectionOff)) { + return; + } + } + + const start = performance.now(); + let result: InterceptorResult = undefined; + + try { + result = interceptor( + args, + agent, + // @ts-expect-error We don't now the type of this + this + ); + } catch (error: any) { + agent.getInspectionStatistics().interceptorThrewError(pkgInfo.name); + agent.onErrorThrownByInterceptor({ + error: error, + method: methodName, + module: pkgInfo.name, + }); + } + + const end = performance.now(); + agent.getInspectionStatistics().onInspectedCall({ + sink: pkgInfo.name, + attackDetected: !!result, + blocked: agent.shouldBlock(), + durationInMs: end - start, + withoutContext: !context, + }); + + const isAllowedIP = + context && + context.remoteAddress && + agent.getConfig().isAllowedIP(context.remoteAddress); + + if (result && context && !isAllowedIP) { + // Flag request as having an attack detected + updateContext(context, "attackDetected", true); + + agent.onDetectedAttack({ + module: pkgInfo.name, + operation: result.operation, + kind: result.kind, + source: result.source, + blocked: agent.shouldBlock(), + stack: cleanupStackTrace(new Error().stack!, libraryRoot), + path: result.pathToPayload, + metadata: result.metadata, + request: context, + payload: result.payload, + }); + + if (agent.shouldBlock()) { + throw new Error( + `Zen has blocked ${attackKindHumanName(result.kind)}: ${result.operation}(...) originating from ${result.source}${escapeHTML(result.pathToPayload)}` + ); + } + } +} diff --git a/library/agent/hooks/wrapNewInstance.test.ts b/library/agent/hooks/wrapNewInstance.test.ts new file mode 100644 index 000000000..275eb04a7 --- /dev/null +++ b/library/agent/hooks/wrapNewInstance.test.ts @@ -0,0 +1,90 @@ +/* eslint-disable max-classes-per-file */ +import * as t from "tap"; +import { wrapNewInstance } from "./wrapNewInstance"; +import { LoggerForTesting } from "../logger/LoggerForTesting"; +import { Token } from "../api/Token"; +import { createTestAgent } from "../../helpers/createTestAgent"; + +t.test("Agent is not initialized", async (t) => { + try { + wrapNewInstance({}, "test", { name: "test", type: "external" }, () => {}); + t.fail(); + } catch (e) { + t.same(e.message, "Can not wrap new instance if agent is not initialized"); + } +}); + +const logger = new LoggerForTesting(); +const agent = createTestAgent({ + logger, + token: new Token("123"), +}); + +t.test("Inspect args", async (t) => { + const exports = { + test: class Test { + constructor(private input: string) {} + + getInput() { + return this.input; + } + }, + }; + + wrapNewInstance( + exports, + "test", + { name: "test", type: "external" }, + (exports) => { + exports.testMethod = function test() { + return "aikido"; + }; + } + ); + + const instance = new exports.test("input"); + t.same(instance.getInput(), "input"); + // @ts-expect-error Test method is added by interceptor + t.same(instance.testMethod(), "aikido"); +}); + +t.test("Wrap non existing class", async (t) => { + const exports = {}; + + wrapNewInstance( + exports, + "test", + { name: "testmod", type: "external" }, + () => {} + ); + + t.same(logger.getMessages(), [ + "Failed to wrap method test in module testmod", + ]); +}); + +t.test("Can wrap default export", async (t) => { + let testExport = class Test { + constructor(private input: string) {} + + getInput() { + return this.input; + } + }; + + testExport = wrapNewInstance( + testExport, + undefined, + { name: "test", type: "external" }, + (exports) => { + exports.testMethod = function test() { + return "aikido"; + }; + } + ) as any; + + const instance = new testExport("input"); + t.same(instance.getInput(), "input"); + // @ts-expect-error Test method is added by interceptor + t.same(instance.testMethod(), "aikido"); +}); diff --git a/library/agent/hooks/wrapNewInstance.ts b/library/agent/hooks/wrapNewInstance.ts new file mode 100644 index 000000000..3923de4c8 --- /dev/null +++ b/library/agent/hooks/wrapNewInstance.ts @@ -0,0 +1,40 @@ +import { getInstance } from "../AgentSingleton"; +import { wrapDefaultOrNamed } from "./wrapDefaultOrNamed"; +import { WrapPackageInfo } from "./WrapPackageInfo"; + +/** + * Intercepts the creation of a new instance of a class, to wrap it's methods and properties. + */ +export function wrapNewInstance( + subject: unknown, + className: string | undefined, + pkgInfo: WrapPackageInfo, + interceptor: (exports: any) => void +) { + const agent = getInstance(); + if (!agent) { + throw new Error("Can not wrap new instance if agent is not initialized"); + } + + try { + return wrapDefaultOrNamed( + subject, + className, + function wrap(original: Function) { + return function wrap() { + // eslint-disable-next-line prefer-rest-params + const args = Array.from(arguments); + + // @ts-expect-error It's a constructor + const newInstance = new original(...args); + + interceptor(newInstance); + + return newInstance; + }; + } + ); + } catch (error) { + agent.onFailedToWrapMethod(pkgInfo.name, className || "default export"); + } +} diff --git a/library/agent/hooks/wrapRequire.test.ts b/library/agent/hooks/wrapRequire.test.ts new file mode 100644 index 000000000..af1090058 --- /dev/null +++ b/library/agent/hooks/wrapRequire.test.ts @@ -0,0 +1,390 @@ +import * as t from "tap"; +import { + wrapRequire, + setPackagesToPatch, + setBuiltinModulesToPatch, + getOriginalRequire, +} from "./wrapRequire"; +import { Package } from "./Package"; +import { BuiltinModule } from "./BuiltinModule"; + +t.test("Wrap require does not throw an error", async (t) => { + wrapRequire(); + t.pass(); +}); + +t.test("Wrapping require twice does not throw an error", async (t) => { + wrapRequire(); + t.pass(); +}); + +t.test("Can wrap external package", async (t) => { + const initialSqlite3 = require("sqlite3"); + + const pkg = new Package("sqlite3"); + pkg.withVersion("^5.0.0").onRequire((exports, pkgInfo) => { + exports._test = "aikido"; + t.same(pkgInfo.name, "sqlite3"); + t.same(pkgInfo.type, "external"); + t.ok(pkgInfo.path?.base.endsWith("node_modules/sqlite3")); + t.same(pkgInfo.path?.relative, "lib/sqlite3.js"); + }); + setPackagesToPatch([pkg]); + + // Require patched sqlite3 + const sqlite3 = require("sqlite3"); + t.same(sqlite3._test, "aikido"); + + // Get cached sqlite3 + const sqlite3Cached = require("sqlite3"); + t.same(sqlite3Cached._test, "aikido"); + + // Reset packages to patch, so on the next require we get the original sqlite3 + setPackagesToPatch([]); + const unpatchedSqlite3 = require("sqlite3"); + + t.same(initialSqlite3, unpatchedSqlite3); +}); + +t.test("Can wrap file of external package", async (t) => { + const initialHonoBase = require("hono/hono-base"); + + const pkg = new Package("hono"); + pkg + .withVersion("^4.0.0") + .onFileRequire("dist/cjs/hono-base.js", (exports, pkgInfo) => { + exports._test = "aikido"; + t.same(pkgInfo.name, "hono"); + t.same(pkgInfo.type, "external"); + t.ok(pkgInfo.path?.base.endsWith("node_modules/hono")); + t.same(pkgInfo.path?.relative, "dist/cjs/hono-base.js"); + }); + setPackagesToPatch([pkg]); + + // Require patched version of hono + const honoBase = require("hono/hono-base"); + t.same(honoBase._test, "aikido"); + + // Reset packages to patch, so on the next require we get the original hono + setPackagesToPatch([]); + const unpatchedHonoBase = require("hono/hono-base"); + t.same(initialHonoBase, unpatchedHonoBase); +}); + +t.test("Can wrap builtin module", async (t) => { + const initialFs = require("fs"); + + const module = new BuiltinModule("fs"); + module.onRequire((exports, pkgInfo) => { + exports._test = "aikido"; + t.same(pkgInfo.name, "fs"); + t.same(pkgInfo.type, "builtin"); + t.same(pkgInfo.path, undefined); + }); + setBuiltinModulesToPatch([module]); + + // Require patched fs + const fs = require("fs"); + t.same(fs._test, "aikido"); + + // Get cached fs + const fsCached = require("fs"); + t.same(fsCached._test, "aikido"); + + // Reset builtin modules to patch, so on the next require we get the original fs + setBuiltinModulesToPatch([]); + const unpatchedFs = require("fs"); + t.same(initialFs, unpatchedFs); +}); + +t.test("Does not wrap package with not matching version", async (t) => { + const initialSqlite3 = require("sqlite3"); + + const pkg = new Package("sqlite3"); + pkg.withVersion("^100.0.0").onRequire((exports, pkgInfo) => { + exports._test = "aikido"; + }); + setPackagesToPatch([pkg]); + + // Require original sqlite3 + const sqlite3 = require("sqlite3"); + t.same(sqlite3, initialSqlite3); +}); + +t.test("Does not wrap package with no interceptors", async (t) => { + const initialSqlite3 = require("sqlite3"); + + const pkg = new Package("sqlite3"); + pkg.withVersion("^5.0.0"); + setPackagesToPatch([pkg]); + + // Require original sqlite3 + const sqlite3 = require("sqlite3"); + t.same(sqlite3, initialSqlite3); +}); + +t.test("Does not wrap package without version", async (t) => { + const initialSqlite3 = require("sqlite3"); + + const pkg = new Package("sqlite3"); + setPackagesToPatch([pkg]); + + // Require original sqlite3 + const sqlite3 = require("sqlite3"); + t.same(sqlite3, initialSqlite3); +}); + +t.test("Replace default export", async (t) => { + const initialSqlite3 = require("sqlite3"); + + const pkg = new Package("sqlite3"); + pkg.withVersion("^5.0.0").onRequire((exports, pkgInfo) => { + return "aikido"; + }); + setPackagesToPatch([pkg]); + + // Require patched sqlite3 + const sqlite3 = require("sqlite3"); + t.same(sqlite3, "aikido"); + + // Reset packages to patch, so on the next require we get the original sqlite3 + setPackagesToPatch([]); + const unpatchedSqlite3 = require("sqlite3"); + + t.same(initialSqlite3, unpatchedSqlite3); +}); + +t.test("Confirm its caching the exports", async (t) => { + let counter = 0; + + const pkg = new Package("sqlite3"); + pkg.withVersion("^5.0.0").onRequire((exports, pkgInfo) => { + counter++; + return "aikido"; + }); + setPackagesToPatch([pkg]); + + // Require patched sqlite3 + const sqlite3 = require("sqlite3"); + t.same(sqlite3, "aikido"); + const sqlite3Cached = require("sqlite3"); + t.same(sqlite3Cached, "aikido"); + + setPackagesToPatch([]); + + t.same(counter, 1); +}); + +t.test("Returns original exports on exception", async (t) => { + const initialSqlite3 = require("sqlite3"); + + const pkg = new Package("sqlite3"); + pkg.withVersion("^5.0.0").onRequire((exports, pkgInfo) => { + exports._test = "aikido"; + throw new Error("Test error"); + }); + setPackagesToPatch([pkg]); + + // Should return original sqlite3 + const sqlite3 = require("sqlite3"); + t.same(sqlite3, initialSqlite3); +}); + +t.test("Require non-existing package", async (t) => { + const error = t.throws(() => require("unknown")); + t.ok(error instanceof Error); + if (error instanceof Error) { + t.match(error.message, /Cannot find module .unknown./); + } +}); + +t.test("Not wrapped using original require", async (t) => { + const initialFs = require("fs"); + + const mod = new BuiltinModule("fs"); + mod.onRequire((exports, pkgInfo) => { + exports._test = "aikido"; + }); + setBuiltinModulesToPatch([mod]); + + // Require patched sqlite3 + const fs = require("fs"); + t.same(fs._test, "aikido"); + + // Require original sqlite3 + const fsOriginal = getOriginalRequire()("fs"); + t.same(fsOriginal, initialFs); +}); + +t.test("Require json file", async (t) => { + const json = require("../../package.json"); + t.same(json.name, "@aikidosec/firewall"); +}); + +t.test("Pass invalid arguments to VersionedPackage", async (t) => { + t.same( + // @ts-expect-error Test with invalid arguments + (t.throws(() => new Package()) as Error).message, + "Package name is required" + ); + t.same( + // @ts-expect-error Test with invalid arguments + (t.throws(() => new Package("test").withVersion()) as Error).message, + "Version range is required" + ); + t.same( + ( + t.throws(() => + // @ts-expect-error Test with invalid arguments + new Package("test").withVersion("^1.0.0").onRequire() + ) as Error + ).message, + "Interceptor must be a function" + ); + t.same( + ( + t.throws(() => + new Package("test").withVersion("^1.0.0").onFileRequire("", () => {}) + ) as Error + ).message, + "Relative path must not be empty" + ); + t.same( + ( + t.throws(() => + new Package("test") + .withVersion("^1.0.0") + .onFileRequire("test", () => {}) + .onFileRequire("test", () => {}) + ) as Error + ).message, + "Interceptor for test already exists" + ); + t.same( + ( + t.throws(() => + new Package("test") + .withVersion("^1.0.0") + .onFileRequire("/test", () => {}) + ) as Error + ).message, + "Absolute paths are not allowed for require file interceptors" + ); + t.same( + ( + t.throws(() => + new Package("test") + .withVersion("^1.0.0") + .onFileRequire("../test", () => {}) + ) as Error + ).message, + "Relative paths with '..' are not allowed for require file interceptors" + ); + + t.same( + new Package("test") + .withVersion("^1.0.0") + .onFileRequire("./test", () => {}) + .getRequireFileInterceptor("test"), + () => {} + ); +}); + +t.test("Add two packages with same name", async (t) => { + let intercepted = 0; + const pkg = new Package("sqlite3"); + pkg.withVersion("^5.0.0").onRequire(() => { + intercepted++; + }); + const pkg2 = new Package("sqlite3"); + pkg2.withVersion("^5.0.0").onRequire(() => { + intercepted++; + }); + + setPackagesToPatch([pkg, pkg2]); + + // Require patched sqlite3 + const sqlite3 = require("sqlite3"); + t.same(intercepted, 2); + + setPackagesToPatch([]); +}); + +t.test("Add two builtin modules with same name", async (t) => { + let intercepted = 0; + const mod = new BuiltinModule("fs"); + mod.onRequire(() => { + intercepted++; + }); + const mod2 = new BuiltinModule("fs"); + mod2.onRequire(() => { + intercepted++; + }); + + setBuiltinModulesToPatch([mod, mod2]); + + // Require patched fs + const fs = require("fs"); + t.same(intercepted, 2); + + setBuiltinModulesToPatch([]); +}); + +t.test( + "Wraps process.getBuiltinModule", + { + skip: !process.getBuiltinModule + ? "Not available in Node.js < v22.3.0" + : false, + }, + async (t) => { + const originalFs = require("fs"); + + const mod = new BuiltinModule("fs"); + mod.onRequire(() => { + return "aikido"; + }); + setBuiltinModulesToPatch([mod]); + + // Require patched fs + const fs = process.getBuiltinModule("fs"); + t.same(fs, "aikido"); + + setBuiltinModulesToPatch([]); + + const fsUnpatched = require("fs"); + t.same(fsUnpatched, originalFs); + } +); + +t.test( + "process.getBuiltinModule with non-existing module", + { + skip: !process.getBuiltinModule + ? "Not available in Node.js < v22.3.0" + : false, + }, + async (t) => { + const error = t.throws(() => process.getBuiltinModule("unknown")); + t.ok(error instanceof Error); + if (error instanceof Error) { + t.match(error.message, /Cannot find module .unknown./); + } + } +); + +t.test( + "process.getBuiltinModule with non-builtin module", + { + skip: !process.getBuiltinModule + ? "Not available in Node.js < v22.3.0" + : false, + }, + async (t) => { + const error = t.throws(() => process.getBuiltinModule("sqlite3")); + t.ok(error instanceof Error); + if (error instanceof Error) { + t.match(error.message, /Cannot find module .sqlite3./); + } + } +); diff --git a/library/agent/hooks/wrapRequire.ts b/library/agent/hooks/wrapRequire.ts new file mode 100644 index 000000000..d8b90c402 --- /dev/null +++ b/library/agent/hooks/wrapRequire.ts @@ -0,0 +1,307 @@ +/* eslint-disable max-lines-per-function */ +import * as mod from "module"; +import { BuiltinModule } from "./BuiltinModule"; +import { isBuiltinModule } from "./isBuiltinModule"; +import { getModuleInfoFromPath } from "./getModuleInfoFromPath"; +import { Package } from "./Package"; +import { satisfiesVersion } from "../../helpers/satisfiesVersion"; +import { removeNodePrefix } from "../../helpers/removeNodePrefix"; +import { RequireInterceptor } from "./RequireInterceptor"; +import type { PackageJson } from "type-fest"; +import { isMainJsFile } from "./isMainJsFile"; +import { WrapPackageInfo } from "./WrapPackageInfo"; +import { getInstance } from "../AgentSingleton"; + +const originalRequire = mod.prototype.require; +let isRequireWrapped = false; + +let packages: Package[] = []; +let builtinModules: BuiltinModule[] = []; +let pkgCache = new Map(); +let builtinCache = new Map(); + +/** + * Wraps the require function to intercept require calls. + * This function makes sure that the require function is only wrapped once. + */ +export function wrapRequire() { + if (isRequireWrapped) { + return; + } + + // @ts-expect-error Not included in the Node.js types + if (typeof mod._resolveFilename !== "function") { + throw new Error( + `Could not find the _resolveFilename function in node:module using Node.js version ${process.version}` + ); + } + + // Prevent wrapping the require function multiple times + isRequireWrapped = true; + + // @ts-expect-error TS doesn't know that we are not overwriting the subproperties + mod.prototype.require = function wrapped() { + // eslint-disable-next-line prefer-rest-params + return patchedRequire.call(this, arguments); + }; + + // Wrap process.getBuiltinModule, which allows requiring builtin modules (since Node.js v22.3.0) + if (typeof process.getBuiltinModule === "function") { + process.getBuiltinModule = function wrappedGetBuiltinModule() { + // eslint-disable-next-line prefer-rest-params + return patchedRequire.call(this, arguments); + }; + } +} + +/** + * Update the list of external packages that should be patched. + */ +export function setPackagesToPatch(packagesToPatch: Package[]) { + packages = packagesToPatch; + // Reset cache + pkgCache = new Map(); +} + +/** + * Update the list of builtin modules that should be patched. + */ +export function setBuiltinModulesToPatch( + builtinModulesToPatch: BuiltinModule[] +) { + builtinModules = builtinModulesToPatch; + // Reset cache + builtinCache = new Map(); +} + +/** + * Our custom require function that intercepts require calls. + */ +function patchedRequire(this: mod | NodeJS.Process, args: IArguments) { + // Apply the original require function + const originalExports = originalRequire.apply( + this, + args as unknown as [string] + ); + + if (!args.length || typeof args[0] !== "string") { + return originalExports; + } + + /** + * Parameter that is passed to the require function + * Can be a module name, a relative / absolute path + */ + const id = args[0] as string; + + try { + // Check if it's a builtin module + // They are easier to patch (no file patching) + // Separate handling for builtin modules improves the performance + if (isBuiltinModule(id)) { + // Call function for patching builtin modules with the same context (this) + return patchBuiltinModule.call(this, id, originalExports); + } + + // Call function for patching external packages + return patchPackage.call(this as mod, id, originalExports); + } catch (error) { + if (error instanceof Error) { + getInstance()?.onFailedToWrapModule(id, error); + } + + return originalExports; + } +} + +/** + * Run all require interceptors for the builtin module and cache the result. + */ +function patchBuiltinModule(id: string, originalExports: unknown) { + const moduleName = removeNodePrefix(id); + + // Check if already cached + if (builtinCache.has(moduleName)) { + return builtinCache.get(moduleName); + } + + // Check if we want to patch this builtin module + const matchingBuiltins = builtinModules.filter( + (m) => m.getName() === moduleName + ); + + // We don't want to patch this builtin module + if (!matchingBuiltins.length) { + return originalExports; + } + + // Get interceptors from all matching builtin modules + const interceptors = matchingBuiltins + .map((m) => m.getRequireInterceptors()) + .flat(); + + return executeInterceptors( + interceptors, + originalExports, + builtinCache, + moduleName, + { + name: moduleName, + type: "builtin", + } + ); +} + +/** + * Run all require interceptors for the package and cache the result. + * Also checks the package versions. Not used for builtin modules. + */ +function patchPackage(this: mod, id: string, originalExports: unknown) { + // Get the full filepath of the required js file + // @ts-expect-error Not included in the Node.js types + const filename = mod._resolveFilename(id, this); + if (!filename) { + throw new Error("Could not resolve filename using _resolveFilename"); + } + + // Ignore .json files + if (filename.endsWith(".json")) { + return originalExports; + } + + // Check if cache has the filename + if (pkgCache.has(filename)) { + return pkgCache.get(filename); + } + + // Parses the filename to extract the module name, the base dir of the module and the relative path of the included file + const pathInfo = getModuleInfoFromPath(filename); + if (!pathInfo) { + // Can happen if the package is not inside a node_modules folder, like the dev build of our library itself + return originalExports; + } + + const moduleName = pathInfo.name; + + // Get all versioned packages for the module name + const versionedPackages = packages + .filter((pkg) => pkg.getName() === moduleName) + .map((pkg) => pkg.getVersions()) + .flat(); + + // We don't want to patch this package because we do not have any hooks for it + if (!versionedPackages.length) { + return originalExports; + } + + // Read the package.json of the required package + const packageJson = originalRequire( + `${pathInfo.base}/package.json` + ) as PackageJson; + + // Get the version of the installed package + const installedPkgVersion = packageJson.version; + if (!installedPkgVersion) { + throw new Error( + `Could not get installed package version for ${moduleName}` + ); + } + + // Check if the installed package version is supported (get all matching versioned packages) + const matchingVersionedPackages = versionedPackages.filter((pkg) => + satisfiesVersion(pkg.getRange(), installedPkgVersion) + ); + + const agent = getInstance(); + if (agent) { + // Report to the agent that the package was wrapped or not if it's version is not supported + agent.onPackageWrapped(moduleName, { + version: installedPkgVersion, + supported: !!matchingVersionedPackages.length, + }); + } + + if (!matchingVersionedPackages.length) { + // We don't want to patch this package version + return originalExports; + } + + // Check if the required file is the main file of the package or another js file inside the package + const isMainFile = isMainJsFile(pathInfo, id, filename, packageJson); + + let interceptors: RequireInterceptor[] = []; + + if (isMainFile) { + interceptors = matchingVersionedPackages + .map((pkg) => pkg.getRequireInterceptors()) + .flat(); + } else { + // If it's not the main file, we want to check if the want to patch the required file + interceptors = matchingVersionedPackages + .map((pkg) => pkg.getRequireFileInterceptor(pathInfo.path) || []) + .flat(); + } + + return executeInterceptors( + interceptors, + originalExports, + pkgCache, + filename as string, + { + name: pathInfo.name, + version: installedPkgVersion, + type: "external", + path: { + base: pathInfo.base, + relative: pathInfo.path, + }, + } + ); +} + +/** + * Executes the provided require interceptor functions and sets the cache. + */ +function executeInterceptors( + interceptors: RequireInterceptor[], + exports: unknown, + cache: Map, + cacheKey: string, + wrapPackageInfo: WrapPackageInfo +) { + // Cache because we need to prevent this called again if module is imported inside interceptors + cache.set(cacheKey, exports); + + // Return early if no interceptors + if (!interceptors.length) { + return exports; + } + + // Foreach interceptor function + for (const interceptor of interceptors) { + // If one interceptor fails, we don't want to stop the other interceptors + try { + const returnVal = interceptor(exports, wrapPackageInfo); + // If the interceptor returns a value, we want to use this value as the new exports + if (typeof returnVal !== "undefined") { + exports = returnVal; + } + } catch (error) { + if (error instanceof Error) { + getInstance()?.onFailedToWrapModule(wrapPackageInfo.name, error); + } + } + } + + // Finally cache the result + cache.set(cacheKey, exports); + + return exports; +} + +/** + * Returns the unwrapped require function. + */ +export function getOriginalRequire() { + return originalRequire; +} diff --git a/library/agent/wrapInstalledPackages.ts b/library/agent/wrapInstalledPackages.ts index 96478f723..6b7b364a8 100644 --- a/library/agent/wrapInstalledPackages.ts +++ b/library/agent/wrapInstalledPackages.ts @@ -1,13 +1,12 @@ -import { Agent } from "./Agent"; import { applyHooks } from "./applyHooks"; import { Hooks } from "./hooks/Hooks"; import { Wrapper } from "./Wrapper"; -export function wrapInstalledPackages(agent: Agent, wrappers: Wrapper[]) { +export function wrapInstalledPackages(wrappers: Wrapper[]) { const hooks = new Hooks(); wrappers.forEach((wrapper) => { wrapper.wrap(hooks); }); - return applyHooks(hooks, agent); + return applyHooks(hooks); } diff --git a/library/helpers/createTestAgent.ts b/library/helpers/createTestAgent.ts new file mode 100644 index 000000000..b497e633e --- /dev/null +++ b/library/helpers/createTestAgent.ts @@ -0,0 +1,30 @@ +import { Agent } from "../agent/Agent"; +import { setInstance } from "../agent/AgentSingleton"; +import type { ReportingAPI } from "../agent/api/ReportingAPI"; +import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; +import type { Token } from "../agent/api/Token"; +import type { Logger } from "../agent/logger/Logger"; +import { LoggerNoop } from "../agent/logger/LoggerNoop"; + +/** + * Create a test agent for testing purposes + */ +export function createTestAgent(opts?: { + block?: boolean; + logger?: Logger; + api?: ReportingAPI; + token?: Token; + serverless?: string; +}) { + const agent = new Agent( + opts?.block ?? true, + opts?.logger ?? new LoggerNoop(), + opts?.api ?? new ReportingAPIForTesting(), + opts?.token, // Defaults to undefined + opts?.serverless // Defaults to undefined + ); + + setInstance(agent); + + return agent; +} diff --git a/library/helpers/getPackageVersion.ts b/library/helpers/getPackageVersion.ts index d7cacbeb1..065a8a660 100644 --- a/library/helpers/getPackageVersion.ts +++ b/library/helpers/getPackageVersion.ts @@ -1,4 +1,5 @@ import { sep } from "path"; +import { getOriginalRequire } from "../agent/hooks/wrapRequire"; /** * Get the installed version of a package @@ -23,7 +24,7 @@ export function getPackageVersion(pkg: string): string | null { const index = parts.indexOf(lookup); const root = parts.slice(0, index + 1).join(sep); - return require(`${root}/package.json`).version; + return getOriginalRequire()(`${root}/package.json`).version; } catch (error) { return null; } diff --git a/library/helpers/removeNodePrefix.ts b/library/helpers/removeNodePrefix.ts new file mode 100644 index 000000000..71eafd159 --- /dev/null +++ b/library/helpers/removeNodePrefix.ts @@ -0,0 +1,9 @@ +/** + * Removes the "node:" prefix from the id of a builtin module. + */ +export function removeNodePrefix(id: string) { + if (id.startsWith("node:")) { + return id.slice(5); + } + return id; +} diff --git a/library/helpers/wrap.ts b/library/helpers/wrap.ts index 133068859..3c7f3b52e 100644 --- a/library/helpers/wrap.ts +++ b/library/helpers/wrap.ts @@ -3,21 +3,32 @@ type WrappedFunction = T & { }; export function wrap( - nodule: any, + module: any, name: string, wrapper: (original: Function) => Function ) { - if (!nodule[name]) { + if (!module[name]) { throw new Error(`no original function ${name} to wrap`); } - if (typeof nodule[name] !== "function") { + if (typeof module[name] !== "function") { throw new Error( - `original must be a function, instead found: ${typeof nodule[name]}` + `original must be a function, instead found: ${typeof module[name]}` ); } - const original = nodule[name]; + const original = module[name]; + const wrapped = createWrappedFunction(original, wrapper); + + defineProperty(module, name, wrapped); + + return wrapped; +} + +export function createWrappedFunction( + original: Function, + wrapper: (original: Function) => Function +): Function { const wrapped = wrapper(original); defineProperty(wrapped, "__original", original); @@ -28,14 +39,12 @@ export function wrap( // .inspect("realpath", (args) => {...}) // We don't want to lose the original function's properties. // Most of the functions we're wrapping don't have any properties, so this is a rare case. - for (const prop in nodule[name]) { - if (nodule[name].hasOwnProperty(prop)) { - defineProperty(wrapped, prop, nodule[name][prop]); + for (const prop in original) { + if (original.hasOwnProperty(prop)) { + defineProperty(wrapped, prop, original[prop as keyof Function]); } } - defineProperty(nodule, name, wrapped); - return wrapped; } diff --git a/library/package-lock.json b/library/package-lock.json index 68f4c0b7a..4dcca9374 100644 --- a/library/package-lock.json +++ b/library/package-lock.json @@ -19,7 +19,7 @@ "@types/express": "^4.17.21", "@types/ip": "^1.1.3", "@types/mysql": "^2.15.25", - "@types/node": "^20.11.5", + "@types/node": "^22.3.0", "@types/pg": "^8.11.0", "@types/qs": "^6.9.11", "@types/shell-quote": "^1.7.5", @@ -55,6 +55,7 @@ "sqlite3": "^5.1.7", "supertest": "^6.3.4", "tap": "^18.6.1", + "type-fest": "^4.24.0", "typescript": "^5.3.3", "undici": "^6.12.0", "xml-js": "^1.6.11", @@ -3519,11 +3520,10 @@ } }, "node_modules/@types/node": { - "version": "20.16.3", - "resolved": "https://registry.npmjs.org/@types/node/-/node-20.16.3.tgz", - "integrity": "sha512-/wdGiWRkMOm53gAsSyFMXFZHbVg7C6CbkrzHNpaHoYfsUWPg7m6ZRKtvQjgvQ9i8WT540a3ydRlRQbxjY30XxQ==", + "version": "22.7.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.7.0.tgz", + "integrity": "sha512-MOdOibwBs6KW1vfqz2uKMlxq5xAfAZ98SZjO8e3XnAbFnTJtAspqhWk7hrdSAs9/Y14ZWMiy7/MxMUzAOadYEw==", "dev": true, - "license": "MIT", "dependencies": { "undici-types": "~6.19.2" } @@ -7108,6 +7108,18 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/globals/node_modules/type-fest": { + "version": "0.20.2", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", + "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/globalthis": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/globalthis/-/globalthis-1.0.4.tgz", @@ -13164,13 +13176,12 @@ } }, "node_modules/type-fest": { - "version": "0.20.2", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", - "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", + "version": "4.26.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-4.26.1.tgz", + "integrity": "sha512-yOGpmOAL7CkKe/91I5O3gPICmJNLJ1G4zFYVAsRHg7M64biSnPtRj0WNQt++bRkjYOqjWXrhnUw1utzmVErAdg==", "dev": true, - "license": "(MIT OR CC0-1.0)", "engines": { - "node": ">=10" + "node": ">=16" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" diff --git a/library/package.json b/library/package.json index 3e2696de0..5fcf09fb5 100644 --- a/library/package.json +++ b/library/package.json @@ -42,7 +42,7 @@ "@types/express": "^4.17.21", "@types/ip": "^1.1.3", "@types/mysql": "^2.15.25", - "@types/node": "^20.11.5", + "@types/node": "^22.3.0", "@types/pg": "^8.11.0", "@types/qs": "^6.9.11", "@types/shell-quote": "^1.7.5", @@ -78,6 +78,7 @@ "sqlite3": "^5.1.7", "supertest": "^6.3.4", "tap": "^18.6.1", + "type-fest": "^4.24.0", "typescript": "^5.3.3", "undici": "^6.12.0", "xml-js": "^1.6.11", diff --git a/library/ratelimiting/shouldRateLimitRequest.test.ts b/library/ratelimiting/shouldRateLimitRequest.test.ts index 01c3e9a05..baad11a4a 100644 --- a/library/ratelimiting/shouldRateLimitRequest.test.ts +++ b/library/ratelimiting/shouldRateLimitRequest.test.ts @@ -1,15 +1,14 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Endpoint } from "../agent/Config"; import type { Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { shouldRateLimitRequest } from "./shouldRateLimitRequest"; +import { createTestAgent } from "../helpers/createTestAgent"; function createContext( - remoteAddress: string = undefined, - userId: string = undefined, + remoteAddress: string | undefined = undefined, + userId: string | undefined = undefined, route: string = "/login", method: string = "POST" ): Context { @@ -32,10 +31,10 @@ async function createAgent( endpoints: Endpoint[] = [], allowedIpAddresses: string[] = [] ) { - const agent = new Agent( - false, - new LoggerNoop(), - new ReportingAPIForTesting({ + const agent = createTestAgent({ + block: false, + token: new Token("123"), + api: new ReportingAPIForTesting({ allowedIPAddresses: allowedIpAddresses, success: true, heartbeatIntervalInMS: 10 * 60 * 1000, @@ -43,9 +42,7 @@ async function createAgent( configUpdatedAt: 0, endpoints: endpoints, }), - new Token("123"), - undefined - ); + }); agent.start([]); diff --git a/library/sinks/AwsSDKVersion2.test.ts b/library/sinks/AwsSDKVersion2.test.ts index 2a9f97cff..a447f6cec 100644 --- a/library/sinks/AwsSDKVersion2.test.ts +++ b/library/sinks/AwsSDKVersion2.test.ts @@ -1,9 +1,7 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerForTesting } from "../agent/logger/LoggerForTesting"; import { AwsSDKVersion2 } from "./AwsSDKVersion2"; +import { createTestAgent } from "../helpers/createTestAgent"; // Suppress upgrade to SDK v3 notice require("aws-sdk/lib/maintenance_mode_message").suppress = true; @@ -26,14 +24,7 @@ const unsafeContext: Context = { }; t.test("it works", async (t) => { - const logger = new LoggerForTesting(); - const agent = new Agent( - true, - logger, - new ReportingAPIForTesting(), - undefined, - undefined - ); + const agent = createTestAgent(); agent.start([new AwsSDKVersion2()]); diff --git a/library/sinks/AwsSDKVersion2.ts b/library/sinks/AwsSDKVersion2.ts index fa3a23ada..c8233d2ef 100644 --- a/library/sinks/AwsSDKVersion2.ts +++ b/library/sinks/AwsSDKVersion2.ts @@ -1,6 +1,8 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; +import { wrapNewInstance } from "../agent/hooks/wrapNewInstance"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; import { checkContextForPathTraversal } from "../vulnerabilities/path-traversal/checkContextForPathTraversal"; @@ -67,15 +69,17 @@ export class AwsSDKVersion2 implements Wrapper { } wrap(hooks: Hooks) { - const s3 = hooks + hooks .addPackage("aws-sdk") .withVersion("^2.0.0") - .addSubject((exports) => exports) - .inspectNewInstance("S3") - .addSubject((exports) => exports); - - operationsWithKey.forEach((operation) => { - s3.inspect(operation, (args) => this.inspectS3Operation(args, operation)); - }); + .onRequire((exports, pkgInfo) => { + wrapNewInstance(exports, "S3", pkgInfo, (instance) => { + for (const operation of operationsWithKey) { + wrapExport(instance, operation, pkgInfo, { + inspectArgs: (args) => this.inspectS3Operation(args, operation), + }); + } + }); + }); } } diff --git a/library/sinks/BetterSQLite3.test.ts b/library/sinks/BetterSQLite3.test.ts index 3770a7be9..3c08ebdf9 100644 --- a/library/sinks/BetterSQLite3.test.ts +++ b/library/sinks/BetterSQLite3.test.ts @@ -1,9 +1,7 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { runWithContext, type Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { BetterSQLite3 } from "./BetterSQLite3"; +import { createTestAgent } from "../helpers/createTestAgent"; const dangerousContext: Context = { remoteAddress: "::1", @@ -49,13 +47,7 @@ const safeContext: Context = { }; t.test("it detects SQL injections", async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent(); agent.start([new BetterSQLite3()]); const betterSqlite3 = require("better-sqlite3"); diff --git a/library/sinks/BetterSQLite3.ts b/library/sinks/BetterSQLite3.ts index 08e311992..f29f170b0 100644 --- a/library/sinks/BetterSQLite3.ts +++ b/library/sinks/BetterSQLite3.ts @@ -1,6 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { checkContextForPathTraversal } from "../vulnerabilities/path-traversal/checkContextForPathTraversal"; import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/checkContextForSqlInjection"; @@ -61,24 +62,27 @@ export class BetterSQLite3 implements Wrapper { } wrap(hooks: Hooks) { - const subjects = hooks - .addPackage("better-sqlite3") - .withVersion("^11.0.0 || ^10.0.0 || ^9.0.0 || ^8.0.0") - .addSubject((exports) => exports.prototype); - const sqlFunctions = ["prepare", "exec", "pragma"]; - - for (const func of sqlFunctions) { - subjects.inspect(func, (args) => { - return this.inspectQuery(`better-sqlite3.${func}`, args); - }); - } - const fsPathFunctions = ["backup", "loadExtension"]; - for (const func of fsPathFunctions) { - subjects.inspect(func, (args) => { - return this.inspectPath(`better-sqlite3.${func}`, args); + + hooks + .addPackage("better-sqlite3") + .withVersion("^11.0.0 || ^10.0.0 || ^9.0.0 || ^8.0.0") + .onRequire((exports, pkgInfo) => { + for (const func of sqlFunctions) { + wrapExport(exports.prototype, func, pkgInfo, { + inspectArgs: (args) => { + return this.inspectQuery(`better-sqlite3.${func}`, args); + }, + }); + } + for (const func of fsPathFunctions) { + wrapExport(exports.prototype, func, pkgInfo, { + inspectArgs: (args) => { + return this.inspectPath(`better-sqlite3.${func}`, args); + }, + }); + } }); - } } } diff --git a/library/sinks/ChildProcess.test.ts b/library/sinks/ChildProcess.test.ts index b57e72b1a..7c616a9c0 100644 --- a/library/sinks/ChildProcess.test.ts +++ b/library/sinks/ChildProcess.test.ts @@ -1,10 +1,8 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { ChildProcess } from "./ChildProcess"; -import { execFile, execFileSync, fork } from "child_process"; +import { execFile, execFileSync } from "child_process"; +import { createTestAgent } from "../helpers/createTestAgent"; const unsafeContext: Context = { remoteAddress: "::1", @@ -31,13 +29,9 @@ function throws(fn: () => void, wanted: string | RegExp) { } t.test("it works", async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent({ + serverless: "lambda", + }); agent.start([new ChildProcess()]); diff --git a/library/sinks/ChildProcess.ts b/library/sinks/ChildProcess.ts index c88523d31..3a2fa7849 100644 --- a/library/sinks/ChildProcess.ts +++ b/library/sinks/ChildProcess.ts @@ -1,6 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; import { checkContextForPathTraversal } from "../vulnerabilities/path-traversal/checkContextForPathTraversal"; @@ -17,19 +18,43 @@ const PATH_PREFIXES = [ export class ChildProcess implements Wrapper { wrap(hooks: Hooks) { - const childProcess = hooks.addBuiltinModule("child_process"); - - childProcess - .addSubject((exports) => exports) - .inspect("exec", (args) => this.inspectExec(args, "exec")) - .inspect("execSync", (args) => this.inspectExec(args, "execSync")) - .inspect("spawn", (args) => this.inspectSpawn(args, "spawn")) - .inspect("spawnSync", (args) => this.inspectSpawn(args, "spawnSync")) - .inspect("execFile", (args) => this.inspectExecFile(args, "execFile")) - .inspect("execFileSync", (args) => - this.inspectExecFile(args, "execFileSync") - ) - .inspect("fork", (args) => this.inspectFork(args, "fork")); + hooks.addBuiltinModule("child_process").onRequire((exports, pkgInfo) => { + wrapExport(exports, "exec", pkgInfo, { + inspectArgs: (args) => { + return this.inspectExec(args, "exec"); + }, + }); + wrapExport(exports, "execSync", pkgInfo, { + inspectArgs: (args) => { + return this.inspectExec(args, "execSync"); + }, + }); + wrapExport(exports, "spawn", pkgInfo, { + inspectArgs: (args) => { + return this.inspectSpawn(args, "spawn"); + }, + }); + wrapExport(exports, "spawnSync", pkgInfo, { + inspectArgs: (args) => { + return this.inspectSpawn(args, "spawnSync"); + }, + }); + wrapExport(exports, "execFile", pkgInfo, { + inspectArgs: (args) => { + return this.inspectExecFile(args, "execFile"); + }, + }); + wrapExport(exports, "execFileSync", pkgInfo, { + inspectArgs: (args) => { + return this.inspectExecFile(args, "execFileSync"); + }, + }); + wrapExport(exports, "fork", pkgInfo, { + inspectArgs: (args) => { + return this.inspectFork(args, "fork"); + }, + }); + }); } private inspectFork(args: unknown[], name: string) { diff --git a/library/sinks/Fetch.test.ts b/library/sinks/Fetch.test.ts index b2ed54089..1ad0bb4ee 100644 --- a/library/sinks/Fetch.test.ts +++ b/library/sinks/Fetch.test.ts @@ -1,13 +1,12 @@ /* eslint-disable prefer-rest-params */ import * as t from "tap"; -import { Agent } from "../agent/Agent"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { wrap } from "../helpers/wrap"; import { Fetch } from "./Fetch"; import * as dns from "dns"; +import { createTestAgent } from "../helpers/createTestAgent"; const calls: Record = {}; wrap(dns, "lookup", function lookup(original) { @@ -71,13 +70,11 @@ t.test( { skip: !global.fetch ? "fetch is not available" : false }, async (t) => { const api = new ReportingAPIForTesting(); - const agent = new Agent( - true, - new LoggerNoop(), + const agent = createTestAgent({ + token: new Token("123"), api, - new Token("123"), - undefined - ); + }); + agent.start([new Fetch()]); t.same(agent.getHostnames().asArray(), []); diff --git a/library/sinks/Fetch.ts b/library/sinks/Fetch.ts index 0306cd1b4..a637060e3 100644 --- a/library/sinks/Fetch.ts +++ b/library/sinks/Fetch.ts @@ -2,7 +2,7 @@ import { lookup } from "dns"; import { Agent } from "../agent/Agent"; import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { Wrapper } from "../agent/Wrapper"; import { getPortFromURL } from "../helpers/getPortFromURL"; import { tryParseURL } from "../helpers/tryParseURL"; @@ -136,18 +136,17 @@ export class Fetch implements Wrapper { globalThis.fetch().catch(() => {}); } - hooks - .addGlobal("fetch") + hooks.addGlobal("fetch", { // Whenever a request is made, we'll check the hostname whether it's a private IP - .inspect((args, subject, agent) => this.inspectFetch(args, agent)) - // We're not really modifying the arguments here, but we need to patch the global dispatcher - .modifyArguments((args, subject, agent) => { + inspectArgs: (args, agent) => this.inspectFetch(args, agent), + modifyArgs: (args, agent) => { if (!this.patchedGlobalDispatcher) { this.patchGlobalDispatcher(agent); this.patchedGlobalDispatcher = true; } return args; - }); + }, + }); } } diff --git a/library/sinks/FileSystem.test.ts b/library/sinks/FileSystem.test.ts index 3ad5d678d..b55d9c92d 100644 --- a/library/sinks/FileSystem.test.ts +++ b/library/sinks/FileSystem.test.ts @@ -1,9 +1,7 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { FileSystem } from "./FileSystem"; +import { createTestAgent } from "../helpers/createTestAgent"; const unsafeContext: Context = { remoteAddress: "::1", @@ -47,13 +45,7 @@ function throws(fn: () => void, wanted: string | RegExp) { } t.test("it works", async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent({ serverless: "lambda" }); agent.start([new FileSystem()]); @@ -62,6 +54,7 @@ t.test("it works", async (t) => { writeFileSync, rename, realpath, + promises: fsDotPromise, realpathSync, } = require("fs"); const { writeFile: writeFilePromise } = require("fs/promises"); @@ -95,6 +88,11 @@ t.test("it works", async (t) => { "some other file content to test with", { encoding: "utf-8" } ); + await fsDotPromise.writeFile( + "./test.txt", + "some other file content to test with", + { encoding: "utf-8" } + ); rename("./test.txt", "./test2.txt", (err) => {}); rename(new URL("file:///test123.txt"), "test2.txt", (err) => {}); rename(Buffer.from("./test123.txt"), "test2.txt", (err) => {}); @@ -135,7 +133,7 @@ t.test("it works", async (t) => { { encoding: "utf-8" } ) ); - + t.ok(error instanceof Error); if (error instanceof Error) { t.match( error.message, @@ -143,6 +141,21 @@ t.test("it works", async (t) => { ); } + const error2 = await t.rejects(() => + fsDotPromise.writeFile( + "../../test.txt", + "some other file content to test with", + { encoding: "utf-8" } + ) + ); + t.ok(error2 instanceof Error); + if (error2 instanceof Error) { + t.match( + error2.message, + "Zen has blocked a path traversal attack: fs.writeFile(...) originating from body.file.matches" + ); + } + throws( () => rename("../../test.txt", "./test2.txt", (err) => {}), "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" diff --git a/library/sinks/FileSystem.ts b/library/sinks/FileSystem.ts index eab9d2c94..1141f13c7 100644 --- a/library/sinks/FileSystem.ts +++ b/library/sinks/FileSystem.ts @@ -1,6 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { getSemverNodeVersion } from "../helpers/getNodeVersion"; import { isVersionGreaterOrEqual } from "../helpers/isVersionGreaterOrEqual"; @@ -91,42 +92,51 @@ export class FileSystem implements Wrapper { } wrap(hooks: Hooks) { - const fs = hooks.addBuiltinModule("fs"); - const callbackStyle = fs.addSubject((exports) => exports); - const promiseStyle = hooks - .addBuiltinModule("fs/promises") - .addSubject((exports) => exports); - - const functions = this.getFunctions(); - - Object.keys(functions).forEach((name) => { - const { pathsArgs, sync, promise } = functions[name]; - callbackStyle.inspect(name, (args) => { - return this.inspectPath(args, name, pathsArgs); - }); + hooks.addBuiltinModule("fs").onRequire((exports, pkgInfo) => { + const functions = this.getFunctions(); - if (sync) { - callbackStyle.inspect(`${name}Sync`, (args) => { - return this.inspectPath(args, `${name}Sync`, pathsArgs); - }); - } + Object.keys(functions).forEach((name) => { + const { pathsArgs, sync, promise } = functions[name]; - if (promise) { - promiseStyle.inspect(name, (args) => { - return this.inspectPath(args, name, pathsArgs); + wrapExport(exports, name, pkgInfo, { + inspectArgs: (args) => { + return this.inspectPath(args, name, pathsArgs); + }, }); - } - }); - fs.addSubject((exports) => exports.realpath).inspect("native", (args) => { - return this.inspectPath(args, "realpath.native", 1); + if (sync) { + wrapExport(exports, `${name}Sync`, pkgInfo, { + inspectArgs: (args) => { + return this.inspectPath(args, `${name}Sync`, pathsArgs); + }, + }); + } + }); + + // Wrap realpath.native + wrapExport(exports.realpath, "native", pkgInfo, { + inspectArgs: (args) => { + return this.inspectPath(args, "realpath.native", 1); + }, + }); + wrapExport(exports.realpathSync, "native", pkgInfo, { + inspectArgs: (args) => { + return this.inspectPath(args, "realpathSync.native", 1); + }, + }); }); - fs.addSubject((exports) => exports.realpathSync).inspect( - "native", - (args) => { - return this.inspectPath(args, "realpathSync.native", 1); - } - ); + hooks.addBuiltinModule("fs/promises").onRequire((exports, pkgInfo) => { + const functions = this.getFunctions(); + Object.keys(functions).forEach((name) => { + const { pathsArgs, sync, promise } = functions[name]; + + if (promise) { + wrapExport(exports, name, pkgInfo, { + inspectArgs: (args) => this.inspectPath(args, name, pathsArgs), + }); + } + }); + }); } } diff --git a/library/sinks/HTTPRequest.axios.test.ts b/library/sinks/HTTPRequest.axios.test.ts index a84a667df..c1f4a9b7a 100644 --- a/library/sinks/HTTPRequest.axios.test.ts +++ b/library/sinks/HTTPRequest.axios.test.ts @@ -1,10 +1,8 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { HTTPRequest } from "./HTTPRequest"; +import { createTestAgent } from "../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -24,13 +22,10 @@ const context: Context = { const redirectTestUrl = "http://ssrf-redirects.testssandbox.com"; t.test("it works", { skip: "SSRF redirect check disabled atm" }, async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - new Token("123"), - undefined - ); + const agent = createTestAgent({ + token: new Token("123"), + }); + agent.start([new HTTPRequest()]); t.same(agent.getHostnames().asArray(), []); @@ -84,7 +79,7 @@ t.test("it works", { skip: "SSRF redirect check disabled atm" }, async (t) => { if (error2 instanceof Error) { t.match( error2.message, - "Aikido firewall has blocked a server-side request forgery: http.request(...) originating from body.image" + "Zen has blocked a server-side request forgery: http.request(...) originating from body.image" ); } }); diff --git a/library/sinks/HTTPRequest.followRedirects.test.ts b/library/sinks/HTTPRequest.followRedirects.test.ts index 01895b2db..469c2450d 100644 --- a/library/sinks/HTTPRequest.followRedirects.test.ts +++ b/library/sinks/HTTPRequest.followRedirects.test.ts @@ -1,10 +1,8 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { HTTPRequest } from "./HTTPRequest"; +import { createTestAgent } from "../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -42,13 +40,9 @@ t.before(async () => { const redirectTestUrl = "http://ssrf-redirects.testssandbox.com"; t.test("it works", { skip: "SSRF redirect check disabled atm" }, (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - new Token("123"), - undefined - ); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([new HTTPRequest()]); const { http } = require("follow-redirects"); @@ -112,7 +106,7 @@ t.test("it works", { skip: "SSRF redirect check disabled atm" }, (t) => { t.ok(e instanceof Error); t.same( e.message, - "Redirected request failed: Aikido firewall has blocked a server-side request forgery: http.request(...) originating from body.image" + "Redirected request failed: Zen has blocked a server-side request forgery: http.request(...) originating from body.image" ); }); response.end(); diff --git a/library/sinks/HTTPRequest.needle.test.ts b/library/sinks/HTTPRequest.needle.test.ts index cc1bd643d..4b4b1ef61 100644 --- a/library/sinks/HTTPRequest.needle.test.ts +++ b/library/sinks/HTTPRequest.needle.test.ts @@ -1,10 +1,8 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { HTTPRequest } from "./HTTPRequest"; +import { createTestAgent } from "../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -24,13 +22,9 @@ const context: Context = { const redirectTestUrl = "http://ssrf-redirects.testssandbox.com"; t.test("it works", { skip: "SSRF redirect check disabled atm" }, async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - new Token("123"), - undefined - ); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([new HTTPRequest()]); t.same(agent.getHostnames().asArray(), []); diff --git a/library/sinks/HTTPRequest.nodeFetch.test.ts b/library/sinks/HTTPRequest.nodeFetch.test.ts index 7ad3951d8..1c733c7b8 100644 --- a/library/sinks/HTTPRequest.nodeFetch.test.ts +++ b/library/sinks/HTTPRequest.nodeFetch.test.ts @@ -1,10 +1,8 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { HTTPRequest } from "./HTTPRequest"; +import { createTestAgent } from "../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -24,13 +22,9 @@ const context: Context = { const redirectTestUrl = "http://ssrf-redirects.testssandbox.com"; t.test("it works", { skip: "SSRF redirect check disabled atm" }, async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - new Token("123"), - undefined - ); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([new HTTPRequest()]); t.same(agent.getHostnames().asArray(), []); diff --git a/library/sinks/HTTPRequest.redirect.test.ts b/library/sinks/HTTPRequest.redirect.test.ts index 66374adf7..b35e06ed0 100644 --- a/library/sinks/HTTPRequest.redirect.test.ts +++ b/library/sinks/HTTPRequest.redirect.test.ts @@ -1,11 +1,9 @@ /* eslint-disable prefer-rest-params */ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { HTTPRequest } from "./HTTPRequest"; +import { createTestAgent } from "../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -34,13 +32,9 @@ const redirectUrl = { }; t.test("it works", { skip: "SSRF redirect check disabled atm" }, (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - new Token("123"), - undefined - ); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([new HTTPRequest()]); const http = require("http"); diff --git a/library/sinks/HTTPRequest.test.ts b/library/sinks/HTTPRequest.test.ts index 85bb98934..521b23df2 100644 --- a/library/sinks/HTTPRequest.test.ts +++ b/library/sinks/HTTPRequest.test.ts @@ -1,13 +1,11 @@ /* eslint-disable prefer-rest-params */ import * as dns from "dns"; import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { wrap } from "../helpers/wrap"; import { HTTPRequest } from "./HTTPRequest"; +import { createTestAgent } from "../helpers/createTestAgent"; const calls: Record = {}; wrap(dns, "lookup", function lookup(original) { @@ -50,13 +48,9 @@ const context: Context = { }; t.test("it works", (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - new Token("123"), - undefined - ); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([new HTTPRequest()]); t.same(agent.getHostnames().asArray(), []); diff --git a/library/sinks/HTTPRequest.ts b/library/sinks/HTTPRequest.ts index a428c63a5..660aca7c1 100644 --- a/library/sinks/HTTPRequest.ts +++ b/library/sinks/HTTPRequest.ts @@ -3,7 +3,7 @@ import { type RequestOptions } from "http"; import { Agent } from "../agent/Agent"; import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { Wrapper } from "../agent/Wrapper"; import { getPortFromURL } from "../helpers/getPortFromURL"; import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; @@ -11,6 +11,7 @@ import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupC import { isRedirectToPrivateIP } from "../vulnerabilities/ssrf/isRedirectToPrivateIP"; import { getUrlFromHTTPRequestArgs } from "./http-request/getUrlFromHTTPRequestArgs"; import { wrapResponseHandler } from "./http-request/wrapResponseHandler"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { isOptionsObject } from "./http-request/isOptionsObject"; export class HTTPRequest implements Wrapper { @@ -159,26 +160,22 @@ export class HTTPRequest implements Wrapper { wrap(hooks: Hooks) { const modules = ["http", "https"] as const; - - modules.forEach((module) => { - hooks - .addBuiltinModule(module) - .addSubject((exports) => exports) - // Whenever a request is made, we'll check the hostname whether it's a private IP - .inspect("request", (args, subject, agent) => - this.inspectHttpRequest(args, agent, module) - ) - .inspect("get", (args, subject, agent) => - this.inspectHttpRequest(args, agent, module) - ) - // Whenever a request is made, we'll modify the options to pass a custom lookup function - // that will inspect resolved IP address (and thus preventing TOCTOU attacks) - .modifyArguments("request", (args, subject, agent) => { - return this.monitorDNSLookups(args, agent, module); - }) - .modifyArguments("get", (args, subject, agent) => { - return this.monitorDNSLookups(args, agent, module); - }); - }); + const methods = ["request", "get"] as const; + + for (const module of modules) { + hooks.addBuiltinModule(module).onRequire((exports, pkgInfo) => { + for (const method of methods) { + wrapExport(exports, method, pkgInfo, { + // Whenever a request is made, we'll check the hostname whether it's a private IP + inspectArgs: (args, agent) => + this.inspectHttpRequest(args, agent, module), + // Whenever a request is made, we'll modify the options to pass a custom lookup function + // that will inspect resolved IP address (and thus preventing TOCTOU attacks) + modifyArgs: (args, agent) => + this.monitorDNSLookups(args, agent, module), + }); + } + }); + } } } diff --git a/library/sinks/MongoDB.test.ts b/library/sinks/MongoDB.test.ts index b8feba3b6..5eddca076 100644 --- a/library/sinks/MongoDB.test.ts +++ b/library/sinks/MongoDB.test.ts @@ -1,9 +1,7 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { MongoDB } from "./MongoDB"; +import { createTestAgent } from "../helpers/createTestAgent"; const unsafeContext: Context = { remoteAddress: "::1", @@ -36,13 +34,9 @@ const safeContext: Context = { }; t.test("it inspects method calls and blocks if needed", async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent({ + serverless: "lambda", + }); agent.start([new MongoDB()]); const { MongoClient } = require("mongodb"); diff --git a/library/sinks/MongoDB.ts b/library/sinks/MongoDB.ts index 9179843b4..05fd28ef2 100644 --- a/library/sinks/MongoDB.ts +++ b/library/sinks/MongoDB.ts @@ -1,11 +1,12 @@ /* eslint-disable prefer-rest-params */ import type { Collection } from "mongodb"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { detectNoSQLInjection } from "../vulnerabilities/nosql-injection/detectNoSQLInjection"; import { isPlainObject } from "../helpers/isPlainObject"; import { Context, getContext } from "../agent/Context"; import { Wrapper } from "../agent/Wrapper"; +import { wrapExport } from "../agent/hooks/wrapExport"; const OPERATIONS_WITH_FILTER = [ "count", @@ -186,30 +187,33 @@ export class MongoDB implements Wrapper { } wrap(hooks: Hooks) { - const mongodb = hooks + hooks .addPackage("mongodb") - .withVersion("^4.0.0 || ^5.0.0 || ^6.0.0"); - - const collection = mongodb.addSubject( - (exports) => exports.Collection.prototype - ); - - OPERATIONS_WITH_FILTER.forEach((operation) => { - collection.inspect(operation, (args, collection) => - this.inspectOperation(operation, args, collection as Collection) - ); - }); - - collection.inspect("distinct", (args, collection) => - this.inspectDistinct(args, collection as Collection) - ); - - collection.inspect("bulkWrite", (args, collection) => - this.inspectBulkWrite(args, collection as Collection) - ); - - collection.inspect("aggregate", (args, collection) => - this.inspectAggregate(args, collection as Collection) - ); + .withVersion("^4.0.0 || ^5.0.0 || ^6.0.0") + .onRequire((exports, pkgInfo) => { + const collectionProto = exports.Collection.prototype; + + OPERATIONS_WITH_FILTER.forEach((operation) => { + wrapExport(collectionProto, operation, pkgInfo, { + inspectArgs: (args, agent, collection) => + this.inspectOperation(operation, args, collection as Collection), + }); + }); + + wrapExport(collectionProto, "bulkWrite", pkgInfo, { + inspectArgs: (args, agent, collection) => + this.inspectBulkWrite(args, collection as Collection), + }); + + wrapExport(collectionProto, "aggregate", pkgInfo, { + inspectArgs: (args, agent, collection) => + this.inspectAggregate(args, collection as Collection), + }); + + wrapExport(collectionProto, "distinct", pkgInfo, { + inspectArgs: (args, agent, collection) => + this.inspectDistinct(args, collection as Collection), + }); + }); } } diff --git a/library/sinks/MySQL.test.ts b/library/sinks/MySQL.test.ts index 654a4d4ef..055bdd045 100644 --- a/library/sinks/MySQL.test.ts +++ b/library/sinks/MySQL.test.ts @@ -1,10 +1,8 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { getContext, runWithContext, type Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { MySQL } from "./MySQL"; import type { Connection } from "mysql"; +import { createTestAgent } from "../helpers/createTestAgent"; function query(sql: string, connection: Connection) { return new Promise((resolve, reject) => { @@ -46,13 +44,7 @@ const context: Context = { }; t.test("it detects SQL injections", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent(); agent.start([new MySQL()]); const mysql = require("mysql"); diff --git a/library/sinks/MySQL.ts b/library/sinks/MySQL.ts index 56d565c8b..58749ffba 100644 --- a/library/sinks/MySQL.ts +++ b/library/sinks/MySQL.ts @@ -1,6 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/checkContextForSqlInjection"; @@ -48,12 +49,13 @@ export class MySQL implements Wrapper { } wrap(hooks: Hooks) { - const mysql = hooks.addPackage("mysql").withVersion("^2.0.0"); - - const connection = mysql - .addFile("lib/Connection") - .addSubject((exports) => exports.prototype); - - connection.inspect("query", (args) => this.inspectQuery(args)); + hooks + .addPackage("mysql") + .withVersion("^2.0.0") + .onFileRequire("lib/Connection.js", (exports, pkgInfo) => { + wrapExport(exports.prototype, "query", pkgInfo, { + inspectArgs: (args) => this.inspectQuery(args), + }); + }); } } diff --git a/library/sinks/MySQL2.test.ts b/library/sinks/MySQL2.test.ts index a2604aa8a..468780749 100644 --- a/library/sinks/MySQL2.test.ts +++ b/library/sinks/MySQL2.test.ts @@ -1,9 +1,7 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { runWithContext, type Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { MySQL2 } from "./MySQL2"; +import { createTestAgent } from "../helpers/createTestAgent"; const dangerousContext: Context = { remoteAddress: "::1", @@ -34,13 +32,7 @@ const safeContext: Context = { }; t.test("it detects SQL injections", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent(); agent.start([new MySQL2()]); const mysql = require("mysql2/promise"); diff --git a/library/sinks/MySQL2.ts b/library/sinks/MySQL2.ts index 43e5318c2..fe0d1757d 100644 --- a/library/sinks/MySQL2.ts +++ b/library/sinks/MySQL2.ts @@ -1,6 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/checkContextForSqlInjection"; @@ -49,17 +50,20 @@ export class MySQL2 implements Wrapper { } wrap(hooks: Hooks) { - const mysql2 = hooks.addPackage("mysql2").withVersion("^3.0.0"); - const connection = mysql2.addSubject( - (exports) => exports.Connection.prototype - ); - - connection.inspect("query", (args) => - this.inspectQuery("mysql2.query", args) - ); + hooks + .addPackage("mysql2") + .withVersion("^3.0.0") + .onRequire((exports, pkgInfo) => { + // Wrap connection.query + wrapExport(exports.Connection.prototype, "query", pkgInfo, { + inspectArgs: (args, agent) => this.inspectQuery("mysql2.query", args), + }); - connection.inspect("execute", (args) => - this.inspectQuery("mysql2.execute", args) - ); + // Wrap connection.execute + wrapExport(exports.Connection.prototype, "execute", pkgInfo, { + inspectArgs: (args, agent) => + this.inspectQuery("mysql2.execute", args), + }); + }); } } diff --git a/library/sinks/NodeSQLite.test.ts b/library/sinks/NodeSQLite.test.ts index c1450318f..2e7249b6f 100644 --- a/library/sinks/NodeSQLite.test.ts +++ b/library/sinks/NodeSQLite.test.ts @@ -1,10 +1,8 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { runWithContext, type Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { NodeSQLite } from "./NodeSqlite"; import { isPackageInstalled } from "../helpers/isPackageInstalled"; +import { createTestAgent } from "../helpers/createTestAgent"; const dangerousContext: Context = { remoteAddress: "::1", @@ -35,13 +33,7 @@ const safeContext: Context = { }; t.test("does not break when the Node.js version is too low", async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); + const agent = createTestAgent(); agent.start([new NodeSQLite()]); t.end(); @@ -55,13 +47,7 @@ t.test( : false, }, async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); + const agent = createTestAgent(); agent.start([new NodeSQLite()]); const { DatabaseSync } = require("node:sqlite"); diff --git a/library/sinks/NodeSqlite.ts b/library/sinks/NodeSqlite.ts index 93a4f871e..a5c9a44eb 100644 --- a/library/sinks/NodeSqlite.ts +++ b/library/sinks/NodeSqlite.ts @@ -1,6 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/checkContextForSqlInjection"; import type { SQLDialect } from "../vulnerabilities/sql-injection/dialects/SQLDialect"; @@ -10,19 +11,17 @@ export class NodeSQLite implements Wrapper { private readonly dialect: SQLDialect = new SQLDialectSQLite(); wrap(hooks: Hooks) { - const database = hooks - .addBuiltinModule("node:sqlite") - .addSubject((exports) => { - return exports.DatabaseSync.prototype; - }); - const sqlFunctions = ["exec", "prepare"]; - for (const func of sqlFunctions) { - database.inspect(func, (args) => { - return this.inspectQuery(`node:sqlite.${func}`, args); - }); - } + // Omit node: prefix because its an internal module + hooks.addBuiltinModule("sqlite").onRequire((exports, pkgInfo) => { + const dbSyncProto = exports.DatabaseSync.prototype; + for (const func of sqlFunctions) { + wrapExport(dbSyncProto, func, pkgInfo, { + inspectArgs: (args) => this.inspectQuery(`node:sqlite.${func}`, args), + }); + } + }); } private inspectQuery(operation: string, args: unknown[]): InterceptorResult { diff --git a/library/sinks/Path.test.ts b/library/sinks/Path.test.ts index f89598486..52b646236 100644 --- a/library/sinks/Path.test.ts +++ b/library/sinks/Path.test.ts @@ -1,9 +1,7 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Context, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { Path } from "./Path"; +import { createTestAgent } from "../helpers/createTestAgent"; const unsafeContext: Context = { remoteAddress: "::1", @@ -28,13 +26,7 @@ const unsafeAbsoluteContext: Context = { }; t.test("it works", async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); + const agent = createTestAgent(); agent.start([new Path()]); diff --git a/library/sinks/Path.ts b/library/sinks/Path.ts index b201aa57b..0b971781c 100644 --- a/library/sinks/Path.ts +++ b/library/sinks/Path.ts @@ -1,5 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { wrapExport } from "../agent/hooks/wrapExport"; +import { WrapPackageInfo } from "../agent/hooks/WrapPackageInfo"; import { Wrapper } from "../agent/Wrapper"; import { checkContextForPathTraversal } from "../vulnerabilities/path-traversal/checkContextForPathTraversal"; @@ -33,17 +35,18 @@ export class Path implements Wrapper { } wrap(hooks: Hooks): void { - hooks - .addBuiltinModule("path/posix") - .addSubject((exports) => exports) - .inspect("join", (args) => this.inspectPath(args, "join")) - .inspect("resolve", (args) => this.inspectPath(args, "resolve")) - .inspect("normalize", (args) => this.inspectPath(args, "normalize")); - hooks - .addBuiltinModule("path/win32") - .addSubject((exports) => exports) - .inspect("join", (args) => this.inspectPath(args, "join")) - .inspect("resolve", (args) => this.inspectPath(args, "resolve")) - .inspect("normalize", (args) => this.inspectPath(args, "normalize")); + const functions = ["join", "resolve", "normalize"]; + + const onRequire = (exports: any, pkgInfo: WrapPackageInfo) => { + for (const func of functions) { + wrapExport(exports, func, pkgInfo, { + inspectArgs: (args) => this.inspectPath(args, func), + }); + } + }; + + hooks.addBuiltinModule("path").onRequire(onRequire); + hooks.addBuiltinModule("path/posix").onRequire(onRequire); + hooks.addBuiltinModule("path/win32").onRequire(onRequire); } } diff --git a/library/sinks/Postgres.pool.test.ts b/library/sinks/Postgres.pool.test.ts index f5e9310f7..4f7e526c3 100644 --- a/library/sinks/Postgres.pool.test.ts +++ b/library/sinks/Postgres.pool.test.ts @@ -1,9 +1,7 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { runWithContext, type Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { Postgres } from "./Postgres"; +import { createTestAgent } from "../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -21,13 +19,10 @@ const context: Context = { }; t.test("it detects SQL injections", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent({ + serverless: "lambda", + }); + agent.start([new Postgres()]); const { Pool } = require("pg"); diff --git a/library/sinks/Postgres.test.ts b/library/sinks/Postgres.test.ts index cf269d66f..ed4d6c4f8 100644 --- a/library/sinks/Postgres.test.ts +++ b/library/sinks/Postgres.test.ts @@ -1,9 +1,7 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { getContext, runWithContext, type Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { Postgres } from "./Postgres"; +import { createTestAgent } from "../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -21,13 +19,7 @@ const context: Context = { }; t.test("it inspects query method calls and blocks if needed", async (t) => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent(); agent.start([new Postgres()]); const { Client } = require("pg"); diff --git a/library/sinks/Postgres.ts b/library/sinks/Postgres.ts index a1cfc62e7..171143129 100644 --- a/library/sinks/Postgres.ts +++ b/library/sinks/Postgres.ts @@ -1,11 +1,12 @@ import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { Wrapper } from "../agent/Wrapper"; import { getContext } from "../agent/Context"; import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/checkContextForSqlInjection"; import { SQLDialect } from "../vulnerabilities/sql-injection/dialects/SQLDialect"; import { SQLDialectPostgres } from "../vulnerabilities/sql-injection/dialects/SQLDialectPostgres"; import { isPlainObject } from "../helpers/isPlainObject"; +import { wrapExport } from "../agent/hooks/wrapExport"; export class Postgres implements Wrapper { private readonly dialect: SQLDialect = new SQLDialectPostgres(); @@ -48,9 +49,13 @@ export class Postgres implements Wrapper { } wrap(hooks: Hooks) { - const pg = hooks.addPackage("pg").withVersion("^7.0.0 || ^8.0.0"); - - const client = pg.addSubject((exports) => exports.Client.prototype); - client.inspect("query", (args) => this.inspectQuery(args)); + hooks + .addPackage("pg") + .withVersion("^7.0.0 || ^8.0.0") + .onRequire((exports, pkgInfo) => { + wrapExport(exports.Client.prototype, "query", pkgInfo, { + inspectArgs: (args) => this.inspectQuery(args), + }); + }); } } diff --git a/library/sinks/SQLite3.test.ts b/library/sinks/SQLite3.test.ts index 79b53685e..09f5e9438 100644 --- a/library/sinks/SQLite3.test.ts +++ b/library/sinks/SQLite3.test.ts @@ -1,10 +1,8 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { runWithContext, type Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { SQLite3 } from "./SQLite3"; import { promisify } from "util"; +import { createTestAgent } from "../helpers/createTestAgent"; const dangerousContext: Context = { remoteAddress: "::1", @@ -50,13 +48,9 @@ const safeContext: Context = { }; t.test("it detects SQL injections", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - "lambda" - ); + const agent = createTestAgent({ + serverless: "lambda", + }); agent.start([new SQLite3()]); const sqlite3 = require("sqlite3"); diff --git a/library/sinks/SQLite3.ts b/library/sinks/SQLite3.ts index 25dd5e7b8..7aebce1ee 100644 --- a/library/sinks/SQLite3.ts +++ b/library/sinks/SQLite3.ts @@ -1,6 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { checkContextForPathTraversal } from "../vulnerabilities/path-traversal/checkContextForPathTraversal"; import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/checkContextForSqlInjection"; @@ -61,11 +62,6 @@ export class SQLite3 implements Wrapper { } wrap(hooks: Hooks) { - const sqlite3 = hooks.addPackage("sqlite3").withVersion("^5.0.0"); - const database = sqlite3.addSubject((exports) => { - return exports.Database.prototype; - }); - const sqlFunctions = [ "run", "get", @@ -76,14 +72,25 @@ export class SQLite3 implements Wrapper { "map", ]; - for (const func of sqlFunctions) { - database.inspect(func, (args) => { - return this.inspectQuery(`sqlite3.${func}`, args); - }); - } + hooks + .addPackage("sqlite3") + .withVersion("^5.0.0") + .onRequire((exports, pkgInfo) => { + const db = exports.Database.prototype; - database.inspect("backup", (args) => { - return this.inspectPath(`sqlite3.backup`, args); - }); + for (const func of sqlFunctions) { + wrapExport(db, func, pkgInfo, { + inspectArgs: (args, agent) => { + return this.inspectQuery(`sqlite3.${func}`, args); + }, + }); + } + + wrapExport(db, "backup", pkgInfo, { + inspectArgs: (args, agent) => { + return this.inspectPath(`sqlite3.backup`, args); + }, + }); + }); } } diff --git a/library/sinks/Shelljs.test.ts b/library/sinks/Shelljs.test.ts index e8902924f..26886a985 100644 --- a/library/sinks/Shelljs.test.ts +++ b/library/sinks/Shelljs.test.ts @@ -1,11 +1,9 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; -import { getContext, runWithContext, type Context } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; +import { runWithContext, type Context } from "../agent/Context"; import { Shelljs } from "./Shelljs"; import { ChildProcess } from "./ChildProcess"; import { FileSystem } from "./FileSystem"; +import { createTestAgent } from "../helpers/createTestAgent"; const dangerousContext: Context = { remoteAddress: "::1", @@ -50,16 +48,10 @@ const safeContext: Context = { route: "/posts/:id", }; -t.test("it detects shell injections", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); +const agent = createTestAgent(); +agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); +t.test("it detects shell injections", async () => { const shelljs = require("shelljs"); const error = await t.rejects(async () => { @@ -78,15 +70,6 @@ t.test("it detects shell injections", async () => { }); t.test("it does not detect injection in safe context", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); - const shelljs = require("shelljs"); try { @@ -100,15 +83,6 @@ t.test("it does not detect injection in safe context", async () => { }); t.test("it does not detect injection without context", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); - const shelljs = require("shelljs"); try { @@ -120,15 +94,6 @@ t.test("it does not detect injection without context", async () => { }); t.test("it detects async shell injections", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); - const shelljs = require("shelljs"); const error = await t.rejects(async () => { @@ -175,15 +140,6 @@ t.test("it detects async shell injections", async () => { }); t.test("it prevents path injections using ls", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); - const shelljs = require("shelljs"); const error = await t.rejects(async () => { @@ -199,15 +155,6 @@ t.test("it prevents path injections using ls", async () => { }); t.test("it prevents path injections using cat", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); - const shelljs = require("shelljs"); const error = await t.rejects(async () => { @@ -231,15 +178,6 @@ t.test("it prevents path injections using cat", async () => { t.test( "it does not prevent path injections using cat with safe context", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); - const shelljs = require("shelljs"); try { @@ -254,15 +192,6 @@ t.test( ); t.test("invalid arguments are passed to shelljs", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - agent.start([new Shelljs(), new FileSystem(), new ChildProcess()]); - const shelljs = require("shelljs"); runWithContext(safeContext, () => { diff --git a/library/sinks/Shelljs.ts b/library/sinks/Shelljs.ts index 306091e3d..6e88edc80 100644 --- a/library/sinks/Shelljs.ts +++ b/library/sinks/Shelljs.ts @@ -1,6 +1,7 @@ import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; import { checkContextForShellInjection } from "../vulnerabilities/shell-injection/checkContextForShellInjection"; @@ -40,12 +41,25 @@ export class Shelljs implements Wrapper { } wrap(hooks: Hooks) { - const shelljs = hooks.addPackage("shelljs").withVersion("^0.8.0 || ^0.7.0"); - const exports = shelljs.addSubject((exports) => exports); - - // We need to wrap exec, because shelljs is not using child_process.exec directly, it spawns a subprocess and shares the command via a json file. That subprocess then executes the command. - exports.inspect("exec", (args) => { - return this.inspectExec("exec", args); - }); + hooks + .addPackage("shelljs") + .withVersion("^0.8.0 || ^0.7.0") + // We need to wrap exec, because shelljs is not using child_process.exec directly, it spawns a subprocess and shares the command via a json file. That subprocess then executes the command. + .onFileRequire("src/common.js", (exports, pkgInfo) => { + wrapExport(exports, "register", pkgInfo, { + modifyArgs: (args) => { + if ( + args.length > 0 && + args[0] === "exec" && + typeof args[1] === "function" + ) { + args[1] = wrapExport(args[1], undefined, pkgInfo, { + inspectArgs: (args) => this.inspectExec("exec", args), + }); + } + return args; + }, + }); + }); } } diff --git a/library/sinks/Undici.test.ts b/library/sinks/Undici.test.ts index 4fc8ba2a9..69322724b 100644 --- a/library/sinks/Undici.test.ts +++ b/library/sinks/Undici.test.ts @@ -1,7 +1,6 @@ /* eslint-disable prefer-rest-params */ import * as dns from "dns"; import * as t from "tap"; -import { Agent } from "../agent/Agent"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; @@ -9,6 +8,7 @@ import { LoggerForTesting } from "../agent/logger/LoggerForTesting"; import { wrap } from "../helpers/wrap"; import { getMajorNodeVersion } from "../helpers/getNodeVersion"; import { Undici } from "./Undici"; +import { createTestAgent } from "../helpers/createTestAgent"; const calls: Record = {}; wrap(dns, "lookup", function lookup(original) { @@ -63,7 +63,11 @@ t.test( async (t) => { const logger = new LoggerForTesting(); const api = new ReportingAPIForTesting(); - const agent = new Agent(true, logger, api, new Token("123"), undefined); + const agent = createTestAgent({ + api, + logger, + token: new Token("123"), + }); agent.start([new Undici()]); diff --git a/library/sinks/Undici.ts b/library/sinks/Undici.ts index f320ca0da..f699ed60e 100644 --- a/library/sinks/Undici.ts +++ b/library/sinks/Undici.ts @@ -2,7 +2,7 @@ import { lookup } from "dns"; import { Agent } from "../agent/Agent"; import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; -import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { Wrapper } from "../agent/Wrapper"; import { getMajorNodeVersion, @@ -11,6 +11,7 @@ import { import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls"; import { wrapDispatch } from "./undici/wrapDispatch"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { getHostnameAndPortFromArgs } from "./undici/getHostnameAndPortFromArgs"; const methods = [ @@ -106,31 +107,31 @@ export class Undici implements Wrapper { const undici = hooks .addPackage("undici") .withVersion("^4.0.0 || ^5.0.0 || ^6.0.0") - .addSubject((exports) => exports); - - undici.inspect("setGlobalDispatcher", (args, subject, agent) => { - if (this.patchedGlobalDispatcher) { - agent.log( - `undici.setGlobalDispatcher was called, we can't provide protection!` - ); - } - }); - - methods.forEach((method) => { - undici - // Whenever a request is made, we'll check the hostname whether it's a private IP - .inspect(method, (args, subject, agent) => - this.inspect(args, agent, method) - ) - // We're not really modifying the arguments here, but we need to patch the global dispatcher - .modifyArguments(method, (args, subject, agent) => { - if (!this.patchedGlobalDispatcher) { - this.patchGlobalDispatcher(agent); - this.patchedGlobalDispatcher = true; - } - - return args; + .onRequire((exports, pkgInfo) => { + // Print a warning that we can't provide protection if setGlobalDispatcher is called + wrapExport(exports, "setGlobalDispatcher", pkgInfo, { + inspectArgs: (args, agent) => { + if (this.patchedGlobalDispatcher) { + agent.log( + `undici.setGlobalDispatcher was called, we can't provide protection!` + ); + } + }, }); - }); + // Wrap all methods that can make requests + for (const method of methods) { + wrapExport(exports, method, pkgInfo, { + // Whenever a request is made, we'll check the hostname whether it's a private IP + // If global dispatcher is not patched, we'll patch it + inspectArgs: (args, agent) => { + if (!this.patchedGlobalDispatcher) { + this.patchGlobalDispatcher(agent); + this.patchedGlobalDispatcher = true; + } + return this.inspect(args, agent, method); + }, + }); + } + }); } } diff --git a/library/sources/Express.test.ts b/library/sources/Express.test.ts index 62b64a21b..83805dfbb 100644 --- a/library/sources/Express.test.ts +++ b/library/sources/Express.test.ts @@ -1,19 +1,15 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { setInstance } from "../agent/AgentSingleton"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { setUser } from "../agent/context/user"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { Express } from "./Express"; import { FileSystem } from "../sinks/FileSystem"; import { HTTPServer } from "./HTTPServer"; +import { createTestAgent } from "../helpers/createTestAgent"; // Before require("express") -const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ +const agent = createTestAgent({ + api: new ReportingAPIForTesting({ success: true, endpoints: [ { @@ -62,11 +58,11 @@ const agent = new Agent( heartbeatIntervalInMS: 10 * 60 * 1000, allowedIPAddresses: ["4.3.2.1"], }), - new Token("123"), - "lambda" -); + token: new Token("123"), + serverless: "lambda", +}); + agent.start([new Express(), new FileSystem(), new HTTPServer()]); -setInstance(agent); import * as express from "express"; import * as request from "supertest"; diff --git a/library/sources/Express.ts b/library/sources/Express.ts index c333eb886..3ae0fda7b 100644 --- a/library/sources/Express.ts +++ b/library/sources/Express.ts @@ -5,6 +5,7 @@ import { Agent } from "../agent/Agent"; import { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; import { wrapRequestHandler } from "./express/wrapRequestHandler"; +import { wrapExport } from "../agent/hooks/wrapExport"; export class Express implements Wrapper { // Wrap all the functions passed to app.METHOD(...) @@ -34,24 +35,21 @@ export class Express implements Wrapper { } wrap(hooks: Hooks) { - const express = hooks.addPackage("express").withVersion("^4.0.0 || ^5.0.0"); - - const route = express.addSubject((exports) => exports.Route.prototype); - const expressMethodNames = METHODS.map((method) => method.toLowerCase()); - expressMethodNames.forEach((method) => { - route.modifyArguments(method, (args, subject, agent) => { - return this.wrapArgs(args, agent); + hooks + .addPackage("express") + .withVersion("^4.0.0 || ^5.0.0") + .onRequire((exports, pkgInfo) => { + for (const method of expressMethodNames) { + wrapExport(exports.Route.prototype, method, pkgInfo, { + modifyArgs: (args, agent) => this.wrapArgs(args, agent), + }); + } + + wrapExport(exports.application, "use", pkgInfo, { + modifyArgs: (args, agent) => this.wrapArgs(args, agent), + }); }); - }); - - express - .addSubject((exports) => { - return exports.application; - }) - .modifyArguments("use", (args, subject, agent) => - this.wrapArgs(args, agent) - ); } } diff --git a/library/sources/FastXmlParser.test.ts b/library/sources/FastXmlParser.test.ts index e79cfbeaf..63ca3c128 100644 --- a/library/sources/FastXmlParser.test.ts +++ b/library/sources/FastXmlParser.test.ts @@ -1,18 +1,10 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { getContext, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { FastXmlParser } from "./FastXmlParser"; +import { createTestAgent } from "../helpers/createTestAgent"; t.test("it works", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); + const agent = createTestAgent(); agent.start([new FastXmlParser()]); diff --git a/library/sources/FastXmlParser.ts b/library/sources/FastXmlParser.ts index d6fe475f7..c4ca8d3ce 100644 --- a/library/sources/FastXmlParser.ts +++ b/library/sources/FastXmlParser.ts @@ -1,6 +1,8 @@ /* eslint-disable prefer-rest-params */ import { getContext, updateContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { wrapExport } from "../agent/hooks/wrapExport"; +import { wrapNewInstance } from "../agent/hooks/wrapNewInstance"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; @@ -35,16 +37,18 @@ export class FastXmlParser implements Wrapper { } wrap(hooks: Hooks) { - const fastXmlParser = hooks + hooks .addPackage("fast-xml-parser") - .withVersion("^4.0.0"); - - fastXmlParser - .addSubject((exports) => exports) - .inspectNewInstance("XMLParser") - .addSubject((exports) => exports) - .inspectResult("parse", (args, result) => - this.inspectParse(args, result) - ); + .withVersion("^4.0.0") + .onRequire((exports, pkgInfo) => { + wrapNewInstance(exports, "XMLParser", pkgInfo, (instance) => { + wrapExport(instance, "parse", pkgInfo, { + modifyReturnValue: (args, returnValue) => { + this.inspectParse(args, returnValue); + return returnValue; + }, + }); + }); + }); } } diff --git a/library/sources/FunctionsFramework.test.ts b/library/sources/FunctionsFramework.test.ts index 4b901d439..1e476d1cc 100644 --- a/library/sources/FunctionsFramework.test.ts +++ b/library/sources/FunctionsFramework.test.ts @@ -1,17 +1,15 @@ import * as t from "tap"; import * as express from "express"; import * as request from "supertest"; -import { Agent } from "../agent/Agent"; -import { setInstance } from "../agent/AgentSingleton"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; -import { Token } from "../agent/api/Token"; import { getContext, updateContext } from "../agent/Context"; -import { LoggerForTesting } from "../agent/logger/LoggerForTesting"; import { createCloudFunctionWrapper, FunctionsFramework, } from "./FunctionsFramework"; import * as asyncHandler from "express-async-handler"; +import { createTestAgent } from "../helpers/createTestAgent"; +import { Token } from "../agent/api/Token"; function getExpressApp() { const app = express(); @@ -75,16 +73,10 @@ t.test("it sets context", async (t) => { }); t.test("it counts requests", async (t) => { - const logger = new LoggerForTesting(); - const agent = new Agent( - true, - logger, - new ReportingAPIForTesting(), - undefined, - "gcp" - ); + const agent = createTestAgent({ + serverless: "gcp", + }); agent.start([]); - setInstance(agent); const app = getExpressApp(); @@ -97,16 +89,10 @@ t.test("it counts requests", async (t) => { }); t.test("it counts attacks", async (t) => { - const logger = new LoggerForTesting(); - const agent = new Agent( - true, - logger, - new ReportingAPIForTesting(), - undefined, - "gcp" - ); + const agent = createTestAgent({ + serverless: "gcp", + }); agent.start([]); - setInstance(agent); const app = getExpressApp(); @@ -119,16 +105,10 @@ t.test("it counts attacks", async (t) => { }); t.test("it counts request if error", async (t) => { - const logger = new LoggerForTesting(); - const agent = new Agent( - true, - logger, - new ReportingAPIForTesting(), - undefined, - "gcp" - ); + const agent = createTestAgent({ + serverless: "gcp", + }); agent.start([]); - setInstance(agent); const app = getExpressApp(); @@ -141,11 +121,13 @@ t.test("it counts request if error", async (t) => { }); t.test("it flushes stats first invoke", async (t) => { - const logger = new LoggerForTesting(); const api = new ReportingAPIForTesting(); - const agent = new Agent(true, logger, api, new Token("123"), "gcp"); + const agent = createTestAgent({ + api, + serverless: "gcp", + token: new Token("123"), + }); agent.start([]); - setInstance(agent); api.clear(); @@ -161,16 +143,10 @@ t.test("it flushes stats first invoke", async (t) => { }); t.test("it hooks into functions framework", async () => { - const logger = new LoggerForTesting(); - const agent = new Agent( - true, - logger, - new ReportingAPIForTesting(), - undefined, - "gcp" - ); + const agent = createTestAgent({ + serverless: "gcp", + }); agent.start([new FunctionsFramework()]); - setInstance(agent); const framework = require("@google-cloud/functions-framework"); framework.http("hello", (req, res) => { diff --git a/library/sources/FunctionsFramework.ts b/library/sources/FunctionsFramework.ts index 30ebfc22b..4fe1dcc2c 100644 --- a/library/sources/FunctionsFramework.ts +++ b/library/sources/FunctionsFramework.ts @@ -1,6 +1,7 @@ import { getInstance } from "../agent/AgentSingleton"; import { getContext, runWithContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import type { HttpFunction } from "@google-cloud/functions-framework"; @@ -56,17 +57,18 @@ export class FunctionsFramework implements Wrapper { wrap(hooks: Hooks) { const functions = hooks .addPackage("@google-cloud/functions-framework") - .withVersion("^3.0.0"); - - functions - .addSubject((exports) => exports) - .modifyArguments("http", (args) => { - if (args.length === 2 && typeof args[1] === "function") { - const httpFunction = args[1] as HttpFunction; - args[1] = createCloudFunctionWrapper(httpFunction); - } + .withVersion("^3.0.0") + .onRequire((exports, pkgInfo) => { + wrapExport(exports, "http", pkgInfo, { + modifyArgs: (args) => { + if (args.length === 2 && typeof args[1] === "function") { + const httpFunction = args[1] as HttpFunction; + args[1] = createCloudFunctionWrapper(httpFunction); + } - return args; + return args; + }, + }); }); } } diff --git a/library/sources/GraphQL.test.ts b/library/sources/GraphQL.test.ts new file mode 100644 index 000000000..900e60b19 --- /dev/null +++ b/library/sources/GraphQL.test.ts @@ -0,0 +1,89 @@ +import * as t from "tap"; +import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; +import { getContext, runWithContext } from "../agent/Context"; +import { GraphQL } from "./GraphQL"; +import { Token } from "../agent/api/Token"; +import { createTestAgent } from "../helpers/createTestAgent"; + +t.test("it works", async () => { + const agent = createTestAgent({ + api: new ReportingAPIForTesting({ + success: true, + endpoints: [ + { + method: "POST", + route: "/graphql", + forceProtectionOff: false, + rateLimiting: { + enabled: true, + maxRequests: 3, + windowSizeInMS: 60 * 1000, + }, + graphql: { + name: "getFile", + type: "query", + }, + }, + ], + allowedIPAddresses: [], + configUpdatedAt: 0, + heartbeatIntervalInMS: 10 * 60 * 1000, + blockedUserIds: [], + }), + token: new Token("123"), + }); + + agent.start([new GraphQL()]); + + const { graphql, buildSchema } = require("graphql"); + + const schema = buildSchema(` + type Query { + getFile(path: String): String + } + `); + + const root = { + getFile: ({ path }: { path: string }) => { + return "file content"; + }, + }; + + const query = async (path: string) => { + return await graphql({ + schema, + source: `{ getFile(path: "${path}") }`, + rootValue: root, + }); + }; + + const context = { + remoteAddress: "::1", + method: "POST", + url: "http://localhost:4000/graphql", + query: {}, + headers: {}, + body: { query: '{ getFile(path: "/etc/bashrc") }' }, + cookies: {}, + routeParams: {}, + source: "express", + route: "/graphql", + }; + + await query("/etc/bashrc"); + + await runWithContext(context, async () => { + await query("/etc/bashrc"); + t.same(getContext()?.graphql, ["/etc/bashrc"]); + }); + + // Rate limiting works + await runWithContext(context, async () => { + const success = await query("/etc/bashrc"); + t.same(success.data.getFile, "file content"); + await query("/etc/bashrc"); + await query("/etc/bashrc"); + const result = await query("/etc/bashrc"); + t.same(result.errors[0].message, "You are rate limited by Zen."); + }); +}); diff --git a/library/sources/GraphQL.ts b/library/sources/GraphQL.ts index fc2d9a0a8..9c26fb626 100644 --- a/library/sources/GraphQL.ts +++ b/library/sources/GraphQL.ts @@ -1,6 +1,5 @@ /* eslint-disable prefer-rest-params */ import { Agent } from "../agent/Agent"; -import { getInstance } from "../agent/AgentSingleton"; import { getContext, updateContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; @@ -8,15 +7,11 @@ import type { ExecutionArgs } from "graphql/execution/execute"; import { isPlainObject } from "../helpers/isPlainObject"; import { extractInputsFromDocument } from "./graphql/extractInputsFromDocument"; import { extractTopLevelFieldsFromDocument } from "./graphql/extractTopLevelFieldsFromDocument"; -import { wrap } from "../helpers/wrap"; import { shouldRateLimitOperation } from "./graphql/shouldRateLimitOperation"; +import { wrapExport } from "../agent/hooks/wrapExport"; export class GraphQL implements Wrapper { - private inspectGraphQLExecute( - args: unknown[], - subject: unknown, - agent: Agent - ): void { + private inspectGraphQLExecute(args: unknown[], agent: Agent): void { if (!Array.isArray(args) || typeof args[0] !== "object") { return; } @@ -65,61 +60,61 @@ export class GraphQL implements Wrapper { } } - private createExecuteWrapper(original: Function) { - const { GraphQLError } = require("graphql"); - - return function wrappedExecute(this: unknown) { - const context = getContext(); - const agent = getInstance(); - - if (!context || !agent) { - return original.apply(this, arguments); - } - - const args = Array.from(arguments); + private handleRateLimiting( + args: unknown[], + origReturnVal: unknown, + agent: Agent + ) { + const context = getContext(); - if (!Array.isArray(args) || !isPlainObject(args[0])) { - return original.apply(this, arguments); - } + if (!context || !agent) { + return origReturnVal; + } - const result = shouldRateLimitOperation( - agent, - context, - args[0] as unknown as ExecutionArgs - ); + if (!Array.isArray(args) || !isPlainObject(args[0])) { + return origReturnVal; + } - if (result.block) { - return { - errors: [ - new GraphQLError("You are rate limited by Aikido firewall.", { - nodes: [result.field], - extensions: { - code: "RATE_LIMITED_BY_AIKIDO_FIREWALL", - ipAddress: context.remoteAddress, - }, - }), - ], - }; - } + const result = shouldRateLimitOperation( + agent, + context, + args[0] as unknown as ExecutionArgs + ); + + if (result.block) { + const { GraphQLError } = require("graphql"); + + return { + errors: [ + new GraphQLError("You are rate limited by Zen.", { + nodes: [result.field], + extensions: { + code: "RATE_LIMITED_BY_ZEN", + ipAddress: context.remoteAddress, + }, + }), + ], + }; + } - return original.apply(this, arguments); - }; + return origReturnVal; } wrap(hooks: Hooks) { + const methods = ["execute", "executeSync"] as const; + hooks .addPackage("graphql") .withVersion("^16.0.0") - .addFile("execution/execute.js") - .addSubject((exports) => { - // We don't have a hook yet to modify the return value of a function - // We need to refactor this system to allow for that - // For now, we'll wrap the execute function manually - wrap(exports, "execute", this.createExecuteWrapper); - - return exports; - }) - .inspect("execute", this.inspectGraphQLExecute) - .inspect("executeSync", this.inspectGraphQLExecute); + .onFileRequire("execution/execute.js", (exports, pkgInfo) => { + for (const method of methods) { + wrapExport(exports, method, pkgInfo, { + modifyReturnValue: (args, returnValue, agent) => + this.handleRateLimiting(args, returnValue, agent), + inspectArgs: (args, agent) => + this.inspectGraphQLExecute(args, agent), + }); + } + }); } } diff --git a/library/sources/HTTP2Server.test.ts b/library/sources/HTTP2Server.test.ts index 202eb82ae..0b4bb0828 100644 --- a/library/sources/HTTP2Server.test.ts +++ b/library/sources/HTTP2Server.test.ts @@ -1,17 +1,15 @@ import * as t from "tap"; import { Token } from "../agent/api/Token"; import { connect, IncomingHttpHeaders } from "http2"; -import { Agent } from "../agent/Agent"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { getContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { HTTPServer } from "./HTTPServer"; import { isLocalhostIP } from "../helpers/isLocalhostIP"; import { wrap } from "../helpers/wrap"; import * as pkg from "../helpers/isPackageInstalled"; -import { readFileSync } from "fs"; import { resolve } from "path"; import { FileSystem } from "../sinks/FileSystem"; +import { createTestAgent } from "../helpers/createTestAgent"; const originalIsPackageInstalled = pkg.isPackageInstalled; wrap(pkg, "isPackageInstalled", function wrap() { @@ -57,15 +55,14 @@ const api = new ReportingAPIForTesting({ ], heartbeatIntervalInMS: 10 * 60 * 1000, }); -const agent = new Agent( - true, - new LoggerNoop(), +const agent = createTestAgent({ + token: new Token("123"), api, - new Token("abc"), - undefined -); +}); agent.start([new HTTPServer(), new FileSystem()]); +const { readFileSync } = require("fs"); + t.beforeEach(() => { delete process.env.AIKIDO_MAX_BODY_SIZE_MB; }); diff --git a/library/sources/HTTPServer.test.ts b/library/sources/HTTPServer.test.ts index f0f5f4b31..d828e0ce8 100644 --- a/library/sources/HTTPServer.test.ts +++ b/library/sources/HTTPServer.test.ts @@ -15,12 +15,11 @@ wrap(pkg, "isPackageInstalled", function wrap() { }); import * as t from "tap"; -import { Agent } from "../agent/Agent"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { getContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { fetch } from "../helpers/fetch"; import { HTTPServer } from "./HTTPServer"; +import { createTestAgent } from "../helpers/createTestAgent"; // Before require("http") const api = new ReportingAPIForTesting({ @@ -50,13 +49,10 @@ const api = new ReportingAPIForTesting({ ], heartbeatIntervalInMS: 10 * 60 * 1000, }); -const agent = new Agent( - true, - new LoggerNoop(), +const agent = createTestAgent({ + token: new Token("123"), api, - new Token("abc"), - "lambda" -); +}); agent.start([new HTTPServer()]); t.setTimeout(30 * 1000); diff --git a/library/sources/HTTPServer.ts b/library/sources/HTTPServer.ts index fa122f303..da4a0beb2 100644 --- a/library/sources/HTTPServer.ts +++ b/library/sources/HTTPServer.ts @@ -1,5 +1,7 @@ import { Agent } from "../agent/Agent"; import { Hooks } from "../agent/hooks/Hooks"; +import { wrapExport } from "../agent/hooks/wrapExport"; +import { wrapNewInstance } from "../agent/hooks/wrapNewInstance"; import { Wrapper } from "../agent/Wrapper"; import { isPackageInstalled } from "../helpers/isPackageInstalled"; import { createRequestListener } from "./http-server/createRequestListener"; @@ -54,42 +56,51 @@ export class HTTPServer implements Wrapper { wrap(hooks: Hooks) { ["http", "https", "http2"].forEach((module) => { - const subjects = hooks - .addBuiltinModule(module) - .addSubject((exports) => exports); + hooks.addBuiltinModule(module).onRequire((exports, pkgInfo) => { + // Server classes are not exported in the http2 module + if (module !== "http2") { + wrapExport(exports, "Server", pkgInfo, { + modifyArgs: (args, agent) => { + return this.wrapRequestListener(args, module, agent); + }, + }); + } - // Server classes are not exported in the http2 module - if (module !== "http2") { - subjects.modifyArguments("Server", (args, subject, agent) => { - return this.wrapRequestListener(args, module, agent); + wrapExport(exports, "createServer", pkgInfo, { + modifyArgs: (args, agent) => { + return this.wrapRequestListener(args, module, agent); + }, }); - } - subjects - .modifyArguments("createServer", (args, subject, agent) => { - return this.wrapRequestListener(args, module, agent); - }) - .inspectNewInstance("createServer") - .addSubject((exports) => exports) - .modifyArguments("on", (args, subject, agent) => { - return this.wrapOn(args, module, agent); + wrapNewInstance(exports, "createServer", pkgInfo, (instance) => { + wrapExport(instance, "on", pkgInfo, { + modifyArgs: (args, agent) => { + return this.wrapOn(args, module, agent); + }, + }); }); - if (module === "http2") { - subjects.modifyArguments( - "createSecureServer", - (args, subject, agent) => { - return this.wrapRequestListener(args, module, agent); - } - ); - - subjects - .inspectNewInstance("createSecureServer") - .addSubject((exports) => exports) - .modifyArguments("on", (args, subject, agent) => { - return this.wrapOn(args, module, agent); + if (module === "http2") { + wrapExport(exports, "createSecureServer", pkgInfo, { + modifyArgs: (args, agent) => { + return this.wrapRequestListener(args, module, agent); + }, }); - } + + wrapNewInstance( + exports, + "createSecureServer", + pkgInfo, + (instance) => { + wrapExport(instance, "on", pkgInfo, { + modifyArgs: (args, agent) => { + return this.wrapOn(args, module, agent); + }, + }); + } + ); + } + }); }); } } diff --git a/library/sources/Hapi.test.ts b/library/sources/Hapi.test.ts index 5f86e2836..7763f99fb 100644 --- a/library/sources/Hapi.test.ts +++ b/library/sources/Hapi.test.ts @@ -1,18 +1,14 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { setInstance } from "../agent/AgentSingleton"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { setUser } from "../agent/context/user"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { Hapi } from "./Hapi"; import { FileSystem } from "../sinks/FileSystem"; import { HTTPServer } from "./HTTPServer"; +import { createTestAgent } from "../helpers/createTestAgent"; -const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ +const agent = createTestAgent({ + api: new ReportingAPIForTesting({ success: true, endpoints: [ { @@ -31,11 +27,9 @@ const agent = new Agent( heartbeatIntervalInMS: 10 * 60 * 1000, allowedIPAddresses: ["4.3.2.1"], }), - new Token("123"), - undefined -); + token: new Token("123"), +}); agent.start([new Hapi(), new FileSystem(), new HTTPServer()]); -setInstance(agent); import * as hapi from "@hapi/hapi"; import * as request from "supertest"; diff --git a/library/sources/Hapi.ts b/library/sources/Hapi.ts index 728be748c..626b4def4 100644 --- a/library/sources/Hapi.ts +++ b/library/sources/Hapi.ts @@ -9,6 +9,9 @@ import { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; import { wrapRequestHandler } from "./hapi/wrapRequestHandler"; +import { wrapNewInstance } from "../agent/hooks/wrapNewInstance"; +import { WrapPackageInfo } from "../agent/hooks/WrapPackageInfo"; +import { wrapExport } from "../agent/hooks/wrapExport"; export class Hapi implements Wrapper { private wrapRouteHandler(args: unknown[], agent: Agent) { @@ -79,25 +82,29 @@ export class Hapi implements Wrapper { return args; } - wrap(hooks: Hooks) { - const hapi = hooks.addPackage("@hapi/hapi").withVersion("^21.0.0"); - const exports = hapi.addSubject((exports) => exports); - - const subjects = [ - exports.inspectNewInstance("server").addSubject((exports) => exports), - exports.inspectNewInstance("Server").addSubject((exports) => exports), - ]; + private wrapServer(server: unknown, pkgInfo: WrapPackageInfo) { + wrapExport(server, "route", pkgInfo, { + modifyArgs: (args, agent) => this.wrapRouteHandler(args, agent), + }); + wrapExport(server, "ext", pkgInfo, { + modifyArgs: (args, agent) => this.wrapExtensionFunction(args, agent), + }); + wrapExport(server, "decorate", pkgInfo, { + modifyArgs: (args, agent) => this.wrapDecorateFunction(args, agent), + }); + } - for (const subject of subjects) { - subject.modifyArguments("route", (args, subject, agent) => { - return this.wrapRouteHandler(args, agent); - }); - subject.modifyArguments("ext", (args, subject, agent) => { - return this.wrapExtensionFunction(args, agent); - }); - subject.modifyArguments("decorate", (args, subject, agent) => { - return this.wrapDecorateFunction(args, agent); + wrap(hooks: Hooks) { + const hapi = hooks + .addPackage("@hapi/hapi") + .withVersion("^21.0.0") + .onRequire((exports, pkgInfo) => { + wrapNewInstance(exports, "Server", pkgInfo, (server) => { + this.wrapServer(server, pkgInfo); + }); + wrapNewInstance(exports, "server", pkgInfo, (server) => { + this.wrapServer(server, pkgInfo); + }); }); - } } } diff --git a/library/sources/Hono.test.ts b/library/sources/Hono.test.ts index 2ed0c5b2c..5ef999949 100644 --- a/library/sources/Hono.test.ts +++ b/library/sources/Hono.test.ts @@ -1,21 +1,18 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { setInstance } from "../agent/AgentSingleton"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { setUser } from "../agent/context/user"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { Hono as HonoInternal } from "./Hono"; import { HTTPServer } from "./HTTPServer"; import { getMajorNodeVersion } from "../helpers/getNodeVersion"; import { fetch } from "../helpers/fetch"; import { getContext } from "../agent/Context"; import { isLocalhostIP } from "../helpers/isLocalhostIP"; +import { createTestAgent } from "../helpers/createTestAgent"; -const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ +const agent = createTestAgent({ + token: new Token("123"), + api: new ReportingAPIForTesting({ success: true, endpoints: [ { @@ -34,11 +31,8 @@ const agent = new Agent( heartbeatIntervalInMS: 10 * 60 * 1000, allowedIPAddresses: ["4.3.2.1"], }), - new Token("123"), - undefined -); +}); agent.start([new HonoInternal(), new HTTPServer()]); -setInstance(agent); function getApp() { const { Hono } = require("hono"); diff --git a/library/sources/Hono.ts b/library/sources/Hono.ts index bb3639902..4774448a7 100644 --- a/library/sources/Hono.ts +++ b/library/sources/Hono.ts @@ -4,6 +4,7 @@ import { Agent } from "../agent/Agent"; import { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; import { wrapRequestHandler } from "./hono/wrapRequestHandler"; +import { wrapExport } from "../agent/hooks/wrapExport"; export class Hono implements Wrapper { // Wrap all the functions passed to hono.METHOD(...) @@ -25,17 +26,18 @@ export class Hono implements Wrapper { } wrap(hooks: Hooks) { - const hono = hooks + hooks .addPackage("hono") .withVersion("^4.0.0") - .addFile("hono-base"); - - hono - .addSubject((exports) => { - return exports.HonoBase.prototype; + .onFileRequire("dist/hono-base.js", (exports, pkgInfo) => { + wrapExport(exports.HonoBase.prototype, "addRoute", pkgInfo, { + modifyArgs: (args, agent) => this.wrapArgs(args, agent), + }); }) - .modifyArguments("addRoute", (args, original, agent) => { - return this.wrapArgs(args, agent); + .onFileRequire("dist/cjs/hono-base.js", (exports, pkgInfo) => { + wrapExport(exports.HonoBase.prototype, "addRoute", pkgInfo, { + modifyArgs: (args, agent) => this.wrapArgs(args, agent), + }); }); } } diff --git a/library/sources/Lambda.test.ts b/library/sources/Lambda.test.ts index bc0d2173f..300782156 100644 --- a/library/sources/Lambda.test.ts +++ b/library/sources/Lambda.test.ts @@ -1,13 +1,11 @@ import * as FakeTimers from "@sinonjs/fake-timers"; import type { Context } from "aws-lambda"; import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { setInstance } from "../agent/AgentSingleton"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { getContext, updateContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { createLambdaWrapper, SQSEvent, APIGatewayProxyEvent } from "./Lambda"; +import { createTestAgent } from "../helpers/createTestAgent"; const gatewayEvent: APIGatewayProxyEvent = { resource: "/dev/{proxy+}", @@ -193,11 +191,14 @@ t.test("it passes through unknown types of events", async () => { t.test("it sends heartbeat after first and every 10 minutes", async () => { const clock = FakeTimers.install(); - const logger = new LoggerNoop(); const testing = new ReportingAPIForTesting(); - const agent = new Agent(false, logger, testing, new Token("123"), "lambda"); + const agent = createTestAgent({ + block: false, + token: new Token("token"), + serverless: "lambda", + api: testing, + }); agent.start([]); - setInstance(agent); const handler = createLambdaWrapper(async (event, context) => { return getContext(); @@ -309,11 +310,13 @@ t.test( async () => { const clock = FakeTimers.install(); - const logger = new LoggerNoop(); const testing = new ReportingAPIForTesting(); - const agent = new Agent(false, logger, testing, undefined, "lambda"); + const agent = createTestAgent({ + block: false, + serverless: "lambda", + api: testing, + }); agent.start([]); - setInstance(agent); const handler = createLambdaWrapper(async (event, context) => { return getContext(); @@ -341,11 +344,14 @@ t.test( t.test("if handler throws it still sends heartbeat", async () => { const clock = FakeTimers.install(); - const logger = new LoggerNoop(); const testing = new ReportingAPIForTesting(); - const agent = new Agent(false, logger, testing, new Token("token"), "lambda"); + const agent = createTestAgent({ + block: false, + token: new Token("token"), + serverless: "lambda", + api: testing, + }); agent.start([]); - setInstance(agent); testing.clear(); @@ -419,11 +425,12 @@ t.test("no cookie header", async () => { }); t.test("it counts attacks", async () => { - const logger = new LoggerNoop(); - const testing = new ReportingAPIForTesting(); - const agent = new Agent(false, logger, testing, new Token("token"), "lambda"); + const agent = createTestAgent({ + block: false, + token: new Token("token"), + serverless: "lambda", + }); agent.start([]); - setInstance(agent); const handler = createLambdaWrapper(async (event, context) => { const ctx = getContext(); diff --git a/library/sources/PubSub.ts b/library/sources/PubSub.ts index c7acee2b5..5fd790433 100644 --- a/library/sources/PubSub.ts +++ b/library/sources/PubSub.ts @@ -1,29 +1,30 @@ import { runWithContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import type { Message } from "@google-cloud/pubsub"; export class PubSub implements Wrapper { wrap(hooks: Hooks) { - const pubSub = hooks + hooks .addPackage("@google-cloud/pubsub") - .withVersion("^4.0.0"); + .withVersion("^4.0.0") + .onFileRequire("build/src/subscription.js", (exports, pkgInfo) => { + wrapExport(exports.Subscription.prototype, "on", pkgInfo, { + modifyArgs: (args) => { + if ( + args.length > 0 && + typeof args[0] === "string" && + args[0] === "message" && + typeof args[1] === "function" + ) { + const originalCallback = args[1]; + args[1] = handleMessage(originalCallback); + } - pubSub - .addFile("build/src/subscription.js") - .addSubject((exports) => exports.Subscription.prototype) - .modifyArguments("on", (args) => { - if ( - args.length > 0 && - typeof args[0] === "string" && - args[0] === "message" && - typeof args[1] === "function" - ) { - const originalCallback = args[1]; - args[1] = handleMessage(originalCallback); - } - - return args; + return args; + }, + }); }); } } diff --git a/library/sources/Xml2js.test.ts b/library/sources/Xml2js.test.ts index 3d0568361..9caabb3cb 100644 --- a/library/sources/Xml2js.test.ts +++ b/library/sources/Xml2js.test.ts @@ -1,19 +1,10 @@ import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Context, getContext, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { Xml2js } from "./Xml2js"; +import { createTestAgent } from "../helpers/createTestAgent"; t.test("it works", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - + const agent = createTestAgent(); agent.start([new Xml2js()]); const { parseStringPromise, parseString } = require("xml2js"); diff --git a/library/sources/Xml2js.ts b/library/sources/Xml2js.ts index 9abb9afc6..fc4d0e0ce 100644 --- a/library/sources/Xml2js.ts +++ b/library/sources/Xml2js.ts @@ -1,6 +1,7 @@ /* eslint-disable prefer-rest-params */ import { getContext, updateContext, runWithContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; @@ -46,13 +47,13 @@ export class Xml2js implements Wrapper { } wrap(hooks: Hooks) { - const xml2js = hooks + hooks .addPackage("xml2js") - .withVersion("^0.6.0 || ^0.5.0 || ^0.4.18"); - - xml2js - .addSubject((exports) => exports.Parser.prototype) - // Also wraps parseStringPromise and usage without Parser instance - .modifyArguments("parseString", (args) => this.modifyArgs(args)); + .withVersion("^0.6.0 || ^0.5.0 || ^0.4.18") + .onRequire((exports, pkgInfo) => { + wrapExport(exports.Parser.prototype, "parseString", pkgInfo, { + modifyArgs: (args) => this.modifyArgs(args), + }); + }); } } diff --git a/library/sources/XmlMinusJs.test.ts b/library/sources/XmlMinusJs.test.ts index 4d4311f42..abdc41685 100644 --- a/library/sources/XmlMinusJs.test.ts +++ b/library/sources/XmlMinusJs.test.ts @@ -1,20 +1,12 @@ import { join } from "path"; import * as t from "tap"; -import { Agent } from "../agent/Agent"; -import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { getContext, runWithContext } from "../agent/Context"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; import { XmlMinusJs } from "./XmlMinusJs"; import { readFile } from "fs/promises"; +import { createTestAgent } from "../helpers/createTestAgent"; t.test("xml2js works", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); + const agent = createTestAgent(); agent.start([new XmlMinusJs()]); @@ -53,13 +45,7 @@ t.test("xml2js works", async () => { }); t.test("xml2json works", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); + const agent = createTestAgent(); agent.start([new XmlMinusJs()]); @@ -94,14 +80,7 @@ t.test("xml2json works", async () => { }); t.test("Ignore if xml is not in the body", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - undefined, - undefined - ); - + const agent = createTestAgent(); agent.start([new XmlMinusJs()]); const xmljs = require("xml-js"); diff --git a/library/sources/XmlMinusJs.ts b/library/sources/XmlMinusJs.ts index b07bbe859..5bc1f42dd 100644 --- a/library/sources/XmlMinusJs.ts +++ b/library/sources/XmlMinusJs.ts @@ -1,6 +1,7 @@ /* eslint-disable prefer-rest-params */ import { getContext, updateContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { wrapExport } from "../agent/hooks/wrapExport"; import { Wrapper } from "../agent/Wrapper"; import { isPlainObject } from "../helpers/isPlainObject"; @@ -36,15 +37,22 @@ export class XmlMinusJs implements Wrapper { } wrap(hooks: Hooks) { - const xmljs = hooks.addPackage("xml-js").withVersion("^1.0.0"); - - xmljs - .addSubject((exports) => exports) - .inspectResult("xml2js", (args, result) => { - this.inspectParse(args, result, false); - }) - .inspectResult("xml2json", (args, result) => { - this.inspectParse(args, result, true); + hooks + .addPackage("xml-js") + .withVersion("^1.0.0") + .onRequire((exports, pkgInfo) => { + wrapExport(exports, "xml2js", pkgInfo, { + modifyReturnValue: (args, result) => { + this.inspectParse(args, result, false); + return result; + }, + }); + wrapExport(exports, "xml2json", pkgInfo, { + modifyReturnValue: (args, result) => { + this.inspectParse(args, result, true); + return result; + }, + }); }); } } diff --git a/library/sources/graphql/shouldRateLimitOperation.test.ts b/library/sources/graphql/shouldRateLimitOperation.test.ts index 0b0363c79..a30118417 100644 --- a/library/sources/graphql/shouldRateLimitOperation.test.ts +++ b/library/sources/graphql/shouldRateLimitOperation.test.ts @@ -1,23 +1,17 @@ import { parse, ExecutionArgs } from "graphql"; import * as t from "tap"; -import { Agent } from "../../agent/Agent"; import { ReportingAPIForTesting } from "../../agent/api/ReportingAPIForTesting"; import { Token } from "../../agent/api/Token"; import { Context } from "../../agent/Context"; -import { LoggerNoop } from "../../agent/logger/LoggerNoop"; import { shouldRateLimitOperation } from "./shouldRateLimitOperation"; +import { createTestAgent } from "../../helpers/createTestAgent"; type Args = Pick; t.test("it does not rate limit if endpoint not found", async () => { - const token = new Token("123"); - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting(), - token, - undefined - ); + const agent = createTestAgent({ + token: new Token("123"), + }); const args: Args = { document: parse(` @@ -50,11 +44,9 @@ t.test("it does not rate limit if endpoint not found", async () => { }); t.test("it rate limits query", async () => { - const token = new Token("123"); - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ + const agent = createTestAgent({ + token: new Token("123"), + api: new ReportingAPIForTesting({ success: true, endpoints: [ { @@ -77,9 +69,7 @@ t.test("it rate limits query", async () => { heartbeatIntervalInMS: 10 * 60 * 1000, blockedUserIds: [], }), - token, - undefined - ); + }); agent.start([]); await new Promise((resolve) => setTimeout(resolve, 0)); @@ -119,11 +109,9 @@ t.test("it rate limits query", async () => { }); t.test("it rate limits mutation", async () => { - const token = new Token("123"); - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ + const agent = createTestAgent({ + token: new Token("123"), + api: new ReportingAPIForTesting({ success: true, endpoints: [ { @@ -146,9 +134,7 @@ t.test("it rate limits mutation", async () => { heartbeatIntervalInMS: 10 * 60 * 1000, blockedUserIds: [], }), - token, - undefined - ); + }); agent.start([]); await new Promise((resolve) => setTimeout(resolve, 0)); diff --git a/library/sources/http-server/ipAllowedToAccessRoute.test.ts b/library/sources/http-server/ipAllowedToAccessRoute.test.ts index 1e08630dd..f25c97300 100644 --- a/library/sources/http-server/ipAllowedToAccessRoute.test.ts +++ b/library/sources/http-server/ipAllowedToAccessRoute.test.ts @@ -5,6 +5,7 @@ import { Token } from "../../agent/api/Token"; import { Context } from "../../agent/Context"; import { LoggerNoop } from "../../agent/logger/LoggerNoop"; import { ipAllowedToAccessRoute } from "./ipAllowedToAccessRoute"; +import { createTestAgent } from "../../helpers/createTestAgent"; let agent: Agent; const context: Context = { @@ -21,10 +22,9 @@ const context: Context = { }; t.beforeEach(async () => { - agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ + agent = createTestAgent({ + token: new Token("123"), + api: new ReportingAPIForTesting({ success: true, allowedIPAddresses: [], configUpdatedAt: 0, @@ -41,9 +41,7 @@ t.beforeEach(async () => { ], block: true, }), - new Token("123"), - undefined - ); + }); agent.start([]); @@ -86,10 +84,9 @@ t.test("it blocks request if no IP address", async () => { }); t.test("it allows request if configuration is broken", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ + const agent = createTestAgent({ + token: new Token("123"), + api: new ReportingAPIForTesting({ success: true, allowedIPAddresses: [], configUpdatedAt: 0, @@ -107,9 +104,7 @@ t.test("it allows request if configuration is broken", async () => { ], block: true, }), - new Token("123"), - undefined - ); + }); agent.start([]); @@ -122,10 +117,9 @@ t.test("it allows request if configuration is broken", async () => { }); t.test("it allows request if allowed IP addresses is empty", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ + const agent = createTestAgent({ + token: new Token("123"), + api: new ReportingAPIForTesting({ success: true, allowedIPAddresses: [], configUpdatedAt: 0, @@ -142,9 +136,7 @@ t.test("it allows request if allowed IP addresses is empty", async () => { ], block: true, }), - new Token("123"), - undefined - ); + }); agent.start([]); @@ -164,10 +156,9 @@ t.test("it blocks request if not allowed IP address", async () => { }); t.test("it checks every matching endpoint", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ + const agent = createTestAgent({ + token: new Token("123"), + api: new ReportingAPIForTesting({ success: true, allowedIPAddresses: [], configUpdatedAt: 0, @@ -191,9 +182,7 @@ t.test("it checks every matching endpoint", async () => { ], block: true, }), - new Token("123"), - undefined - ); + }); agent.start([]); @@ -208,10 +197,9 @@ t.test("it checks every matching endpoint", async () => { t.test( "if allowed IPs is empty or broken, it ignores the endpoint but does check the other ones", async () => { - const agent = new Agent( - true, - new LoggerNoop(), - new ReportingAPIForTesting({ + const agent = createTestAgent({ + token: new Token("123"), + api: new ReportingAPIForTesting({ success: true, allowedIPAddresses: [], configUpdatedAt: 0, @@ -243,9 +231,7 @@ t.test( ], block: true, }), - new Token("123"), - undefined - ); + }); agent.start([]); diff --git a/library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts b/library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts index d703c5d3c..3fdc67931 100644 --- a/library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts +++ b/library/vulnerabilities/path-traversal/checkContextForPathTraversal.ts @@ -1,5 +1,5 @@ import { Context } from "../../agent/Context"; -import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../../agent/hooks/InterceptorResult"; import { SOURCES } from "../../agent/Source"; import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { detectPathTraversal } from "./detectPathTraversal"; diff --git a/library/vulnerabilities/prototype-pollution/preventPrototypePollution.test.ts b/library/vulnerabilities/prototype-pollution/preventPrototypePollution.test.ts index 691ef466e..8bdaaec53 100644 --- a/library/vulnerabilities/prototype-pollution/preventPrototypePollution.test.ts +++ b/library/vulnerabilities/prototype-pollution/preventPrototypePollution.test.ts @@ -1,13 +1,10 @@ import * as t from "tap"; -import { Agent } from "../../agent/Agent"; -import { setInstance } from "../../agent/AgentSingleton"; -import { ReportingAPIForTesting } from "../../agent/api/ReportingAPIForTesting"; -import { Token } from "../../agent/api/Token"; import { LoggerForTesting } from "../../agent/logger/LoggerForTesting"; import { freezeBuiltinsIfPossible, preventPrototypePollution, } from "./preventPrototypePollution"; +import { createTestAgent } from "../../helpers/createTestAgent"; t.test( "it does not freeze builtins if incompatible package is found", @@ -41,15 +38,10 @@ t.test("without agent instance", async () => { t.test("it lets agent know", async () => { const logger = new LoggerForTesting(); - const agent = new Agent( - true, + const agent = createTestAgent({ logger, - new ReportingAPIForTesting(), - new Token("123"), - undefined - ); + }); - setInstance(agent); preventPrototypePollution(); t.same(logger.getMessages(), ["Prevented prototype pollution!"]); }); diff --git a/library/vulnerabilities/shell-injection/checkContextForShellInjection.ts b/library/vulnerabilities/shell-injection/checkContextForShellInjection.ts index b12f12b8f..ae38cb137 100644 --- a/library/vulnerabilities/shell-injection/checkContextForShellInjection.ts +++ b/library/vulnerabilities/shell-injection/checkContextForShellInjection.ts @@ -1,5 +1,5 @@ import { Context } from "../../agent/Context"; -import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../../agent/hooks/InterceptorResult"; import { SOURCES } from "../../agent/Source"; import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { detectShellInjection } from "./detectShellInjection"; diff --git a/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts b/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts index 76ee96799..d87b12a5d 100644 --- a/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts +++ b/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts @@ -1,5 +1,5 @@ import { Context } from "../../agent/Context"; -import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../../agent/hooks/InterceptorResult"; import { SOURCES } from "../../agent/Source"; import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { detectSQLInjection } from "./detectSQLInjection"; diff --git a/library/vulnerabilities/ssrf/checkContextForSSRF.ts b/library/vulnerabilities/ssrf/checkContextForSSRF.ts index 7cbbac067..996faa581 100644 --- a/library/vulnerabilities/ssrf/checkContextForSSRF.ts +++ b/library/vulnerabilities/ssrf/checkContextForSSRF.ts @@ -1,5 +1,5 @@ import { Context } from "../../agent/Context"; -import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; +import { InterceptorResult } from "../../agent/hooks/InterceptorResult"; import { SOURCES } from "../../agent/Source"; import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { containsPrivateIPAddress } from "./containsPrivateIPAddress"; diff --git a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts index ab186ea13..b1d8c3c09 100644 --- a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts +++ b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts @@ -1,12 +1,12 @@ import { LookupAddress, lookup } from "dns"; import * as t from "tap"; -import { Agent } from "../../agent/Agent"; import { ReportingAPIForTesting } from "../../agent/api/ReportingAPIForTesting"; import { Token } from "../../agent/api/Token"; import { Context, runWithContext } from "../../agent/Context"; import { LoggerNoop } from "../../agent/logger/LoggerNoop"; import { inspectDNSLookupCalls } from "./inspectDNSLookupCalls"; import { getMajorNodeVersion } from "../../helpers/getNodeVersion"; +import { createTestAgent } from "../../helpers/createTestAgent"; const context: Context = { remoteAddress: "::1", @@ -24,10 +24,9 @@ const context: Context = { }; t.test("it resolves private IPv4 without context", (t) => { - const logger = new LoggerNoop(); - const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([]); const wrappedLookup = inspectDNSLookupCalls( @@ -45,10 +44,9 @@ t.test("it resolves private IPv4 without context", (t) => { }); t.test("it resolves private IPv6 without context", (t) => { - const logger = new LoggerNoop(); - const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([]); const wrappedLookup = inspectDNSLookupCalls( @@ -66,10 +64,11 @@ t.test("it resolves private IPv6 without context", (t) => { }); t.test("it blocks lookup in blocking mode", (t) => { - const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + api, + }); agent.start([]); api.clear(); @@ -106,10 +105,11 @@ t.test("it blocks lookup in blocking mode", (t) => { }); t.test("it allows resolved public IP", (t) => { - const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + api, + }); agent.start([]); api.clear(); @@ -136,10 +136,11 @@ t.test("it allows resolved public IP", (t) => { t.test( "it does not block resolved private IP if not found in user input", (t) => { - const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + api, + }); agent.start([]); api.clear(); @@ -164,7 +165,6 @@ t.test( t.test( "it does not block resolved private IP if endpoint protection is turned off", async (t) => { - const logger = new LoggerNoop(); const api = new ReportingAPIForTesting({ success: true, heartbeatIntervalInMS: 10 * 60 * 1000, @@ -184,8 +184,10 @@ t.test( allowedIPAddresses: [], configUpdatedAt: 0, }); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + api, + }); agent.start([]); await new Promise((resolve) => setTimeout(resolve, 0)); @@ -213,10 +215,9 @@ t.test( ); t.test("it blocks lookup in blocking mode with all option", (t) => { - const logger = new LoggerNoop(); - const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([]); const wrappedLookup = inspectDNSLookupCalls( @@ -240,10 +241,12 @@ t.test("it blocks lookup in blocking mode with all option", (t) => { }); t.test("it does not block in dry mode", (t) => { - const logger = new LoggerNoop(); const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(false, logger, api, token, undefined); + const agent = createTestAgent({ + block: false, + token: new Token("123"), + api, + }); agent.start([]); api.clear(); @@ -272,10 +275,9 @@ t.test("it does not block in dry mode", (t) => { }); t.test("it ignores invalid args", (t) => { - const logger = new LoggerNoop(); - const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([]); const wrappedLookup = inspectDNSLookupCalls( @@ -294,10 +296,9 @@ t.test("it ignores invalid args", (t) => { }); t.test("it ignores if lookup returns error", (t) => { - const logger = new LoggerNoop(); - const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([]); const wrappedLookup = inspectDNSLookupCalls( @@ -335,10 +336,9 @@ const imdsMockLookup = ( }; t.test("Blocks IMDS SSRF with untrusted domain", async (t) => { - const logger = new LoggerNoop(); - const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([]); const wrappedLookup = inspectDNSLookupCalls( @@ -399,8 +399,10 @@ t.test( allowedIPAddresses: [], configUpdatedAt: 0, }); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + api, + }); agent.start([]); // Wait for the agent to start @@ -424,10 +426,9 @@ t.test( ); t.test("Does not block IMDS SSRF with Google metadata domain", async (t) => { - const logger = new LoggerNoop(); - const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([]); const wrappedLookup = inspectDNSLookupCalls( @@ -466,10 +467,9 @@ t.test("Does not block IMDS SSRF with Google metadata domain", async (t) => { }); t.test("it ignores when the argument is an IP address", async (t) => { - const logger = new LoggerNoop(); - const api = new ReportingAPIForTesting(); - const token = new Token("123"); - const agent = new Agent(true, logger, api, token, undefined); + const agent = createTestAgent({ + token: new Token("123"), + }); agent.start([]); const wrappedLookup = inspectDNSLookupCalls( diff --git a/sample-apps/hono-xml/package-lock.json b/sample-apps/hono-xml/package-lock.json index 5c3319a7f..197b77784 100644 --- a/sample-apps/hono-xml/package-lock.json +++ b/sample-apps/hono-xml/package-lock.json @@ -27,9 +27,9 @@ "link": true }, "node_modules/@hono/node-server": { - "version": "1.12.2", - "resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.12.2.tgz", - "integrity": "sha512-xjzhqhSWUE/OhN0g3KCNVzNsQMlFUAL+/8GgPUr3TKcU7cvgZVBGswFofJ8WwGEHTqobzze1lDpGJl9ZNckDhA==", + "version": "1.12.0", + "resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.12.0.tgz", + "integrity": "sha512-e6oHjNiErRxsZRZBmc2KucuvY3btlO/XPncIpP2X75bRdTilF9GLjm3NHvKKunpJbbJJj31/FoPTksTf8djAVw==", "license": "MIT", "engines": { "node": ">=18.14.1" @@ -39,9 +39,9 @@ } }, "node_modules/aws-ssl-profiles": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/aws-ssl-profiles/-/aws-ssl-profiles-1.1.2.tgz", - "integrity": "sha512-NZKeq9AfyQvEeNlN0zSYAaWrmBffJh3IELMZfRpJVWgrpEbtEpnjvzqBPf+mxoI287JohRDoa+/nsfqqiZmF6g==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/aws-ssl-profiles/-/aws-ssl-profiles-1.1.1.tgz", + "integrity": "sha512-+H+kuK34PfMaI9PNU/NSjBKL5hh/KDM9J72kwYeYEm0A8B1AC4fuCy3qsjnA7lxklgyXsB68yn8Z2xoZEjgwCQ==", "license": "MIT", "engines": { "node": ">= 6.0.0" @@ -57,9 +57,9 @@ } }, "node_modules/fast-xml-parser": { - "version": "4.5.0", - "resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-4.5.0.tgz", - "integrity": "sha512-/PlTQCI96+fZMAOLMZK4CWG1ItCbfZ/0jx7UIJFChPNrx7tcEgerUgWbeieCM9MfHInUDyK8DWYZ+YrywDJuTg==", + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-4.4.1.tgz", + "integrity": "sha512-xkjOecfnKGkSsOwtZ5Pz7Us/T6mrbPQrq0nh+aCO5V9nk5NLWmasAHumTKjiPJPWANe+kAZ84Jc8ooJkzZ88Sw==", "funding": [ { "type": "github", @@ -88,9 +88,9 @@ } }, "node_modules/hono": { - "version": "4.5.11", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.5.11.tgz", - "integrity": "sha512-62FcjLPtjAFwISVBUshryl+vbHOjg8rE4uIK/dxyR8GpLztunZpwFmfEvmJCUI7xoGh/Sr3CGCDPCmYxVw7wUQ==", + "version": "4.5.5", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.5.5.tgz", + "integrity": "sha512-fXBXHqaVfimWofbelLXci8pZyIwBMkDIwCa4OwZvK+xVbEyYLELVP4DfbGaj1aEM6ZY3hHgs4qLvCO2ChkhgQw==", "license": "MIT", "engines": { "node": ">=16.0.0"