diff --git a/apps/backend/src/lib/ai/ai-proxy-handlers.test.ts b/apps/backend/src/lib/ai/ai-proxy-handlers.test.ts new file mode 100644 index 000000000..4142f366a --- /dev/null +++ b/apps/backend/src/lib/ai/ai-proxy-handlers.test.ts @@ -0,0 +1,137 @@ +/** + * Ensures `observeAndLog` returns upstream bytes even when proxy logging throws. + */ +import { describe, expect, it, vi } from "vitest"; + +vi.mock("./loggers/ai-proxy-logger", () => ({ + buildProxyLogRow: vi.fn(), + scheduleProxyLog: vi.fn(), +})); + +vi.mock("./openrouter-usage", async (importOriginal) => { + const original = await importOriginal>(); + return { + ...original, + extractOpenRouterUsage: vi.fn(() => ({})), + scanSseForUsage: vi.fn(async () => ({})), + }; +}); + +vi.mock("@/private", () => ({ + preprocessProxyBody: ({ parsedBody }: { parsedBody: unknown }) => parsedBody, +})); + +import { buildProxyLogRow, scheduleProxyLog } from "./loggers/ai-proxy-logger"; +import { observeAndLog, sanitizeBody } from "./ai-proxy-handlers"; + +const sanitized = sanitizeBody(new TextEncoder().encode(JSON.stringify({ + model: "anthropic/claude-sonnet-4.6", + messages: [{ role: "user", content: "hi" }], +})).buffer as ArrayBuffer); + +describe("observeAndLog body delivery", () => { + it("returns the upstream bytes even if buildProxyLogRow throws", async () => { + vi.mocked(buildProxyLogRow).mockImplementation(() => { + throw new Error("synthetic log construction failure"); + }); + vi.mocked(scheduleProxyLog).mockClear(); + + const upstreamBody = JSON.stringify({ id: "gen-1", choices: [{ message: { content: "hello" } }] }); + const upstream = new Response(upstreamBody, { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + + const out = await observeAndLog({ + response: upstream, + sanitizedBody: sanitized, + callerApiKey: "stack-auth-test", + correlationId: "corr-1", + startedAt: 0, + responseHeaders: { "Content-Type": "application/json", "Cache-Control": "no-store" }, + }); + + expect(out.status).toBe(200); + expect(await out.text()).toBe(upstreamBody); + }); + + it("returns the upstream bytes even if scheduleProxyLog throws", async () => { + vi.mocked(buildProxyLogRow).mockReturnValue({ + correlationId: "corr-2", + mode: "generate", + systemPromptId: "stack-auth-test", + quality: "unknown", + speed: "unknown", + modelId: "anthropic/claude-sonnet-4.6", + isAuthenticated: false, + projectId: undefined, + userId: undefined, + requestedToolsJson: "[]", + messagesJson: "[]", + stepsJson: "[]", + finalText: "", + inputTokens: undefined, + outputTokens: undefined, + cachedInputTokens: undefined, + cacheCreationTokens: undefined, + costUsd: undefined, + cacheDiscountUsd: undefined, + openrouterGenerationId: undefined, + stepCount: 0, + durationMs: 0n, + errorMessage: undefined, + conversationId: undefined, + }); + vi.mocked(scheduleProxyLog).mockImplementation(() => { + throw new Error("synthetic schedule failure"); + }); + + const upstreamBody = JSON.stringify({ id: "gen-2", choices: [{ message: { content: "world" } }] }); + const upstream = new Response(upstreamBody, { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + + const out = await observeAndLog({ + response: upstream, + sanitizedBody: sanitized, + callerApiKey: "stack-auth-test", + correlationId: "corr-2", + startedAt: 0, + responseHeaders: { "Content-Type": "application/json", "Cache-Control": "no-store" }, + }); + + expect(out.status).toBe(200); + expect(await out.text()).toBe(upstreamBody); + }); + + it("delivers full streamed bytes via tee even when the observer arm throws", async () => { + vi.mocked(buildProxyLogRow).mockImplementation(() => { + throw new Error("synthetic log failure inside async observer"); + }); + + const streamPayload = "data: {\"id\":\"gen-3\"}\n\ndata: [DONE]\n\n"; + const upstream = new Response(streamPayload, { + status: 200, + headers: { "Content-Type": "text/event-stream" }, + }); + + const streamingSanitized = sanitizeBody(new TextEncoder().encode(JSON.stringify({ + model: "anthropic/claude-sonnet-4.6", + stream: true, + messages: [{ role: "user", content: "hi" }], + })).buffer as ArrayBuffer); + + const out = await observeAndLog({ + response: upstream, + sanitizedBody: streamingSanitized, + callerApiKey: "stack-auth-test", + correlationId: "corr-3", + startedAt: 0, + responseHeaders: { "Content-Type": "text/event-stream", "Cache-Control": "no-store" }, + }); + + expect(out.status).toBe(200); + expect(await out.text()).toBe(streamPayload); + }); +}); diff --git a/apps/backend/src/lib/ai/ai-proxy-handlers.ts b/apps/backend/src/lib/ai/ai-proxy-handlers.ts index d2b985de0..c80b07799 100644 --- a/apps/backend/src/lib/ai/ai-proxy-handlers.ts +++ b/apps/backend/src/lib/ai/ai-proxy-handlers.ts @@ -51,42 +51,50 @@ export async function observeAndLog(args: { if (isStreaming && response.body) { const [clientStream, observerStream] = response.body.tee(); runAsynchronouslyAndWaitUntil((async () => { - let usage: UsageFields = {}; - const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), 120_000); try { - usage = (await scanSseForUsage(observerStream, controller.signal)) ?? {}; + let usage: UsageFields = {}; + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 120_000); + try { + usage = (await scanSseForUsage(observerStream, controller.signal)) ?? {}; + } catch (err) { + captureError("ai-proxy-scan-sse", err); + } finally { + clearTimeout(timeoutId); + } + scheduleProxyLog(buildProxyLogRow({ + correlationId, + parsed: sanitizedBody.parsed, + apiKey: callerApiKey, + durationMs: BigInt(Math.round(performance.now() - startedAt)), + responseStatus: response.status, + usage, + })); } catch (err) { - captureError("ai-proxy-scan-sse", err); - } finally { - clearTimeout(timeoutId); + captureError("ai-proxy-observer", err); } - scheduleProxyLog(buildProxyLogRow({ - correlationId, - parsed: sanitizedBody.parsed, - apiKey: callerApiKey, - durationMs: BigInt(Math.round(performance.now() - startedAt)), - responseStatus: response.status, - usage, - })); })()); return new Response(clientStream, { status: response.status, headers: responseHeaders }); } const bodyBytes = await response.arrayBuffer(); - let parsedBody: unknown; try { - parsedBody = JSON.parse(new TextDecoder().decode(bodyBytes)); - } catch { - parsedBody = undefined; + let parsedBody: unknown; + try { + parsedBody = JSON.parse(new TextDecoder().decode(bodyBytes)); + } catch { + parsedBody = undefined; + } + scheduleProxyLog(buildProxyLogRow({ + correlationId, + parsed: sanitizedBody.parsed, + apiKey: callerApiKey, + durationMs: BigInt(Math.round(performance.now() - startedAt)), + responseStatus: response.status, + usage: extractOpenRouterUsage(parsedBody), + })); + } catch (err) { + captureError("ai-proxy-log-build", err); } - scheduleProxyLog(buildProxyLogRow({ - correlationId, - parsed: sanitizedBody.parsed, - apiKey: callerApiKey, - durationMs: BigInt(Math.round(performance.now() - startedAt)), - responseStatus: response.status, - usage: extractOpenRouterUsage(parsedBody), - })); return new Response(bodyBytes, { status: response.status, headers: responseHeaders }); } diff --git a/apps/backend/src/lib/ai/loggers/ai-proxy-logger.test.ts b/apps/backend/src/lib/ai/loggers/ai-proxy-logger.test.ts new file mode 100644 index 000000000..ae5dd9801 --- /dev/null +++ b/apps/backend/src/lib/ai/loggers/ai-proxy-logger.test.ts @@ -0,0 +1,69 @@ +/** + * Tool-name extraction in `buildProxyLogRow` for Anthropic and OpenAI shapes. + */ +import { describe, expect, it } from "vitest"; +import { buildProxyLogRow } from "./ai-proxy-logger"; + +const baseInput = { + correlationId: "corr-1", + apiKey: "stack-auth-proxy", + durationMs: 0n, + responseStatus: 200, +}; + +describe("buildProxyLogRow tool-name extraction", () => { + it("captures Anthropic top-level tool names", () => { + const row = buildProxyLogRow({ + ...baseInput, + parsed: { + model: "anthropic/claude-sonnet-4.6", + tools: [ + { name: "get_weather", description: "...", input_schema: {} }, + { name: "send_email", description: "...", input_schema: {} }, + ], + }, + }); + expect(JSON.parse(row.requestedToolsJson)).toEqual(["get_weather", "send_email"]); + }); + + it("captures OpenAI/OpenRouter-format function tool names", () => { + const row = buildProxyLogRow({ + ...baseInput, + parsed: { + model: "anthropic/claude-sonnet-4.6", + tools: [ + { type: "function", function: { name: "get_weather", parameters: {} } }, + { type: "function", function: { name: "send_email", parameters: {} } }, + ], + }, + }); + expect(JSON.parse(row.requestedToolsJson)).toEqual(["get_weather", "send_email"]); + }); + + it("handles a mixed array gracefully", () => { + const row = buildProxyLogRow({ + ...baseInput, + parsed: { + model: "anthropic/claude-sonnet-4.6", + tools: [ + { name: "anthropic_tool", input_schema: {} }, + { type: "function", function: { name: "openai_tool", parameters: {} } }, + { type: "function" }, + null, + "not an object", + ], + }, + }); + expect(JSON.parse(row.requestedToolsJson)).toEqual(["anthropic_tool", "openai_tool"]); + }); + + it("returns an empty array when tools is absent or malformed", () => { + const row = buildProxyLogRow({ + ...baseInput, + parsed: { + model: "anthropic/claude-sonnet-4.6", + }, + }); + expect(JSON.parse(row.requestedToolsJson)).toEqual([]); + }); +}); diff --git a/apps/backend/src/lib/ai/loggers/ai-proxy-logger.ts b/apps/backend/src/lib/ai/loggers/ai-proxy-logger.ts index 3a5176751..cd35ee8e0 100644 --- a/apps/backend/src/lib/ai/loggers/ai-proxy-logger.ts +++ b/apps/backend/src/lib/ai/loggers/ai-proxy-logger.ts @@ -16,7 +16,13 @@ export function buildProxyLogRow(fields: ProxyLogFields): AiQueryLogEntry { const { parsed, apiKey, durationMs, responseStatus, usage, correlationId } = fields; const tools = Array.isArray(parsed.tools) ? parsed.tools : []; const toolNames = tools - .map(t => (t && typeof t === "object" && "name" in t) ? (t as { name: unknown }).name : null) + .map((t) => { + if (t == null || typeof t !== "object") return null; + const obj = t as { name?: unknown, function?: { name?: unknown } }; + if (typeof obj.function?.name === "string") return obj.function.name; + if (typeof obj.name === "string") return obj.name; + return null; + }) .filter((n): n is string => typeof n === "string"); const rawMessages = Array.isArray(parsed.messages) ? parsed.messages : []; const messages = typeof parsed.system === "string" && parsed.system.length > 0 diff --git a/apps/backend/src/lib/ai/spacetimedb-client.test.ts b/apps/backend/src/lib/ai/spacetimedb-client.test.ts new file mode 100644 index 000000000..02104308f --- /dev/null +++ b/apps/backend/src/lib/ai/spacetimedb-client.test.ts @@ -0,0 +1,123 @@ +/** + * Exercises `callSql` 401 enrollment retry (shared with `callReducer` via + * `withEnrollmentRetry`). + */ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +vi.mock("@stackframe/stack-shared/dist/utils/env", async (importOriginal) => { + const original = await importOriginal>(); + return { + ...original, + getEnvVariable: (key: string, fallback?: string) => { + switch (key) { + case "STACK_SPACETIMEDB_URL": { + return "http://spacetime.test"; + } + case "STACK_SPACETIMEDB_DB_NAME": { + return "test-db"; + } + case "STACK_SPACETIMEDB_SERVICE_TOKEN": { + return "test-service-token"; + } + case "STACK_MCP_LOG_TOKEN": { + return "test-log-token"; + } + default: { + return fallback ?? ""; + } + } + }, + }; +}); + +const fetchMock = vi.fn(); +const originalFetch = globalThis.fetch; + +function makeJsonResponse(body: unknown, status = 200): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +function makeSqlSuccess(): Response { + return makeJsonResponse([{ + schema: { elements: [{ name: { some: "question" } }, { name: { some: "answer" } }] }, + rows: [["q1", "a1"]], + }]); +} + +function isSqlRequest(url: unknown): boolean { + return typeof url === "string" && url.endsWith("/sql"); +} + +function isEnrollRequest(url: unknown): boolean { + return typeof url === "string" && url.endsWith("/call/enroll_service"); +} + +beforeEach(() => { + fetchMock.mockReset(); + globalThis.fetch = fetchMock as unknown as typeof fetch; + // The module-level enrollmentPromise cache survives across tests; reset it + // by re-importing the module fresh for each test. + vi.resetModules(); +}); + +afterEach(() => { + globalThis.fetch = originalFetch; +}); + +describe("callSql 401 enrollment retry", () => { + it("retries SQL after a 401 by re-enrolling, mirroring callReducer", async () => { + fetchMock.mockImplementation((url: unknown) => { + if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200)); + if (isSqlRequest(url)) { + if (fetchMock.mock.calls.filter((c) => isSqlRequest(c[0])).length === 1) { + return Promise.resolve(new Response("not enrolled", { status: 401 })); + } + return Promise.resolve(makeSqlSuccess()); + } + return Promise.reject(new Error(`Unexpected URL: ${String(url)}`)); + }); + + const { callSql } = await import("./spacetimedb-client"); + const rows = await callSql<{ question: string, answer: string }>("SELECT question, answer FROM published_qa"); + expect(rows).toEqual([{ question: "q1", answer: "a1" }]); + + const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0])); + expect(sqlCalls.length).toBe(2); + const enrollCalls = fetchMock.mock.calls.filter((c) => isEnrollRequest(c[0])); + expect(enrollCalls.length).toBe(2); + }); + + it("returns rows on first-attempt success without re-enrolling", async () => { + fetchMock.mockImplementation((url: unknown) => { + if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200)); + if (isSqlRequest(url)) return Promise.resolve(makeSqlSuccess()); + return Promise.reject(new Error(`Unexpected URL: ${String(url)}`)); + }); + + const { callSql } = await import("./spacetimedb-client"); + const rows = await callSql<{ question: string, answer: string }>("SELECT question, answer FROM published_qa"); + expect(rows).toEqual([{ question: "q1", answer: "a1" }]); + + const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0])); + expect(sqlCalls.length).toBe(1); + const enrollCalls = fetchMock.mock.calls.filter((c) => isEnrollRequest(c[0])); + expect(enrollCalls.length).toBe(1); + }); + + it("propagates non-401 SQL errors without retrying", async () => { + fetchMock.mockImplementation((url: unknown) => { + if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200)); + if (isSqlRequest(url)) return Promise.resolve(new Response("syntax error", { status: 400 })); + return Promise.reject(new Error(`Unexpected URL: ${String(url)}`)); + }); + + const { callSql } = await import("./spacetimedb-client"); + await expect(callSql("SELECT BAD")).rejects.toThrow(); + + const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0])); + expect(sqlCalls.length).toBe(1); + }); +}); diff --git a/apps/backend/src/lib/ai/spacetimedb-client.ts b/apps/backend/src/lib/ai/spacetimedb-client.ts index 7d142dd2c..41ae20204 100644 --- a/apps/backend/src/lib/ai/spacetimedb-client.ts +++ b/apps/backend/src/lib/ai/spacetimedb-client.ts @@ -67,29 +67,27 @@ function spacetimeDbError(label: string, status: number, preview: string): Error return new StackAssertionError(detail); } -async function callWithEnrollmentRetry(reducer: string, args: unknown[]): Promise { +async function withEnrollmentRetry(op: (token: string) => Promise): Promise { const token = await getServiceToken(); - if (!token) return false; + if (!token) return null; try { - await rawCallReducer(token, reducer, args); - return true; + return await op(token); } catch (err) { if (!(err instanceof StatusError) || err.statusCode !== 401) throw err; enrollmentPromise = null; const fresh = await getServiceToken(); if (!fresh) throw err; - await rawCallReducer(fresh, reducer, args); - return true; + return await op(fresh); } } export async function callReducer(reducer: string, args: unknown[]): Promise { - await callWithEnrollmentRetry(reducer, args); + await withEnrollmentRetry((token) => rawCallReducer(token, reducer, args)); } export async function callReducerStrict(reducer: string, args: unknown[]): Promise { - const ran = await callWithEnrollmentRetry(reducer, args); - if (!ran) { + const ran = await withEnrollmentRetry((token) => rawCallReducer(token, reducer, args)); + if (ran === null) { throw new StackAssertionError( `SpacetimeDB is not configured. Reducer ${reducer} cannot run. ` + `Check STACK_SPACETIMEDB_URL and STACK_SPACETIMEDB_SERVICE_TOKEN.` @@ -106,11 +104,12 @@ export function opt(value: T | null | undefined): { some: T } | { none: [] } return value == null ? { none: [] } : { some: value }; } -export async function callSql>(sql: string): Promise { - const token = await getServiceToken(); - if (!token) return []; +async function rawCallSql(token: string, sql: string): Promise }, + rows: unknown[][], +}>> { const base = httpBase(); - if (!base) return []; + if (!base) throw new StackAssertionError("SpacetimeDB not configured"); const dbName = getEnvVariable("STACK_SPACETIMEDB_DB_NAME"); const res = await fetch(`${base}/v1/database/${encodeURIComponent(dbName)}/sql`, { method: "POST", @@ -122,11 +121,15 @@ export async function callSql>(sql: string): Promise const preview = (await res.text()).slice(0, 200); throw spacetimeDbError("SQL query failed", res.status, preview); } - const parsed = await res.json() as Array<{ + return await res.json() as Array<{ schema: { elements: Array<{ name: { some?: string } | null }> }, rows: unknown[][], }>; - if (parsed.length === 0) return []; +} + +export async function callSql>(sql: string): Promise { + const parsed = await withEnrollmentRetry((token) => rawCallSql(token, sql)); + if (parsed == null || parsed.length === 0) return []; const first = parsed[0]; const cols = first.schema.elements.map(e => e.name?.some ?? ""); return first.rows.map(row => { diff --git a/apps/backend/src/lib/ai/tools/index.test.ts b/apps/backend/src/lib/ai/tools/index.test.ts new file mode 100644 index 000000000..aef3ba977 --- /dev/null +++ b/apps/backend/src/lib/ai/tools/index.test.ts @@ -0,0 +1,26 @@ +import { describe, expect, it } from "vitest"; +import { TOOL_NAMES, validateToolNames } from "./index"; + +describe("validateToolNames stays in sync with TOOL_NAMES", () => { + it("accepts the full TOOL_NAMES list", () => { + expect(validateToolNames([...TOOL_NAMES])).toBe(true); + }); + + it("accepts each tool name individually", () => { + for (const name of TOOL_NAMES) { + expect(validateToolNames([name])).toBe(true); + } + }); + + it("rejects unknown names", () => { + expect(validateToolNames(["does-not-exist"])).toBe(false); + expect(validateToolNames([...TOOL_NAMES, "does-not-exist"])).toBe(false); + }); + + it("rejects non-array inputs", () => { + expect(validateToolNames("docs")).toBe(false); + expect(validateToolNames(null)).toBe(false); + expect(validateToolNames(undefined)).toBe(false); + expect(validateToolNames({})).toBe(false); + }); +}); diff --git a/apps/backend/src/lib/ai/tools/index.ts b/apps/backend/src/lib/ai/tools/index.ts index ba0d52226..f66cecf21 100644 --- a/apps/backend/src/lib/ai/tools/index.ts +++ b/apps/backend/src/lib/ai/tools/index.ts @@ -81,22 +81,13 @@ export async function getTools( /** * Validates that all requested tool names are valid. - * Throws an error if any tool name is invalid. + * Returns false if any tool name is not in `TOOL_NAMES`. */ export function validateToolNames(toolNames: unknown): toolNames is ToolName[] { if (!Array.isArray(toolNames)) { return false; } - - const validToolNames: ToolName[] = [ - "docs", - "sql-query", - "create-email-theme", - "create-email-template", - "create-email-draft", - "update-dashboard", - "patch-dashboard" - ]; - - return toolNames.every((name) => validToolNames.includes(name as ToolName)); + return toolNames.every((name): name is ToolName => + typeof name === "string" && (TOOL_NAMES as readonly string[]).includes(name) + ); }