diff --git a/apps/backend/src/oauth/ssrf-protection.test.ts b/apps/backend/src/oauth/ssrf-protection.test.ts index b429fb8a8..330fb7f25 100644 --- a/apps/backend/src/oauth/ssrf-protection.test.ts +++ b/apps/backend/src/oauth/ssrf-protection.test.ts @@ -1,6 +1,21 @@ import { StatusError } from "@hexclave/shared/dist/utils/errors"; import { describe, expect, it } from "vitest"; -import { assertSafeOAuthResolvedAddress, assertSafeOAuthUrlWithoutDns, isBlockedOAuthIpAddress } from "./ssrf-protection"; +import dns from "node:dns"; +import { assertSafeOAuthResolvedAddress, assertSafeOAuthUrlWithoutDns, isBlockedOAuthIpAddress, safeOAuthDnsLookup } from "./ssrf-protection"; + +async function withProductionNodeEnv(callback: () => Promise): Promise { + const previousNodeEnv = process.env.NODE_ENV; + process.env.NODE_ENV = "production"; + try { + return await callback(); + } finally { + if (previousNodeEnv === undefined) { + delete process.env.NODE_ENV; + } else { + process.env.NODE_ENV = previousNodeEnv; + } + } +} describe("isBlockedOAuthIpAddress", () => { it("blocks AWS metadata, loopback, and private IPv4 ranges", () => { @@ -50,3 +65,25 @@ describe("assertSafeOAuthResolvedAddress", () => { }); }); +describe("safeOAuthDnsLookup", () => { + it("reports blocked single-address lookup results through the callback", async () => { + const error = await withProductionNodeEnv(async () => await new Promise((resolve) => { + safeOAuthDnsLookup("127.0.0.1", {}, (lookupError) => { + resolve(lookupError); + }); + })); + + expect(error).toBeInstanceOf(StatusError); + }); + + it("reports blocked all-address lookup results through the callback", async () => { + const error = await withProductionNodeEnv(async () => await new Promise((resolve) => { + safeOAuthDnsLookup("127.0.0.1", { all: true, verbatim: true } satisfies dns.LookupAllOptions, (lookupError) => { + resolve(lookupError); + }); + })); + + expect(error).toBeInstanceOf(StatusError); + }); +}); + diff --git a/apps/backend/src/oauth/ssrf-protection.ts b/apps/backend/src/oauth/ssrf-protection.ts index 4e601abe3..d9ee4b9d4 100644 --- a/apps/backend/src/oauth/ssrf-protection.ts +++ b/apps/backend/src/oauth/ssrf-protection.ts @@ -119,6 +119,18 @@ type DnsLookupCallback = ( family?: number, ) => void; +function getLookupValidationError(validate: () => void): NodeJS.ErrnoException | null { + try { + validate(); + return null; + } catch (error) { + if (error instanceof Error) { + return error; + } + return new Error("OAuth DNS lookup failed while validating resolved address."); + } +} + export function safeOAuthDnsLookup(hostname: string, options: dns.LookupOptions, callback: DnsLookupCallback): void { if (!shouldEnforceOAuthSsrfProtection()) { dns.lookup(hostname, options, callback); @@ -133,8 +145,14 @@ export function safeOAuthDnsLookup(hostname: string, options: dns.LookupOptions, return; } - for (const address of addresses) { - assertSafeOAuthResolvedAddress(address.address); + const validationError = getLookupValidationError(() => { + for (const address of addresses) { + assertSafeOAuthResolvedAddress(address.address); + } + }); + if (validationError !== null) { + callback(validationError, []); + return; } callback(null, addresses); }); @@ -148,7 +166,11 @@ export function safeOAuthDnsLookup(hostname: string, options: dns.LookupOptions, return; } - assertSafeOAuthResolvedAddress(address); + const validationError = getLookupValidationError(() => assertSafeOAuthResolvedAddress(address)); + if (validationError !== null) { + callback(validationError, "", 0); + return; + } callback(null, address, family); }); }