diff --git a/src/util/utils.js b/src/util/utils.js index ee4b0f46be..4157d144f0 100644 --- a/src/util/utils.js +++ b/src/util/utils.js @@ -10,17 +10,22 @@ const stats = require('./stats'); const resolver = new Resolver(); +const BLOCK_HOST_NAMES = process.env.BLOCK_HOST_NAMES || ''; +const BLOCK_HOST_NAMES_LIST = BLOCK_HOST_NAMES.split(','); +const LOCAL_HOST_NAMES_LIST = ['localhost', '127.0.0.1', '[::]', '[::1]']; const LOCALHOST_IP = '127.0.0.1'; -const LOCALHOST_URL = `http://localhost`; const RECORD_TYPE_A = 4; // ipv4 const staticLookup = (transformerVersionId) => async (hostname, _, cb) => { let ips; const resolveStartTime = new Date(); try { - ips = await resolver.resolve(hostname); + ips = await resolver.resolve4(hostname); } catch (error) { - stats.timing('fetch_dns_resolve_time', resolveStartTime, { transformerVersionId, error: 'true' }); + stats.timing('fetch_dns_resolve_time', resolveStartTime, { + transformerVersionId, + error: 'true', + }); cb(null, `unable to resolve IP address for ${hostname}`, RECORD_TYPE_A); return; } @@ -47,8 +52,17 @@ const httpAgentWithDnsLookup = (scheme, transformerVersionId) => { }; const blockLocalhostRequests = (url) => { - if (url.includes(LOCALHOST_URL) || url.includes(LOCALHOST_IP)) { - throw new Error('localhost requests are not allowed'); + try { + const parseUrl = new URL(url); + const { hostname } = parseUrl; + if (LOCAL_HOST_NAMES_LIST.includes(hostname)) { + throw new Error('localhost requests are not allowed'); + } + if (BLOCK_HOST_NAMES_LIST.includes(hostname)) { + throw new Error('blocked host requests are not allowed'); + } + } catch (error) { + throw new Error(`invalid url ${url} :: ${error.message}`); } }; @@ -163,14 +177,14 @@ const extractStackTraceUptoLastSubstringMatch = (trace, stringLiterals) => { const traceLines = trace.split('\n'); let lastRelevantIndex = 0; - for(let i = traceLines.length - 1; i >= 0; i -= 1) { - if (stringLiterals.some(str => traceLines[i].includes(str))) { + for (let i = traceLines.length - 1; i >= 0; i -= 1) { + if (stringLiterals.some((str) => traceLines[i].includes(str))) { lastRelevantIndex = i; break; } } - return traceLines.slice(0, lastRelevantIndex + 1).join("\n"); + return traceLines.slice(0, lastRelevantIndex + 1).join('\n'); }; module.exports = {