Better OAuth sign-up errors

This commit is contained in:
Konstantin Wohlwend 2026-06-18 11:31:24 -07:00
parent 2220e89939
commit 9388b99c2f
6 changed files with 296 additions and 39 deletions

View File

@ -0,0 +1,105 @@
// @vitest-environment jsdom
import { KnownErrors } from "@hexclave/shared";
import React, { act } from "react";
import { createRoot, type Root } from "react-dom/client";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import type { StackClientApp } from "../lib/hexclave-app/apps/interfaces/client-app";
import { TranslationProviderClient } from "../providers/translation-provider-client";
import { OAuthCallback } from "./oauth-callback";
const appMockState = vi.hoisted(() => ({ app: null as unknown }));
vi.mock("..", () => ({
useStackApp: () => {
if (appMockState.app == null) {
throw new Error("Expected test app to be set before rendering.");
}
return appMockState.app;
},
}));
vi.mock("@hexclave/ui", async () => {
const React = await import("react");
return {
Button: (props: { children: React.ReactNode, onClick?: () => void }) => (
<button type="button" onClick={props.onClick}>{props.children}</button>
),
Spinner: () => <div data-testid="spinner" />,
Typography: (props: { children: React.ReactNode }) => <div>{props.children}</div>,
cn: (...classes: (string | false | null | undefined)[]) => classes.filter(Boolean).join(" "),
};
});
const previousActEnvironment = Reflect.get(globalThis, "IS_REACT_ACT_ENVIRONMENT");
function createAppTestDouble(options: {
callOAuthCallback: () => Promise<boolean>,
}) {
const app = {
callOAuthCallback: options.callOAuthCallback,
redirectToSignIn: vi.fn(async () => {}),
redirectToHome: vi.fn(async () => {}),
};
// This test double intentionally implements only the StackClientApp surface
// that OAuthCallback and the rendered error card touch.
return app as unknown as StackClientApp<true>;
}
let root: Root | null = null;
let container: HTMLDivElement | null = null;
async function renderWithApp(app: StackClientApp<true>) {
appMockState.app = app;
container = document.createElement("div");
document.body.append(container);
root = createRoot(container);
await act(async () => {
root?.render(
<TranslationProviderClient quetzalKeys={new Map()} quetzalLocale={new Map()}>
<OAuthCallback />
</TranslationProviderClient>
);
});
}
async function flushEffects() {
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 0));
});
}
describe("OAuthCallback", () => {
beforeEach(() => {
Reflect.set(globalThis, "IS_REACT_ACT_ENVIRONMENT", true);
});
afterEach(() => {
act(() => {
root?.unmount();
});
container?.remove();
root = null;
container = null;
appMockState.app = null;
vi.restoreAllMocks();
Reflect.set(globalThis, "IS_REACT_ACT_ENVIRONMENT", previousActEnvironment);
});
it("renders backend-encoded OAuth callback errors on the callback page", async () => {
const errorMessage = "Your sign up was rejected by an administrator's sign-up rule.";
const callOAuthCallback = vi.fn(async () => {
throw new KnownErrors.SignUpRejected(errorMessage);
});
const app = createAppTestDouble({ callOAuthCallback });
await renderWithApp(app);
await flushEffects();
expect(callOAuthCallback).toHaveBeenCalledOnce();
expect(container?.textContent).toContain("SIGN_UP_REJECTED");
expect(container?.textContent).toContain(errorMessage);
expect(app.redirectToSignIn).not.toHaveBeenCalled();
});
});

View File

@ -8,27 +8,19 @@ import { useEffect, useRef, useState } from "react";
import { useStackApp } from "..";
import { MaybeFullPage } from "../components/elements/maybe-full-page";
import { StyledLink } from "../components/link";
import { hexclaveAppInternalsSymbol } from "../lib/hexclave-app";
import { useTranslation } from "../lib/translations";
import { ErrorPage } from "./error-page";
export function OAuthCallback({ fullPage }: { fullPage?: boolean }) {
const { t } = useTranslation();
const app = useStackApp();
const called = useRef(false);
const [showRedirectLink, setShowRedirectLink] = useState(false);
const [redirectUrl, setRedirectUrl] = useState<string | null>(null);
const [errorSearchParams, setErrorSearchParams] = useState<Record<string, string> | null>(null);
useEffect(() => runAsynchronously(async () => {
if (called.current) return;
called.current = true;
const redirectToError = async (url: URL) => {
const urlString = url.toString();
if (app[hexclaveAppInternalsSymbol].getRedirectMethod() === "none") {
setRedirectUrl(urlString);
return;
}
await app[hexclaveAppInternalsSymbol].redirectToUrl(urlString, { replace: true });
};
try {
const hasRedirected = await app.callOAuthCallback();
if (!hasRedirected) {
@ -36,15 +28,15 @@ export function OAuthCallback({ fullPage }: { fullPage?: boolean }) {
}
} catch (e) {
if (KnownError.isKnownError(e)) {
const errorUrl = new URL(app.urls.error, window.location.href);
errorUrl.searchParams.set("errorCode", e.errorCode);
errorUrl.searchParams.set("message", e.message);
errorUrl.searchParams.set("details", JSON.stringify(e.details ?? {}));
await redirectToError(errorUrl);
setErrorSearchParams({
errorCode: e.errorCode,
message: e.message,
details: JSON.stringify(e.details ?? {}),
});
return;
}
captureError("<OAuthCallback />", e);
await redirectToError(new URL(app.urls.error, window.location.href));
setErrorSearchParams({});
}
}), [app]);
@ -52,6 +44,10 @@ export function OAuthCallback({ fullPage }: { fullPage?: boolean }) {
setTimeout(() => setShowRedirectLink(true), 3000);
}, []);
if (errorSearchParams != null) {
return <ErrorPage searchParams={errorSearchParams} fullPage={fullPage} />;
}
return (
<MaybeFullPage
fullPage={fullPage ?? false}
@ -66,11 +62,10 @@ export function OAuthCallback({ fullPage }: { fullPage?: boolean }) {
<div className="flex flex-col justify-center items-center gap-4">
<Spinner size={20} />
</div>
{showRedirectLink || redirectUrl != null ? <p>{t('If you are not redirected automatically, ')}<StyledLink
{showRedirectLink ? <p>{t('If you are not redirected automatically, ')}<StyledLink
className="whitespace-nowrap"
href={redirectUrl ?? "#"}
href="#"
onClick={(e) => {
if (redirectUrl != null) return;
e.preventDefault();
runAsynchronously(app.redirectToHome());
}}

View File

@ -1,8 +1,8 @@
// @vitest-environment jsdom
import { HexclaveClientInterface } from "@hexclave/shared";
import { HexclaveClientInterface, KnownErrors } from "@hexclave/shared";
import { describe, expect, it, vi } from "vitest";
import { getNewOAuthProviderOrScopeUrl } from "./auth";
import { callOAuthCallback, getNewOAuthProviderOrScopeUrl } from "./auth";
vi.mock("./cookie", async (importOriginal) => {
const actual = await importOriginal<typeof import("./cookie")>();
@ -15,18 +15,22 @@ vi.mock("./cookie", async (importOriginal) => {
};
});
function createTestInterface() {
return new HexclaveClientInterface({
clientVersion: "test",
getBaseUrl: () => "https://api.example.com",
getApiUrls: () => ["https://api.example.com"],
extraRequestHeaders: {},
projectId: "00000000-0000-4000-8000-000000000000",
publishableClientKey: "pck_test",
});
}
describe("getNewOAuthProviderOrScopeUrl", () => {
it("returns the OAuth URL without performing navigation", async () => {
window.history.replaceState({}, "", "/account?after_auth_return_to=%2Fsettings");
const iface = new HexclaveClientInterface({
clientVersion: "test",
getBaseUrl: () => "https://api.example.com",
getApiUrls: () => ["https://api.example.com"],
extraRequestHeaders: {},
projectId: "00000000-0000-4000-8000-000000000000",
publishableClientKey: "pck_test",
});
const iface = createTestInterface();
const session = iface.createSession({ refreshToken: null, accessToken: null });
const location = await getNewOAuthProviderOrScopeUrl(
@ -61,3 +65,21 @@ describe("getNewOAuthProviderOrScopeUrl", () => {
`);
});
});
describe("callOAuthCallback", () => {
it("turns provider access denial callback params into a known error", async () => {
window.history.replaceState({}, "", "/handler/oauth-callback?error=access_denied&error_description=User+cancelled");
await expect(callOAuthCallback(createTestInterface(), "/handler/oauth-callback"))
.rejects.toSatisfy((error: unknown) => KnownErrors.OAuthProviderAccessDenied.isInstance(error));
expect(window.location.href).toBe("http://localhost:3000/handler/oauth-callback");
});
it("turns generic provider error callback params into a known error", async () => {
window.history.replaceState({}, "", "/handler/oauth-callback?error=server_error&error_description=Provider+failed");
await expect(callOAuthCallback(createTestInterface(), "/handler/oauth-callback"))
.rejects.toSatisfy((error: unknown) => KnownErrors.OAuthProviderTemporarilyUnavailable.isInstance(error));
expect(window.location.href).toBe("http://localhost:3000/handler/oauth-callback");
});
});

View File

@ -1,4 +1,4 @@
import { KnownError, HexclaveClientInterface } from "@hexclave/shared";
import { KnownError, KnownErrors, HexclaveClientInterface } from "@hexclave/shared";
import { InternalSession } from "@hexclave/shared/dist/sessions";
import { HexclaveAssertionError, throwErr } from "@hexclave/shared/dist/utils/errors";
import { Result } from "@hexclave/shared/dist/utils/results";
@ -46,10 +46,39 @@ type OAuthCallbackConsumptionResult =
error: KnownError,
};
const oauthErrorParams = ["error", "error_description", "errorCode", "message", "details"] as const;
function removeOAuthErrorParamsFromHistory(originalUrl: URL): void {
const newUrl = new URL(originalUrl);
for (const param of oauthErrorParams) {
newUrl.searchParams.delete(param);
}
window.history.replaceState({}, "", newUrl.toString());
}
function getProviderOAuthErrorFromUrl(originalUrl: URL): KnownError | null {
const providerError = originalUrl.searchParams.get("error");
const providerErrorDescription = originalUrl.searchParams.get("error_description");
if (providerError == null && providerErrorDescription == null) {
return null;
}
switch (providerError) {
case "access_denied":
case "consent_required": {
return new KnownErrors.OAuthProviderAccessDenied();
}
case "server_error":
case "temporarily_unavailable":
default: {
return new KnownErrors.OAuthProviderTemporarilyUnavailable();
}
}
}
function consumeOAuthCallbackQueryParams(options?: {
dontWarnAboutMissingQueryParams?: boolean,
}): OAuthCallbackConsumptionResult | null {
const oauthErrorParams = ["error", "error_description", "errorCode", "message", "details"] as const;
const requiredParams = ["code", "state"];
const originalUrl = new URL(window.location.href);
const knownErrorCode = originalUrl.searchParams.get("errorCode");
@ -68,11 +97,7 @@ function consumeOAuthCallbackQueryParams(options?: {
}
}
const newUrl = new URL(originalUrl);
for (const param of oauthErrorParams) {
newUrl.searchParams.delete(param);
}
window.history.replaceState({}, "", newUrl.toString());
removeOAuthErrorParamsFromHistory(originalUrl);
return {
type: "known-error",
@ -84,6 +109,15 @@ function consumeOAuthCallbackQueryParams(options?: {
};
}
const providerOAuthError = getProviderOAuthErrorFromUrl(originalUrl);
if (providerOAuthError != null && !requiredParams.every(param => originalUrl.searchParams.has(param))) {
removeOAuthErrorParamsFromHistory(originalUrl);
return {
type: "known-error",
error: providerOAuthError,
};
}
for (const param of requiredParams) {
if (!originalUrl.searchParams.has(param)) {
if (!options?.dontWarnAboutMissingQueryParams) {

View File

@ -435,6 +435,71 @@ describe("StackClientApp cross-domain auth", () => {
}
});
it("redirects hosted current-page OAuth callback errors to the hosted error handler during startup", async () => {
const projectId = "00000000-0000-4000-8000-000000000010";
const previousWindow = globalThis.window;
const previousDocument = globalThis.document;
const callbackUrl = new URL("https://demo.stack-auth.com/dashboard");
callbackUrl.searchParams.set("errorCode", "SIGN_UP_REJECTED");
callbackUrl.searchParams.set("message", "Your sign up was rejected by an administrator's sign-up rule.");
callbackUrl.searchParams.set("details", JSON.stringify({
message: "Your sign up was rejected by an administrator's sign-up rule.",
}));
let currentHref = callbackUrl.toString();
let redirectedUrl = "";
const redirectSpy = vi.spyOn(StackClientApp.prototype as any, "_redirectTo").mockImplementation(async (options: { url: string | URL }) => {
redirectedUrl = options.url.toString();
});
globalThis.document = createMockDocument();
globalThis.window = {
location: {
get href() {
return currentHref;
},
set href(value: string) {
currentHref = value;
},
origin: callbackUrl.origin,
},
history: {
replaceState: (_state: unknown, _title: string, url: string) => {
currentHref = new URL(url, currentHref).toString();
},
},
addEventListener: () => {},
removeEventListener: () => {},
} as any;
try {
new StackClientApp({
baseUrl: "http://localhost:12345",
projectId,
publishableClientKey: "stack-pk-test",
tokenStore: "memory",
redirectMethod: "window",
urls: {
default: { type: "hosted" },
},
noAutomaticPrefetch: true,
});
await new Promise((resolve) => setTimeout(resolve, 0));
await new Promise((resolve) => setTimeout(resolve, 0));
} finally {
redirectSpy.mockRestore();
globalThis.window = previousWindow;
globalThis.document = previousDocument;
}
const errorUrl = new URL(redirectedUrl);
expect(errorUrl.origin).toBe(`https://${projectId}.built-with-stack-auth.com`);
expect(errorUrl.pathname).toBe("/handler/error");
expect(errorUrl.searchParams.get("errorCode")).toBe("SIGN_UP_REJECTED");
expect(errorUrl.searchParams.get("message")).toBe("Your sign up was rejected by an administrator's sign-up rule.");
expect(new URL(currentHref).searchParams.has("errorCode")).toBe(false);
});
it("uses direct sign-out instead of hosted sign-out redirects when code execution is available", async () => {
const clientApp = new StackClientApp({
baseUrl: "http://localhost:12345",

View File

@ -1,5 +1,5 @@
import { WebAuthnError, startAuthentication, startRegistration } from "@simplewebauthn/browser";
import { KnownErrors, HexclaveClientInterface } from "@hexclave/shared";
import { KnownError, KnownErrors, HexclaveClientInterface } from "@hexclave/shared";
import type { RequestListener } from "@hexclave/shared/dist/interface/client-interface";
import { ContactChannelsCrud } from "@hexclave/shared/dist/interface/crud/contact-channels";
import { CurrentUserCrud } from "@hexclave/shared/dist/interface/crud/current-user";
@ -742,11 +742,11 @@ export class _HexclaveClientAppImplIncomplete<HasTokenStore extends boolean, Pro
if (
isBrowserLike()
&& (this._isOAuthCallbackUrlHosted() || this._currentUrlLooksLikeNestedCrossDomainOAuthCallback())
&& this._currentUrlLooksLikeHexclaveOAuthCallback()
&& (this._currentUrlLooksLikeHexclaveOAuthCallback() || this._currentUrlLooksLikeOAuthCallbackError())
) {
this._trackPendingAuthResolution(async () => {
if (isBrowserLike()) {
await this.callOAuthCallback({ dontWarnAboutMissingQueryParams: true });
await this._handleHostedOAuthCallbackDuringStartup();
}
});
}
@ -850,6 +850,22 @@ export class _HexclaveClientAppImplIncomplete<HasTokenStore extends boolean, Pro
currentUrl.searchParams.has("code") && currentUrl.searchParams.has("state")
) || (
currentUrl.searchParams.has("errorCode") && currentUrl.searchParams.has("message")
) || (
this._currentUrlLooksLikeOAuthCallbackError()
);
}
protected _currentUrlLooksLikeOAuthCallbackError(): boolean {
if (typeof window === "undefined") {
return false;
}
const currentUrl = new URL(window.location.href);
if (currentUrl.searchParams.has("errorCode") && currentUrl.searchParams.has("message")) {
return true;
}
return (
(currentUrl.searchParams.has("error") || currentUrl.searchParams.has("error_description"))
&& !(currentUrl.searchParams.has("code") && currentUrl.searchParams.has("state"))
);
}
@ -891,6 +907,26 @@ export class _HexclaveClientAppImplIncomplete<HasTokenStore extends boolean, Pro
return currentUrl.toString();
}
protected async _redirectToOAuthCallbackError(error: KnownError): Promise<void> {
const errorUrl = new URL(this._getUrls().error, window.location.href);
errorUrl.searchParams.set("errorCode", error.errorCode);
errorUrl.searchParams.set("message", error.message);
errorUrl.searchParams.set("details", JSON.stringify(error.details ?? {}));
await this._redirectIfTrusted(errorUrl.toString(), { replace: true });
}
protected async _handleHostedOAuthCallbackDuringStartup(): Promise<void> {
try {
await this.callOAuthCallback({ dontWarnAboutMissingQueryParams: true });
} catch (error) {
if (KnownError.isKnownError(error)) {
await this._redirectToOAuthCallbackError(error);
return;
}
throw error;
}
}
protected async _fetchCurrentRefreshTokenIdIfSignedIn(options?: {
awaitPendingAuthResolutions?: boolean,
overrideTokenStoreInit?: TokenStoreInit,