diff --git a/library/sources/HTTPServer.test.ts b/library/sources/HTTPServer.test.ts index 191b0090..992859c2 100644 --- a/library/sources/HTTPServer.test.ts +++ b/library/sources/HTTPServer.test.ts @@ -1,12 +1,14 @@ import { Token } from "../agent/api/Token"; import { getMajorNodeVersion } from "../helpers/getNodeVersion"; - import * as t from "tap"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { getContext } from "../agent/Context"; import { fetch } from "../helpers/fetch"; +import { wrap } from "../helpers/wrap"; import { HTTPServer } from "./HTTPServer"; import { createTestAgent } from "../helpers/createTestAgent"; +import type { Blocklist } from "../agent/api/fetchBlockedLists"; +import * as fetchBlockedLists from "../agent/api/fetchBlockedLists"; // Before require("http") const api = new ReportingAPIForTesting({ @@ -43,6 +45,24 @@ const agent = createTestAgent({ }); agent.start([new HTTPServer()]); +wrap(fetchBlockedLists, "fetchBlockedLists", function fetchBlockedLists() { + return async function fetchBlockedLists(): Promise<{ + blockedIPAddresses: Blocklist[]; + blockedUserAgents: string; + }> { + return { + blockedIPAddresses: [ + { + source: "geoip", + ips: ["9.9.9.9"], + description: "geo restrictions", + }, + ], + blockedUserAgents: "", + }; + }; +}); + t.setTimeout(30 * 1000); t.beforeEach(() => { @@ -576,3 +596,79 @@ t.test("it checks if IP can access route", async (t) => { }); }); }); + +t.test("it blocks IP address", async (t) => { + const server = http.createServer((req, res) => { + res.setHeader("Content-Type", "text/plain"); + res.end("OK"); + }); + + await new Promise((resolve) => { + server.listen(3325, () => { + Promise.all([ + fetch({ + url: new URL("http://localhost:3325"), + method: "GET", + headers: { + "x-forwarded-for": "9.9.9.9", + }, + timeoutInMS: 500, + }), + fetch({ + url: new URL("http://localhost:3325"), + method: "GET", + timeoutInMS: 500, + }), + ]).then(([response1, response2]) => { + t.equal(response1.statusCode, 403); + t.equal( + response1.body, + "Your IP address is blocked due to geo restrictions. (Your IP: 9.9.9.9)" + ); + t.equal(response2.statusCode, 200); + server.close(); + resolve(); + }); + }); + }); +}); + +t.test( + "it blocks IP address when there are multiple request handlers on server", + async (t) => { + const server = http.createServer((req, res) => { + res.setHeader("Content-Type", "text/plain"); + res.end("OK"); + }); + + server.on("request", (req, res) => { + if (res.headersSent) { + return; + } + + res.setHeader("Content-Type", "text/plain"); + res.end("OK"); + }); + + await new Promise((resolve) => { + server.listen(3326, () => { + fetch({ + url: new URL("http://localhost:3326"), + method: "GET", + headers: { + "x-forwarded-for": "9.9.9.9", + }, + timeoutInMS: 500, + }).then(({ statusCode, body }) => { + t.equal(statusCode, 403); + t.equal( + body, + "Your IP address is blocked due to geo restrictions. (Your IP: 9.9.9.9)" + ); + server.close(); + resolve(); + }); + }); + }); + } +); diff --git a/library/sources/http-server/checkIfRequestIsBlocked.ts b/library/sources/http-server/checkIfRequestIsBlocked.ts index 00940dc9..4be81c7f 100644 --- a/library/sources/http-server/checkIfRequestIsBlocked.ts +++ b/library/sources/http-server/checkIfRequestIsBlocked.ts @@ -13,6 +13,12 @@ export function checkIfRequestIsBlocked( res: ServerResponse, agent: Agent ): boolean { + if (res.headersSent) { + // The headers have already been sent, so we can't block the request + // This might happen if the server has multiple listeners + return false; + } + const context = getContext(); if (!context) { diff --git a/library/sources/http-server/createRequestListener.ts b/library/sources/http-server/createRequestListener.ts index 5eda6abb..4f74448a 100644 --- a/library/sources/http-server/createRequestListener.ts +++ b/library/sources/http-server/createRequestListener.ts @@ -57,7 +57,10 @@ function callListenerWithContext( // This method is called when the response is finished and discovers the routes for display in the dashboard // The bindContext function is used to ensure that the context is available in the callback // If using http2, the context is not available in the callback without this - res.on("finish", bindContext(createOnFinishRequestHandler(res, agent))); + res.on( + "finish", + bindContext(createOnFinishRequestHandler(req, res, agent)) + ); if (checkIfRequestIsBlocked(res, agent)) { // The return is necessary to prevent the listener from being called @@ -68,8 +71,24 @@ function callListenerWithContext( }); } -function createOnFinishRequestHandler(res: ServerResponse, agent: Agent) { +// Use symbol to avoid conflicts with other properties +const countedRequest = Symbol("__zen_request_counted__"); + +function createOnFinishRequestHandler( + req: IncomingMessage, + res: ServerResponse, + agent: Agent +) { return function onFinishRequest() { + if ((req as any)[countedRequest]) { + // The request has already been counted + // This might happen if the server has multiple listeners + return; + } + + // Mark the request as counted + (req as any)[countedRequest] = true; + const context = getContext(); if (