Add tests for AI proxy handlers and SpacetimeDB client

- Introduced unit tests for `observeAndLog` in `ai-proxy-handlers.test.ts` to ensure correct behavior when logging functions throw errors.
- Added tests for `callSql` in `spacetimedb-client.test.ts` to verify 401 enrollment retry logic and error handling.
- Created tests for `buildProxyLogRow` in `ai-proxy-logger.test.ts` to validate tool-name extraction from parsed logs.
- Implemented validation tests for tool names in `index.test.ts` to ensure consistency with defined tool names.

These additions enhance test coverage and reliability of the AI-related functionalities.
This commit is contained in:
mantrakp04 2026-05-23 11:31:22 -07:00
parent 1a8e38e511
commit 29fa318525
8 changed files with 419 additions and 56 deletions

View File

@ -0,0 +1,137 @@
/**
* Ensures `observeAndLog` returns upstream bytes even when proxy logging throws.
*/
import { describe, expect, it, vi } from "vitest";
vi.mock("./loggers/ai-proxy-logger", () => ({
buildProxyLogRow: vi.fn(),
scheduleProxyLog: vi.fn(),
}));
vi.mock("./openrouter-usage", async (importOriginal) => {
const original = await importOriginal<Record<string, unknown>>();
return {
...original,
extractOpenRouterUsage: vi.fn(() => ({})),
scanSseForUsage: vi.fn(async () => ({})),
};
});
vi.mock("@/private", () => ({
preprocessProxyBody: ({ parsedBody }: { parsedBody: unknown }) => parsedBody,
}));
import { buildProxyLogRow, scheduleProxyLog } from "./loggers/ai-proxy-logger";
import { observeAndLog, sanitizeBody } from "./ai-proxy-handlers";
const sanitized = sanitizeBody(new TextEncoder().encode(JSON.stringify({
model: "anthropic/claude-sonnet-4.6",
messages: [{ role: "user", content: "hi" }],
})).buffer as ArrayBuffer);
describe("observeAndLog body delivery", () => {
it("returns the upstream bytes even if buildProxyLogRow throws", async () => {
vi.mocked(buildProxyLogRow).mockImplementation(() => {
throw new Error("synthetic log construction failure");
});
vi.mocked(scheduleProxyLog).mockClear();
const upstreamBody = JSON.stringify({ id: "gen-1", choices: [{ message: { content: "hello" } }] });
const upstream = new Response(upstreamBody, {
status: 200,
headers: { "Content-Type": "application/json" },
});
const out = await observeAndLog({
response: upstream,
sanitizedBody: sanitized,
callerApiKey: "stack-auth-test",
correlationId: "corr-1",
startedAt: 0,
responseHeaders: { "Content-Type": "application/json", "Cache-Control": "no-store" },
});
expect(out.status).toBe(200);
expect(await out.text()).toBe(upstreamBody);
});
it("returns the upstream bytes even if scheduleProxyLog throws", async () => {
vi.mocked(buildProxyLogRow).mockReturnValue({
correlationId: "corr-2",
mode: "generate",
systemPromptId: "stack-auth-test",
quality: "unknown",
speed: "unknown",
modelId: "anthropic/claude-sonnet-4.6",
isAuthenticated: false,
projectId: undefined,
userId: undefined,
requestedToolsJson: "[]",
messagesJson: "[]",
stepsJson: "[]",
finalText: "",
inputTokens: undefined,
outputTokens: undefined,
cachedInputTokens: undefined,
cacheCreationTokens: undefined,
costUsd: undefined,
cacheDiscountUsd: undefined,
openrouterGenerationId: undefined,
stepCount: 0,
durationMs: 0n,
errorMessage: undefined,
conversationId: undefined,
});
vi.mocked(scheduleProxyLog).mockImplementation(() => {
throw new Error("synthetic schedule failure");
});
const upstreamBody = JSON.stringify({ id: "gen-2", choices: [{ message: { content: "world" } }] });
const upstream = new Response(upstreamBody, {
status: 200,
headers: { "Content-Type": "application/json" },
});
const out = await observeAndLog({
response: upstream,
sanitizedBody: sanitized,
callerApiKey: "stack-auth-test",
correlationId: "corr-2",
startedAt: 0,
responseHeaders: { "Content-Type": "application/json", "Cache-Control": "no-store" },
});
expect(out.status).toBe(200);
expect(await out.text()).toBe(upstreamBody);
});
it("delivers full streamed bytes via tee even when the observer arm throws", async () => {
vi.mocked(buildProxyLogRow).mockImplementation(() => {
throw new Error("synthetic log failure inside async observer");
});
const streamPayload = "data: {\"id\":\"gen-3\"}\n\ndata: [DONE]\n\n";
const upstream = new Response(streamPayload, {
status: 200,
headers: { "Content-Type": "text/event-stream" },
});
const streamingSanitized = sanitizeBody(new TextEncoder().encode(JSON.stringify({
model: "anthropic/claude-sonnet-4.6",
stream: true,
messages: [{ role: "user", content: "hi" }],
})).buffer as ArrayBuffer);
const out = await observeAndLog({
response: upstream,
sanitizedBody: streamingSanitized,
callerApiKey: "stack-auth-test",
correlationId: "corr-3",
startedAt: 0,
responseHeaders: { "Content-Type": "text/event-stream", "Cache-Control": "no-store" },
});
expect(out.status).toBe(200);
expect(await out.text()).toBe(streamPayload);
});
});

View File

@ -51,42 +51,50 @@ export async function observeAndLog(args: {
if (isStreaming && response.body) {
const [clientStream, observerStream] = response.body.tee();
runAsynchronouslyAndWaitUntil((async () => {
let usage: UsageFields = {};
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 120_000);
try {
usage = (await scanSseForUsage(observerStream, controller.signal)) ?? {};
let usage: UsageFields = {};
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 120_000);
try {
usage = (await scanSseForUsage(observerStream, controller.signal)) ?? {};
} catch (err) {
captureError("ai-proxy-scan-sse", err);
} finally {
clearTimeout(timeoutId);
}
scheduleProxyLog(buildProxyLogRow({
correlationId,
parsed: sanitizedBody.parsed,
apiKey: callerApiKey,
durationMs: BigInt(Math.round(performance.now() - startedAt)),
responseStatus: response.status,
usage,
}));
} catch (err) {
captureError("ai-proxy-scan-sse", err);
} finally {
clearTimeout(timeoutId);
captureError("ai-proxy-observer", err);
}
scheduleProxyLog(buildProxyLogRow({
correlationId,
parsed: sanitizedBody.parsed,
apiKey: callerApiKey,
durationMs: BigInt(Math.round(performance.now() - startedAt)),
responseStatus: response.status,
usage,
}));
})());
return new Response(clientStream, { status: response.status, headers: responseHeaders });
}
const bodyBytes = await response.arrayBuffer();
let parsedBody: unknown;
try {
parsedBody = JSON.parse(new TextDecoder().decode(bodyBytes));
} catch {
parsedBody = undefined;
let parsedBody: unknown;
try {
parsedBody = JSON.parse(new TextDecoder().decode(bodyBytes));
} catch {
parsedBody = undefined;
}
scheduleProxyLog(buildProxyLogRow({
correlationId,
parsed: sanitizedBody.parsed,
apiKey: callerApiKey,
durationMs: BigInt(Math.round(performance.now() - startedAt)),
responseStatus: response.status,
usage: extractOpenRouterUsage(parsedBody),
}));
} catch (err) {
captureError("ai-proxy-log-build", err);
}
scheduleProxyLog(buildProxyLogRow({
correlationId,
parsed: sanitizedBody.parsed,
apiKey: callerApiKey,
durationMs: BigInt(Math.round(performance.now() - startedAt)),
responseStatus: response.status,
usage: extractOpenRouterUsage(parsedBody),
}));
return new Response(bodyBytes, { status: response.status, headers: responseHeaders });
}

View File

@ -0,0 +1,69 @@
/**
* Tool-name extraction in `buildProxyLogRow` for Anthropic and OpenAI shapes.
*/
import { describe, expect, it } from "vitest";
import { buildProxyLogRow } from "./ai-proxy-logger";
const baseInput = {
correlationId: "corr-1",
apiKey: "stack-auth-proxy",
durationMs: 0n,
responseStatus: 200,
};
describe("buildProxyLogRow tool-name extraction", () => {
it("captures Anthropic top-level tool names", () => {
const row = buildProxyLogRow({
...baseInput,
parsed: {
model: "anthropic/claude-sonnet-4.6",
tools: [
{ name: "get_weather", description: "...", input_schema: {} },
{ name: "send_email", description: "...", input_schema: {} },
],
},
});
expect(JSON.parse(row.requestedToolsJson)).toEqual(["get_weather", "send_email"]);
});
it("captures OpenAI/OpenRouter-format function tool names", () => {
const row = buildProxyLogRow({
...baseInput,
parsed: {
model: "anthropic/claude-sonnet-4.6",
tools: [
{ type: "function", function: { name: "get_weather", parameters: {} } },
{ type: "function", function: { name: "send_email", parameters: {} } },
],
},
});
expect(JSON.parse(row.requestedToolsJson)).toEqual(["get_weather", "send_email"]);
});
it("handles a mixed array gracefully", () => {
const row = buildProxyLogRow({
...baseInput,
parsed: {
model: "anthropic/claude-sonnet-4.6",
tools: [
{ name: "anthropic_tool", input_schema: {} },
{ type: "function", function: { name: "openai_tool", parameters: {} } },
{ type: "function" },
null,
"not an object",
],
},
});
expect(JSON.parse(row.requestedToolsJson)).toEqual(["anthropic_tool", "openai_tool"]);
});
it("returns an empty array when tools is absent or malformed", () => {
const row = buildProxyLogRow({
...baseInput,
parsed: {
model: "anthropic/claude-sonnet-4.6",
},
});
expect(JSON.parse(row.requestedToolsJson)).toEqual([]);
});
});

View File

@ -16,7 +16,13 @@ export function buildProxyLogRow(fields: ProxyLogFields): AiQueryLogEntry {
const { parsed, apiKey, durationMs, responseStatus, usage, correlationId } = fields;
const tools = Array.isArray(parsed.tools) ? parsed.tools : [];
const toolNames = tools
.map(t => (t && typeof t === "object" && "name" in t) ? (t as { name: unknown }).name : null)
.map((t) => {
if (t == null || typeof t !== "object") return null;
const obj = t as { name?: unknown, function?: { name?: unknown } };
if (typeof obj.function?.name === "string") return obj.function.name;
if (typeof obj.name === "string") return obj.name;
return null;
})
.filter((n): n is string => typeof n === "string");
const rawMessages = Array.isArray(parsed.messages) ? parsed.messages : [];
const messages = typeof parsed.system === "string" && parsed.system.length > 0

View File

@ -0,0 +1,123 @@
/**
* Exercises `callSql` 401 enrollment retry (shared with `callReducer` via
* `withEnrollmentRetry`).
*/
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
vi.mock("@stackframe/stack-shared/dist/utils/env", async (importOriginal) => {
const original = await importOriginal<Record<string, unknown>>();
return {
...original,
getEnvVariable: (key: string, fallback?: string) => {
switch (key) {
case "STACK_SPACETIMEDB_URL": {
return "http://spacetime.test";
}
case "STACK_SPACETIMEDB_DB_NAME": {
return "test-db";
}
case "STACK_SPACETIMEDB_SERVICE_TOKEN": {
return "test-service-token";
}
case "STACK_MCP_LOG_TOKEN": {
return "test-log-token";
}
default: {
return fallback ?? "";
}
}
},
};
});
const fetchMock = vi.fn();
const originalFetch = globalThis.fetch;
function makeJsonResponse(body: unknown, status = 200): Response {
return new Response(JSON.stringify(body), {
status,
headers: { "Content-Type": "application/json" },
});
}
function makeSqlSuccess(): Response {
return makeJsonResponse([{
schema: { elements: [{ name: { some: "question" } }, { name: { some: "answer" } }] },
rows: [["q1", "a1"]],
}]);
}
function isSqlRequest(url: unknown): boolean {
return typeof url === "string" && url.endsWith("/sql");
}
function isEnrollRequest(url: unknown): boolean {
return typeof url === "string" && url.endsWith("/call/enroll_service");
}
beforeEach(() => {
fetchMock.mockReset();
globalThis.fetch = fetchMock as unknown as typeof fetch;
// The module-level enrollmentPromise cache survives across tests; reset it
// by re-importing the module fresh for each test.
vi.resetModules();
});
afterEach(() => {
globalThis.fetch = originalFetch;
});
describe("callSql 401 enrollment retry", () => {
it("retries SQL after a 401 by re-enrolling, mirroring callReducer", async () => {
fetchMock.mockImplementation((url: unknown) => {
if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200));
if (isSqlRequest(url)) {
if (fetchMock.mock.calls.filter((c) => isSqlRequest(c[0])).length === 1) {
return Promise.resolve(new Response("not enrolled", { status: 401 }));
}
return Promise.resolve(makeSqlSuccess());
}
return Promise.reject(new Error(`Unexpected URL: ${String(url)}`));
});
const { callSql } = await import("./spacetimedb-client");
const rows = await callSql<{ question: string, answer: string }>("SELECT question, answer FROM published_qa");
expect(rows).toEqual([{ question: "q1", answer: "a1" }]);
const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0]));
expect(sqlCalls.length).toBe(2);
const enrollCalls = fetchMock.mock.calls.filter((c) => isEnrollRequest(c[0]));
expect(enrollCalls.length).toBe(2);
});
it("returns rows on first-attempt success without re-enrolling", async () => {
fetchMock.mockImplementation((url: unknown) => {
if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200));
if (isSqlRequest(url)) return Promise.resolve(makeSqlSuccess());
return Promise.reject(new Error(`Unexpected URL: ${String(url)}`));
});
const { callSql } = await import("./spacetimedb-client");
const rows = await callSql<{ question: string, answer: string }>("SELECT question, answer FROM published_qa");
expect(rows).toEqual([{ question: "q1", answer: "a1" }]);
const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0]));
expect(sqlCalls.length).toBe(1);
const enrollCalls = fetchMock.mock.calls.filter((c) => isEnrollRequest(c[0]));
expect(enrollCalls.length).toBe(1);
});
it("propagates non-401 SQL errors without retrying", async () => {
fetchMock.mockImplementation((url: unknown) => {
if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200));
if (isSqlRequest(url)) return Promise.resolve(new Response("syntax error", { status: 400 }));
return Promise.reject(new Error(`Unexpected URL: ${String(url)}`));
});
const { callSql } = await import("./spacetimedb-client");
await expect(callSql("SELECT BAD")).rejects.toThrow();
const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0]));
expect(sqlCalls.length).toBe(1);
});
});

View File

@ -67,29 +67,27 @@ function spacetimeDbError(label: string, status: number, preview: string): Error
return new StackAssertionError(detail);
}
async function callWithEnrollmentRetry(reducer: string, args: unknown[]): Promise<boolean> {
async function withEnrollmentRetry<T>(op: (token: string) => Promise<T>): Promise<T | null> {
const token = await getServiceToken();
if (!token) return false;
if (!token) return null;
try {
await rawCallReducer(token, reducer, args);
return true;
return await op(token);
} catch (err) {
if (!(err instanceof StatusError) || err.statusCode !== 401) throw err;
enrollmentPromise = null;
const fresh = await getServiceToken();
if (!fresh) throw err;
await rawCallReducer(fresh, reducer, args);
return true;
return await op(fresh);
}
}
export async function callReducer(reducer: string, args: unknown[]): Promise<void> {
await callWithEnrollmentRetry(reducer, args);
await withEnrollmentRetry((token) => rawCallReducer(token, reducer, args));
}
export async function callReducerStrict(reducer: string, args: unknown[]): Promise<void> {
const ran = await callWithEnrollmentRetry(reducer, args);
if (!ran) {
const ran = await withEnrollmentRetry((token) => rawCallReducer(token, reducer, args));
if (ran === null) {
throw new StackAssertionError(
`SpacetimeDB is not configured. Reducer ${reducer} cannot run. ` +
`Check STACK_SPACETIMEDB_URL and STACK_SPACETIMEDB_SERVICE_TOKEN.`
@ -106,11 +104,12 @@ export function opt<T>(value: T | null | undefined): { some: T } | { none: [] }
return value == null ? { none: [] } : { some: value };
}
export async function callSql<T = Record<string, unknown>>(sql: string): Promise<T[]> {
const token = await getServiceToken();
if (!token) return [];
async function rawCallSql(token: string, sql: string): Promise<Array<{
schema: { elements: Array<{ name: { some?: string } | null }> },
rows: unknown[][],
}>> {
const base = httpBase();
if (!base) return [];
if (!base) throw new StackAssertionError("SpacetimeDB not configured");
const dbName = getEnvVariable("STACK_SPACETIMEDB_DB_NAME");
const res = await fetch(`${base}/v1/database/${encodeURIComponent(dbName)}/sql`, {
method: "POST",
@ -122,11 +121,15 @@ export async function callSql<T = Record<string, unknown>>(sql: string): Promise
const preview = (await res.text()).slice(0, 200);
throw spacetimeDbError("SQL query failed", res.status, preview);
}
const parsed = await res.json() as Array<{
return await res.json() as Array<{
schema: { elements: Array<{ name: { some?: string } | null }> },
rows: unknown[][],
}>;
if (parsed.length === 0) return [];
}
export async function callSql<T = Record<string, unknown>>(sql: string): Promise<T[]> {
const parsed = await withEnrollmentRetry((token) => rawCallSql(token, sql));
if (parsed == null || parsed.length === 0) return [];
const first = parsed[0];
const cols = first.schema.elements.map(e => e.name?.some ?? "");
return first.rows.map(row => {

View File

@ -0,0 +1,26 @@
import { describe, expect, it } from "vitest";
import { TOOL_NAMES, validateToolNames } from "./index";
describe("validateToolNames stays in sync with TOOL_NAMES", () => {
it("accepts the full TOOL_NAMES list", () => {
expect(validateToolNames([...TOOL_NAMES])).toBe(true);
});
it("accepts each tool name individually", () => {
for (const name of TOOL_NAMES) {
expect(validateToolNames([name])).toBe(true);
}
});
it("rejects unknown names", () => {
expect(validateToolNames(["does-not-exist"])).toBe(false);
expect(validateToolNames([...TOOL_NAMES, "does-not-exist"])).toBe(false);
});
it("rejects non-array inputs", () => {
expect(validateToolNames("docs")).toBe(false);
expect(validateToolNames(null)).toBe(false);
expect(validateToolNames(undefined)).toBe(false);
expect(validateToolNames({})).toBe(false);
});
});

View File

@ -81,22 +81,13 @@ export async function getTools(
/**
* Validates that all requested tool names are valid.
* Throws an error if any tool name is invalid.
* Returns false if any tool name is not in `TOOL_NAMES`.
*/
export function validateToolNames(toolNames: unknown): toolNames is ToolName[] {
if (!Array.isArray(toolNames)) {
return false;
}
const validToolNames: ToolName[] = [
"docs",
"sql-query",
"create-email-theme",
"create-email-template",
"create-email-draft",
"update-dashboard",
"patch-dashboard"
];
return toolNames.every((name) => validToolNames.includes(name as ToolName));
return toolNames.every((name): name is ToolName =>
typeof name === "string" && (TOOL_NAMES as readonly string[]).includes(name)
);
}