From 9388b99c2f7ccbfeb1cf8cfa32a145e032ea35d7 Mon Sep 17 00:00:00 2001 From: Konstantin Wohlwend Date: Thu, 18 Jun 2026 11:31:24 -0700 Subject: [PATCH] Better OAuth sign-up errors --- .../components-page/oauth-callback.test.tsx | 105 ++++++++++++++++++ .../src/components-page/oauth-callback.tsx | 33 +++--- packages/template/src/lib/auth.test.ts | 42 +++++-- packages/template/src/lib/auth.ts | 48 ++++++-- .../client-app-impl.cross-domain.test.ts | 65 +++++++++++ .../apps/implementations/client-app-impl.ts | 42 ++++++- 6 files changed, 296 insertions(+), 39 deletions(-) create mode 100644 packages/template/src/components-page/oauth-callback.test.tsx diff --git a/packages/template/src/components-page/oauth-callback.test.tsx b/packages/template/src/components-page/oauth-callback.test.tsx new file mode 100644 index 000000000..072b7fdf5 --- /dev/null +++ b/packages/template/src/components-page/oauth-callback.test.tsx @@ -0,0 +1,105 @@ +// @vitest-environment jsdom + +import { KnownErrors } from "@hexclave/shared"; +import React, { act } from "react"; +import { createRoot, type Root } from "react-dom/client"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { StackClientApp } from "../lib/hexclave-app/apps/interfaces/client-app"; +import { TranslationProviderClient } from "../providers/translation-provider-client"; +import { OAuthCallback } from "./oauth-callback"; + +const appMockState = vi.hoisted(() => ({ app: null as unknown })); + +vi.mock("..", () => ({ + useStackApp: () => { + if (appMockState.app == null) { + throw new Error("Expected test app to be set before rendering."); + } + return appMockState.app; + }, +})); + +vi.mock("@hexclave/ui", async () => { + const React = await import("react"); + return { + Button: (props: { children: React.ReactNode, onClick?: () => void }) => ( + + ), + Spinner: () =>
, + Typography: (props: { children: React.ReactNode }) =>
{props.children}
, + cn: (...classes: (string | false | null | undefined)[]) => classes.filter(Boolean).join(" "), + }; +}); + +const previousActEnvironment = Reflect.get(globalThis, "IS_REACT_ACT_ENVIRONMENT"); + +function createAppTestDouble(options: { + callOAuthCallback: () => Promise, +}) { + const app = { + callOAuthCallback: options.callOAuthCallback, + redirectToSignIn: vi.fn(async () => {}), + redirectToHome: vi.fn(async () => {}), + }; + + // This test double intentionally implements only the StackClientApp surface + // that OAuthCallback and the rendered error card touch. + return app as unknown as StackClientApp; +} + +let root: Root | null = null; +let container: HTMLDivElement | null = null; + +async function renderWithApp(app: StackClientApp) { + appMockState.app = app; + container = document.createElement("div"); + document.body.append(container); + root = createRoot(container); + await act(async () => { + root?.render( + + + + ); + }); +} + +async function flushEffects() { + await act(async () => { + await new Promise((resolve) => setTimeout(resolve, 0)); + }); +} + +describe("OAuthCallback", () => { + beforeEach(() => { + Reflect.set(globalThis, "IS_REACT_ACT_ENVIRONMENT", true); + }); + + afterEach(() => { + act(() => { + root?.unmount(); + }); + container?.remove(); + root = null; + container = null; + appMockState.app = null; + vi.restoreAllMocks(); + Reflect.set(globalThis, "IS_REACT_ACT_ENVIRONMENT", previousActEnvironment); + }); + + it("renders backend-encoded OAuth callback errors on the callback page", async () => { + const errorMessage = "Your sign up was rejected by an administrator's sign-up rule."; + const callOAuthCallback = vi.fn(async () => { + throw new KnownErrors.SignUpRejected(errorMessage); + }); + const app = createAppTestDouble({ callOAuthCallback }); + + await renderWithApp(app); + await flushEffects(); + + expect(callOAuthCallback).toHaveBeenCalledOnce(); + expect(container?.textContent).toContain("SIGN_UP_REJECTED"); + expect(container?.textContent).toContain(errorMessage); + expect(app.redirectToSignIn).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/template/src/components-page/oauth-callback.tsx b/packages/template/src/components-page/oauth-callback.tsx index a2d8b68fa..7b19d635a 100644 --- a/packages/template/src/components-page/oauth-callback.tsx +++ b/packages/template/src/components-page/oauth-callback.tsx @@ -8,27 +8,19 @@ import { useEffect, useRef, useState } from "react"; import { useStackApp } from ".."; import { MaybeFullPage } from "../components/elements/maybe-full-page"; import { StyledLink } from "../components/link"; -import { hexclaveAppInternalsSymbol } from "../lib/hexclave-app"; import { useTranslation } from "../lib/translations"; +import { ErrorPage } from "./error-page"; export function OAuthCallback({ fullPage }: { fullPage?: boolean }) { const { t } = useTranslation(); const app = useStackApp(); const called = useRef(false); const [showRedirectLink, setShowRedirectLink] = useState(false); - const [redirectUrl, setRedirectUrl] = useState(null); + const [errorSearchParams, setErrorSearchParams] = useState | null>(null); useEffect(() => runAsynchronously(async () => { if (called.current) return; called.current = true; - const redirectToError = async (url: URL) => { - const urlString = url.toString(); - if (app[hexclaveAppInternalsSymbol].getRedirectMethod() === "none") { - setRedirectUrl(urlString); - return; - } - await app[hexclaveAppInternalsSymbol].redirectToUrl(urlString, { replace: true }); - }; try { const hasRedirected = await app.callOAuthCallback(); if (!hasRedirected) { @@ -36,15 +28,15 @@ export function OAuthCallback({ fullPage }: { fullPage?: boolean }) { } } catch (e) { if (KnownError.isKnownError(e)) { - const errorUrl = new URL(app.urls.error, window.location.href); - errorUrl.searchParams.set("errorCode", e.errorCode); - errorUrl.searchParams.set("message", e.message); - errorUrl.searchParams.set("details", JSON.stringify(e.details ?? {})); - await redirectToError(errorUrl); + setErrorSearchParams({ + errorCode: e.errorCode, + message: e.message, + details: JSON.stringify(e.details ?? {}), + }); return; } captureError("", e); - await redirectToError(new URL(app.urls.error, window.location.href)); + setErrorSearchParams({}); } }), [app]); @@ -52,6 +44,10 @@ export function OAuthCallback({ fullPage }: { fullPage?: boolean }) { setTimeout(() => setShowRedirectLink(true), 3000); }, []); + if (errorSearchParams != null) { + return ; + } + return (
- {showRedirectLink || redirectUrl != null ?

{t('If you are not redirected automatically, ')}{t('If you are not redirected automatically, ')} { - if (redirectUrl != null) return; e.preventDefault(); runAsynchronously(app.redirectToHome()); }} diff --git a/packages/template/src/lib/auth.test.ts b/packages/template/src/lib/auth.test.ts index cab61a726..578f6f853 100644 --- a/packages/template/src/lib/auth.test.ts +++ b/packages/template/src/lib/auth.test.ts @@ -1,8 +1,8 @@ // @vitest-environment jsdom -import { HexclaveClientInterface } from "@hexclave/shared"; +import { HexclaveClientInterface, KnownErrors } from "@hexclave/shared"; import { describe, expect, it, vi } from "vitest"; -import { getNewOAuthProviderOrScopeUrl } from "./auth"; +import { callOAuthCallback, getNewOAuthProviderOrScopeUrl } from "./auth"; vi.mock("./cookie", async (importOriginal) => { const actual = await importOriginal(); @@ -15,18 +15,22 @@ vi.mock("./cookie", async (importOriginal) => { }; }); +function createTestInterface() { + return new HexclaveClientInterface({ + clientVersion: "test", + getBaseUrl: () => "https://api.example.com", + getApiUrls: () => ["https://api.example.com"], + extraRequestHeaders: {}, + projectId: "00000000-0000-4000-8000-000000000000", + publishableClientKey: "pck_test", + }); +} + describe("getNewOAuthProviderOrScopeUrl", () => { it("returns the OAuth URL without performing navigation", async () => { window.history.replaceState({}, "", "/account?after_auth_return_to=%2Fsettings"); - const iface = new HexclaveClientInterface({ - clientVersion: "test", - getBaseUrl: () => "https://api.example.com", - getApiUrls: () => ["https://api.example.com"], - extraRequestHeaders: {}, - projectId: "00000000-0000-4000-8000-000000000000", - publishableClientKey: "pck_test", - }); + const iface = createTestInterface(); const session = iface.createSession({ refreshToken: null, accessToken: null }); const location = await getNewOAuthProviderOrScopeUrl( @@ -61,3 +65,21 @@ describe("getNewOAuthProviderOrScopeUrl", () => { `); }); }); + +describe("callOAuthCallback", () => { + it("turns provider access denial callback params into a known error", async () => { + window.history.replaceState({}, "", "/handler/oauth-callback?error=access_denied&error_description=User+cancelled"); + + await expect(callOAuthCallback(createTestInterface(), "/handler/oauth-callback")) + .rejects.toSatisfy((error: unknown) => KnownErrors.OAuthProviderAccessDenied.isInstance(error)); + expect(window.location.href).toBe("http://localhost:3000/handler/oauth-callback"); + }); + + it("turns generic provider error callback params into a known error", async () => { + window.history.replaceState({}, "", "/handler/oauth-callback?error=server_error&error_description=Provider+failed"); + + await expect(callOAuthCallback(createTestInterface(), "/handler/oauth-callback")) + .rejects.toSatisfy((error: unknown) => KnownErrors.OAuthProviderTemporarilyUnavailable.isInstance(error)); + expect(window.location.href).toBe("http://localhost:3000/handler/oauth-callback"); + }); +}); diff --git a/packages/template/src/lib/auth.ts b/packages/template/src/lib/auth.ts index 59e992b45..f8e78248a 100644 --- a/packages/template/src/lib/auth.ts +++ b/packages/template/src/lib/auth.ts @@ -1,4 +1,4 @@ -import { KnownError, HexclaveClientInterface } from "@hexclave/shared"; +import { KnownError, KnownErrors, HexclaveClientInterface } from "@hexclave/shared"; import { InternalSession } from "@hexclave/shared/dist/sessions"; import { HexclaveAssertionError, throwErr } from "@hexclave/shared/dist/utils/errors"; import { Result } from "@hexclave/shared/dist/utils/results"; @@ -46,10 +46,39 @@ type OAuthCallbackConsumptionResult = error: KnownError, }; +const oauthErrorParams = ["error", "error_description", "errorCode", "message", "details"] as const; + +function removeOAuthErrorParamsFromHistory(originalUrl: URL): void { + const newUrl = new URL(originalUrl); + for (const param of oauthErrorParams) { + newUrl.searchParams.delete(param); + } + window.history.replaceState({}, "", newUrl.toString()); +} + +function getProviderOAuthErrorFromUrl(originalUrl: URL): KnownError | null { + const providerError = originalUrl.searchParams.get("error"); + const providerErrorDescription = originalUrl.searchParams.get("error_description"); + if (providerError == null && providerErrorDescription == null) { + return null; + } + + switch (providerError) { + case "access_denied": + case "consent_required": { + return new KnownErrors.OAuthProviderAccessDenied(); + } + case "server_error": + case "temporarily_unavailable": + default: { + return new KnownErrors.OAuthProviderTemporarilyUnavailable(); + } + } +} + function consumeOAuthCallbackQueryParams(options?: { dontWarnAboutMissingQueryParams?: boolean, }): OAuthCallbackConsumptionResult | null { - const oauthErrorParams = ["error", "error_description", "errorCode", "message", "details"] as const; const requiredParams = ["code", "state"]; const originalUrl = new URL(window.location.href); const knownErrorCode = originalUrl.searchParams.get("errorCode"); @@ -68,11 +97,7 @@ function consumeOAuthCallbackQueryParams(options?: { } } - const newUrl = new URL(originalUrl); - for (const param of oauthErrorParams) { - newUrl.searchParams.delete(param); - } - window.history.replaceState({}, "", newUrl.toString()); + removeOAuthErrorParamsFromHistory(originalUrl); return { type: "known-error", @@ -84,6 +109,15 @@ function consumeOAuthCallbackQueryParams(options?: { }; } + const providerOAuthError = getProviderOAuthErrorFromUrl(originalUrl); + if (providerOAuthError != null && !requiredParams.every(param => originalUrl.searchParams.has(param))) { + removeOAuthErrorParamsFromHistory(originalUrl); + return { + type: "known-error", + error: providerOAuthError, + }; + } + for (const param of requiredParams) { if (!originalUrl.searchParams.has(param)) { if (!options?.dontWarnAboutMissingQueryParams) { diff --git a/packages/template/src/lib/hexclave-app/apps/implementations/client-app-impl.cross-domain.test.ts b/packages/template/src/lib/hexclave-app/apps/implementations/client-app-impl.cross-domain.test.ts index 60b6d5ddb..1dd115412 100644 --- a/packages/template/src/lib/hexclave-app/apps/implementations/client-app-impl.cross-domain.test.ts +++ b/packages/template/src/lib/hexclave-app/apps/implementations/client-app-impl.cross-domain.test.ts @@ -435,6 +435,71 @@ describe("StackClientApp cross-domain auth", () => { } }); + it("redirects hosted current-page OAuth callback errors to the hosted error handler during startup", async () => { + const projectId = "00000000-0000-4000-8000-000000000010"; + const previousWindow = globalThis.window; + const previousDocument = globalThis.document; + const callbackUrl = new URL("https://demo.stack-auth.com/dashboard"); + callbackUrl.searchParams.set("errorCode", "SIGN_UP_REJECTED"); + callbackUrl.searchParams.set("message", "Your sign up was rejected by an administrator's sign-up rule."); + callbackUrl.searchParams.set("details", JSON.stringify({ + message: "Your sign up was rejected by an administrator's sign-up rule.", + })); + let currentHref = callbackUrl.toString(); + let redirectedUrl = ""; + const redirectSpy = vi.spyOn(StackClientApp.prototype as any, "_redirectTo").mockImplementation(async (options: { url: string | URL }) => { + redirectedUrl = options.url.toString(); + }); + + globalThis.document = createMockDocument(); + globalThis.window = { + location: { + get href() { + return currentHref; + }, + set href(value: string) { + currentHref = value; + }, + origin: callbackUrl.origin, + }, + history: { + replaceState: (_state: unknown, _title: string, url: string) => { + currentHref = new URL(url, currentHref).toString(); + }, + }, + addEventListener: () => {}, + removeEventListener: () => {}, + } as any; + + try { + new StackClientApp({ + baseUrl: "http://localhost:12345", + projectId, + publishableClientKey: "stack-pk-test", + tokenStore: "memory", + redirectMethod: "window", + urls: { + default: { type: "hosted" }, + }, + noAutomaticPrefetch: true, + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + await new Promise((resolve) => setTimeout(resolve, 0)); + } finally { + redirectSpy.mockRestore(); + globalThis.window = previousWindow; + globalThis.document = previousDocument; + } + + const errorUrl = new URL(redirectedUrl); + expect(errorUrl.origin).toBe(`https://${projectId}.built-with-stack-auth.com`); + expect(errorUrl.pathname).toBe("/handler/error"); + expect(errorUrl.searchParams.get("errorCode")).toBe("SIGN_UP_REJECTED"); + expect(errorUrl.searchParams.get("message")).toBe("Your sign up was rejected by an administrator's sign-up rule."); + expect(new URL(currentHref).searchParams.has("errorCode")).toBe(false); + }); + it("uses direct sign-out instead of hosted sign-out redirects when code execution is available", async () => { const clientApp = new StackClientApp({ baseUrl: "http://localhost:12345", diff --git a/packages/template/src/lib/hexclave-app/apps/implementations/client-app-impl.ts b/packages/template/src/lib/hexclave-app/apps/implementations/client-app-impl.ts index a045026a7..8855001e3 100644 --- a/packages/template/src/lib/hexclave-app/apps/implementations/client-app-impl.ts +++ b/packages/template/src/lib/hexclave-app/apps/implementations/client-app-impl.ts @@ -1,5 +1,5 @@ import { WebAuthnError, startAuthentication, startRegistration } from "@simplewebauthn/browser"; -import { KnownErrors, HexclaveClientInterface } from "@hexclave/shared"; +import { KnownError, KnownErrors, HexclaveClientInterface } from "@hexclave/shared"; import type { RequestListener } from "@hexclave/shared/dist/interface/client-interface"; import { ContactChannelsCrud } from "@hexclave/shared/dist/interface/crud/contact-channels"; import { CurrentUserCrud } from "@hexclave/shared/dist/interface/crud/current-user"; @@ -742,11 +742,11 @@ export class _HexclaveClientAppImplIncomplete { if (isBrowserLike()) { - await this.callOAuthCallback({ dontWarnAboutMissingQueryParams: true }); + await this._handleHostedOAuthCallbackDuringStartup(); } }); } @@ -850,6 +850,22 @@ export class _HexclaveClientAppImplIncomplete { + const errorUrl = new URL(this._getUrls().error, window.location.href); + errorUrl.searchParams.set("errorCode", error.errorCode); + errorUrl.searchParams.set("message", error.message); + errorUrl.searchParams.set("details", JSON.stringify(error.details ?? {})); + await this._redirectIfTrusted(errorUrl.toString(), { replace: true }); + } + + protected async _handleHostedOAuthCallbackDuringStartup(): Promise { + try { + await this.callOAuthCallback({ dontWarnAboutMissingQueryParams: true }); + } catch (error) { + if (KnownError.isKnownError(error)) { + await this._redirectToOAuthCallbackError(error); + return; + } + throw error; + } + } + protected async _fetchCurrentRefreshTokenIdIfSignedIn(options?: { awaitPendingAuthResolutions?: boolean, overrideTokenStoreInit?: TokenStoreInit,