mirror of
https://github.com/stack-auth/stack.git
synced 2026-06-13 21:01:21 +08:00
Add tests for AI proxy handlers and SpacetimeDB client
- Introduced unit tests for `observeAndLog` in `ai-proxy-handlers.test.ts` to ensure correct behavior when logging functions throw errors. - Added tests for `callSql` in `spacetimedb-client.test.ts` to verify 401 enrollment retry logic and error handling. - Created tests for `buildProxyLogRow` in `ai-proxy-logger.test.ts` to validate tool-name extraction from parsed logs. - Implemented validation tests for tool names in `index.test.ts` to ensure consistency with defined tool names. These additions enhance test coverage and reliability of the AI-related functionalities.
This commit is contained in:
parent
1a8e38e511
commit
29fa318525
137
apps/backend/src/lib/ai/ai-proxy-handlers.test.ts
Normal file
137
apps/backend/src/lib/ai/ai-proxy-handlers.test.ts
Normal file
@ -0,0 +1,137 @@
|
||||
/**
|
||||
* Ensures `observeAndLog` returns upstream bytes even when proxy logging throws.
|
||||
*/
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
vi.mock("./loggers/ai-proxy-logger", () => ({
|
||||
buildProxyLogRow: vi.fn(),
|
||||
scheduleProxyLog: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("./openrouter-usage", async (importOriginal) => {
|
||||
const original = await importOriginal<Record<string, unknown>>();
|
||||
return {
|
||||
...original,
|
||||
extractOpenRouterUsage: vi.fn(() => ({})),
|
||||
scanSseForUsage: vi.fn(async () => ({})),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("@/private", () => ({
|
||||
preprocessProxyBody: ({ parsedBody }: { parsedBody: unknown }) => parsedBody,
|
||||
}));
|
||||
|
||||
import { buildProxyLogRow, scheduleProxyLog } from "./loggers/ai-proxy-logger";
|
||||
import { observeAndLog, sanitizeBody } from "./ai-proxy-handlers";
|
||||
|
||||
const sanitized = sanitizeBody(new TextEncoder().encode(JSON.stringify({
|
||||
model: "anthropic/claude-sonnet-4.6",
|
||||
messages: [{ role: "user", content: "hi" }],
|
||||
})).buffer as ArrayBuffer);
|
||||
|
||||
describe("observeAndLog body delivery", () => {
|
||||
it("returns the upstream bytes even if buildProxyLogRow throws", async () => {
|
||||
vi.mocked(buildProxyLogRow).mockImplementation(() => {
|
||||
throw new Error("synthetic log construction failure");
|
||||
});
|
||||
vi.mocked(scheduleProxyLog).mockClear();
|
||||
|
||||
const upstreamBody = JSON.stringify({ id: "gen-1", choices: [{ message: { content: "hello" } }] });
|
||||
const upstream = new Response(upstreamBody, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
|
||||
const out = await observeAndLog({
|
||||
response: upstream,
|
||||
sanitizedBody: sanitized,
|
||||
callerApiKey: "stack-auth-test",
|
||||
correlationId: "corr-1",
|
||||
startedAt: 0,
|
||||
responseHeaders: { "Content-Type": "application/json", "Cache-Control": "no-store" },
|
||||
});
|
||||
|
||||
expect(out.status).toBe(200);
|
||||
expect(await out.text()).toBe(upstreamBody);
|
||||
});
|
||||
|
||||
it("returns the upstream bytes even if scheduleProxyLog throws", async () => {
|
||||
vi.mocked(buildProxyLogRow).mockReturnValue({
|
||||
correlationId: "corr-2",
|
||||
mode: "generate",
|
||||
systemPromptId: "stack-auth-test",
|
||||
quality: "unknown",
|
||||
speed: "unknown",
|
||||
modelId: "anthropic/claude-sonnet-4.6",
|
||||
isAuthenticated: false,
|
||||
projectId: undefined,
|
||||
userId: undefined,
|
||||
requestedToolsJson: "[]",
|
||||
messagesJson: "[]",
|
||||
stepsJson: "[]",
|
||||
finalText: "",
|
||||
inputTokens: undefined,
|
||||
outputTokens: undefined,
|
||||
cachedInputTokens: undefined,
|
||||
cacheCreationTokens: undefined,
|
||||
costUsd: undefined,
|
||||
cacheDiscountUsd: undefined,
|
||||
openrouterGenerationId: undefined,
|
||||
stepCount: 0,
|
||||
durationMs: 0n,
|
||||
errorMessage: undefined,
|
||||
conversationId: undefined,
|
||||
});
|
||||
vi.mocked(scheduleProxyLog).mockImplementation(() => {
|
||||
throw new Error("synthetic schedule failure");
|
||||
});
|
||||
|
||||
const upstreamBody = JSON.stringify({ id: "gen-2", choices: [{ message: { content: "world" } }] });
|
||||
const upstream = new Response(upstreamBody, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
|
||||
const out = await observeAndLog({
|
||||
response: upstream,
|
||||
sanitizedBody: sanitized,
|
||||
callerApiKey: "stack-auth-test",
|
||||
correlationId: "corr-2",
|
||||
startedAt: 0,
|
||||
responseHeaders: { "Content-Type": "application/json", "Cache-Control": "no-store" },
|
||||
});
|
||||
|
||||
expect(out.status).toBe(200);
|
||||
expect(await out.text()).toBe(upstreamBody);
|
||||
});
|
||||
|
||||
it("delivers full streamed bytes via tee even when the observer arm throws", async () => {
|
||||
vi.mocked(buildProxyLogRow).mockImplementation(() => {
|
||||
throw new Error("synthetic log failure inside async observer");
|
||||
});
|
||||
|
||||
const streamPayload = "data: {\"id\":\"gen-3\"}\n\ndata: [DONE]\n\n";
|
||||
const upstream = new Response(streamPayload, {
|
||||
status: 200,
|
||||
headers: { "Content-Type": "text/event-stream" },
|
||||
});
|
||||
|
||||
const streamingSanitized = sanitizeBody(new TextEncoder().encode(JSON.stringify({
|
||||
model: "anthropic/claude-sonnet-4.6",
|
||||
stream: true,
|
||||
messages: [{ role: "user", content: "hi" }],
|
||||
})).buffer as ArrayBuffer);
|
||||
|
||||
const out = await observeAndLog({
|
||||
response: upstream,
|
||||
sanitizedBody: streamingSanitized,
|
||||
callerApiKey: "stack-auth-test",
|
||||
correlationId: "corr-3",
|
||||
startedAt: 0,
|
||||
responseHeaders: { "Content-Type": "text/event-stream", "Cache-Control": "no-store" },
|
||||
});
|
||||
|
||||
expect(out.status).toBe(200);
|
||||
expect(await out.text()).toBe(streamPayload);
|
||||
});
|
||||
});
|
||||
@ -51,42 +51,50 @@ export async function observeAndLog(args: {
|
||||
if (isStreaming && response.body) {
|
||||
const [clientStream, observerStream] = response.body.tee();
|
||||
runAsynchronouslyAndWaitUntil((async () => {
|
||||
let usage: UsageFields = {};
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), 120_000);
|
||||
try {
|
||||
usage = (await scanSseForUsage(observerStream, controller.signal)) ?? {};
|
||||
let usage: UsageFields = {};
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), 120_000);
|
||||
try {
|
||||
usage = (await scanSseForUsage(observerStream, controller.signal)) ?? {};
|
||||
} catch (err) {
|
||||
captureError("ai-proxy-scan-sse", err);
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
scheduleProxyLog(buildProxyLogRow({
|
||||
correlationId,
|
||||
parsed: sanitizedBody.parsed,
|
||||
apiKey: callerApiKey,
|
||||
durationMs: BigInt(Math.round(performance.now() - startedAt)),
|
||||
responseStatus: response.status,
|
||||
usage,
|
||||
}));
|
||||
} catch (err) {
|
||||
captureError("ai-proxy-scan-sse", err);
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
captureError("ai-proxy-observer", err);
|
||||
}
|
||||
scheduleProxyLog(buildProxyLogRow({
|
||||
correlationId,
|
||||
parsed: sanitizedBody.parsed,
|
||||
apiKey: callerApiKey,
|
||||
durationMs: BigInt(Math.round(performance.now() - startedAt)),
|
||||
responseStatus: response.status,
|
||||
usage,
|
||||
}));
|
||||
})());
|
||||
return new Response(clientStream, { status: response.status, headers: responseHeaders });
|
||||
}
|
||||
|
||||
const bodyBytes = await response.arrayBuffer();
|
||||
let parsedBody: unknown;
|
||||
try {
|
||||
parsedBody = JSON.parse(new TextDecoder().decode(bodyBytes));
|
||||
} catch {
|
||||
parsedBody = undefined;
|
||||
let parsedBody: unknown;
|
||||
try {
|
||||
parsedBody = JSON.parse(new TextDecoder().decode(bodyBytes));
|
||||
} catch {
|
||||
parsedBody = undefined;
|
||||
}
|
||||
scheduleProxyLog(buildProxyLogRow({
|
||||
correlationId,
|
||||
parsed: sanitizedBody.parsed,
|
||||
apiKey: callerApiKey,
|
||||
durationMs: BigInt(Math.round(performance.now() - startedAt)),
|
||||
responseStatus: response.status,
|
||||
usage: extractOpenRouterUsage(parsedBody),
|
||||
}));
|
||||
} catch (err) {
|
||||
captureError("ai-proxy-log-build", err);
|
||||
}
|
||||
scheduleProxyLog(buildProxyLogRow({
|
||||
correlationId,
|
||||
parsed: sanitizedBody.parsed,
|
||||
apiKey: callerApiKey,
|
||||
durationMs: BigInt(Math.round(performance.now() - startedAt)),
|
||||
responseStatus: response.status,
|
||||
usage: extractOpenRouterUsage(parsedBody),
|
||||
}));
|
||||
return new Response(bodyBytes, { status: response.status, headers: responseHeaders });
|
||||
}
|
||||
|
||||
69
apps/backend/src/lib/ai/loggers/ai-proxy-logger.test.ts
Normal file
69
apps/backend/src/lib/ai/loggers/ai-proxy-logger.test.ts
Normal file
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Tool-name extraction in `buildProxyLogRow` for Anthropic and OpenAI shapes.
|
||||
*/
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { buildProxyLogRow } from "./ai-proxy-logger";
|
||||
|
||||
const baseInput = {
|
||||
correlationId: "corr-1",
|
||||
apiKey: "stack-auth-proxy",
|
||||
durationMs: 0n,
|
||||
responseStatus: 200,
|
||||
};
|
||||
|
||||
describe("buildProxyLogRow tool-name extraction", () => {
|
||||
it("captures Anthropic top-level tool names", () => {
|
||||
const row = buildProxyLogRow({
|
||||
...baseInput,
|
||||
parsed: {
|
||||
model: "anthropic/claude-sonnet-4.6",
|
||||
tools: [
|
||||
{ name: "get_weather", description: "...", input_schema: {} },
|
||||
{ name: "send_email", description: "...", input_schema: {} },
|
||||
],
|
||||
},
|
||||
});
|
||||
expect(JSON.parse(row.requestedToolsJson)).toEqual(["get_weather", "send_email"]);
|
||||
});
|
||||
|
||||
it("captures OpenAI/OpenRouter-format function tool names", () => {
|
||||
const row = buildProxyLogRow({
|
||||
...baseInput,
|
||||
parsed: {
|
||||
model: "anthropic/claude-sonnet-4.6",
|
||||
tools: [
|
||||
{ type: "function", function: { name: "get_weather", parameters: {} } },
|
||||
{ type: "function", function: { name: "send_email", parameters: {} } },
|
||||
],
|
||||
},
|
||||
});
|
||||
expect(JSON.parse(row.requestedToolsJson)).toEqual(["get_weather", "send_email"]);
|
||||
});
|
||||
|
||||
it("handles a mixed array gracefully", () => {
|
||||
const row = buildProxyLogRow({
|
||||
...baseInput,
|
||||
parsed: {
|
||||
model: "anthropic/claude-sonnet-4.6",
|
||||
tools: [
|
||||
{ name: "anthropic_tool", input_schema: {} },
|
||||
{ type: "function", function: { name: "openai_tool", parameters: {} } },
|
||||
{ type: "function" },
|
||||
null,
|
||||
"not an object",
|
||||
],
|
||||
},
|
||||
});
|
||||
expect(JSON.parse(row.requestedToolsJson)).toEqual(["anthropic_tool", "openai_tool"]);
|
||||
});
|
||||
|
||||
it("returns an empty array when tools is absent or malformed", () => {
|
||||
const row = buildProxyLogRow({
|
||||
...baseInput,
|
||||
parsed: {
|
||||
model: "anthropic/claude-sonnet-4.6",
|
||||
},
|
||||
});
|
||||
expect(JSON.parse(row.requestedToolsJson)).toEqual([]);
|
||||
});
|
||||
});
|
||||
@ -16,7 +16,13 @@ export function buildProxyLogRow(fields: ProxyLogFields): AiQueryLogEntry {
|
||||
const { parsed, apiKey, durationMs, responseStatus, usage, correlationId } = fields;
|
||||
const tools = Array.isArray(parsed.tools) ? parsed.tools : [];
|
||||
const toolNames = tools
|
||||
.map(t => (t && typeof t === "object" && "name" in t) ? (t as { name: unknown }).name : null)
|
||||
.map((t) => {
|
||||
if (t == null || typeof t !== "object") return null;
|
||||
const obj = t as { name?: unknown, function?: { name?: unknown } };
|
||||
if (typeof obj.function?.name === "string") return obj.function.name;
|
||||
if (typeof obj.name === "string") return obj.name;
|
||||
return null;
|
||||
})
|
||||
.filter((n): n is string => typeof n === "string");
|
||||
const rawMessages = Array.isArray(parsed.messages) ? parsed.messages : [];
|
||||
const messages = typeof parsed.system === "string" && parsed.system.length > 0
|
||||
|
||||
123
apps/backend/src/lib/ai/spacetimedb-client.test.ts
Normal file
123
apps/backend/src/lib/ai/spacetimedb-client.test.ts
Normal file
@ -0,0 +1,123 @@
|
||||
/**
|
||||
* Exercises `callSql` 401 enrollment retry (shared with `callReducer` via
|
||||
* `withEnrollmentRetry`).
|
||||
*/
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
vi.mock("@stackframe/stack-shared/dist/utils/env", async (importOriginal) => {
|
||||
const original = await importOriginal<Record<string, unknown>>();
|
||||
return {
|
||||
...original,
|
||||
getEnvVariable: (key: string, fallback?: string) => {
|
||||
switch (key) {
|
||||
case "STACK_SPACETIMEDB_URL": {
|
||||
return "http://spacetime.test";
|
||||
}
|
||||
case "STACK_SPACETIMEDB_DB_NAME": {
|
||||
return "test-db";
|
||||
}
|
||||
case "STACK_SPACETIMEDB_SERVICE_TOKEN": {
|
||||
return "test-service-token";
|
||||
}
|
||||
case "STACK_MCP_LOG_TOKEN": {
|
||||
return "test-log-token";
|
||||
}
|
||||
default: {
|
||||
return fallback ?? "";
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const fetchMock = vi.fn();
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
function makeJsonResponse(body: unknown, status = 200): Response {
|
||||
return new Response(JSON.stringify(body), {
|
||||
status,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
function makeSqlSuccess(): Response {
|
||||
return makeJsonResponse([{
|
||||
schema: { elements: [{ name: { some: "question" } }, { name: { some: "answer" } }] },
|
||||
rows: [["q1", "a1"]],
|
||||
}]);
|
||||
}
|
||||
|
||||
function isSqlRequest(url: unknown): boolean {
|
||||
return typeof url === "string" && url.endsWith("/sql");
|
||||
}
|
||||
|
||||
function isEnrollRequest(url: unknown): boolean {
|
||||
return typeof url === "string" && url.endsWith("/call/enroll_service");
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
fetchMock.mockReset();
|
||||
globalThis.fetch = fetchMock as unknown as typeof fetch;
|
||||
// The module-level enrollmentPromise cache survives across tests; reset it
|
||||
// by re-importing the module fresh for each test.
|
||||
vi.resetModules();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
});
|
||||
|
||||
describe("callSql 401 enrollment retry", () => {
|
||||
it("retries SQL after a 401 by re-enrolling, mirroring callReducer", async () => {
|
||||
fetchMock.mockImplementation((url: unknown) => {
|
||||
if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200));
|
||||
if (isSqlRequest(url)) {
|
||||
if (fetchMock.mock.calls.filter((c) => isSqlRequest(c[0])).length === 1) {
|
||||
return Promise.resolve(new Response("not enrolled", { status: 401 }));
|
||||
}
|
||||
return Promise.resolve(makeSqlSuccess());
|
||||
}
|
||||
return Promise.reject(new Error(`Unexpected URL: ${String(url)}`));
|
||||
});
|
||||
|
||||
const { callSql } = await import("./spacetimedb-client");
|
||||
const rows = await callSql<{ question: string, answer: string }>("SELECT question, answer FROM published_qa");
|
||||
expect(rows).toEqual([{ question: "q1", answer: "a1" }]);
|
||||
|
||||
const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0]));
|
||||
expect(sqlCalls.length).toBe(2);
|
||||
const enrollCalls = fetchMock.mock.calls.filter((c) => isEnrollRequest(c[0]));
|
||||
expect(enrollCalls.length).toBe(2);
|
||||
});
|
||||
|
||||
it("returns rows on first-attempt success without re-enrolling", async () => {
|
||||
fetchMock.mockImplementation((url: unknown) => {
|
||||
if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200));
|
||||
if (isSqlRequest(url)) return Promise.resolve(makeSqlSuccess());
|
||||
return Promise.reject(new Error(`Unexpected URL: ${String(url)}`));
|
||||
});
|
||||
|
||||
const { callSql } = await import("./spacetimedb-client");
|
||||
const rows = await callSql<{ question: string, answer: string }>("SELECT question, answer FROM published_qa");
|
||||
expect(rows).toEqual([{ question: "q1", answer: "a1" }]);
|
||||
|
||||
const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0]));
|
||||
expect(sqlCalls.length).toBe(1);
|
||||
const enrollCalls = fetchMock.mock.calls.filter((c) => isEnrollRequest(c[0]));
|
||||
expect(enrollCalls.length).toBe(1);
|
||||
});
|
||||
|
||||
it("propagates non-401 SQL errors without retrying", async () => {
|
||||
fetchMock.mockImplementation((url: unknown) => {
|
||||
if (isEnrollRequest(url)) return Promise.resolve(makeJsonResponse({}, 200));
|
||||
if (isSqlRequest(url)) return Promise.resolve(new Response("syntax error", { status: 400 }));
|
||||
return Promise.reject(new Error(`Unexpected URL: ${String(url)}`));
|
||||
});
|
||||
|
||||
const { callSql } = await import("./spacetimedb-client");
|
||||
await expect(callSql("SELECT BAD")).rejects.toThrow();
|
||||
|
||||
const sqlCalls = fetchMock.mock.calls.filter((c) => isSqlRequest(c[0]));
|
||||
expect(sqlCalls.length).toBe(1);
|
||||
});
|
||||
});
|
||||
@ -67,29 +67,27 @@ function spacetimeDbError(label: string, status: number, preview: string): Error
|
||||
return new StackAssertionError(detail);
|
||||
}
|
||||
|
||||
async function callWithEnrollmentRetry(reducer: string, args: unknown[]): Promise<boolean> {
|
||||
async function withEnrollmentRetry<T>(op: (token: string) => Promise<T>): Promise<T | null> {
|
||||
const token = await getServiceToken();
|
||||
if (!token) return false;
|
||||
if (!token) return null;
|
||||
try {
|
||||
await rawCallReducer(token, reducer, args);
|
||||
return true;
|
||||
return await op(token);
|
||||
} catch (err) {
|
||||
if (!(err instanceof StatusError) || err.statusCode !== 401) throw err;
|
||||
enrollmentPromise = null;
|
||||
const fresh = await getServiceToken();
|
||||
if (!fresh) throw err;
|
||||
await rawCallReducer(fresh, reducer, args);
|
||||
return true;
|
||||
return await op(fresh);
|
||||
}
|
||||
}
|
||||
|
||||
export async function callReducer(reducer: string, args: unknown[]): Promise<void> {
|
||||
await callWithEnrollmentRetry(reducer, args);
|
||||
await withEnrollmentRetry((token) => rawCallReducer(token, reducer, args));
|
||||
}
|
||||
|
||||
export async function callReducerStrict(reducer: string, args: unknown[]): Promise<void> {
|
||||
const ran = await callWithEnrollmentRetry(reducer, args);
|
||||
if (!ran) {
|
||||
const ran = await withEnrollmentRetry((token) => rawCallReducer(token, reducer, args));
|
||||
if (ran === null) {
|
||||
throw new StackAssertionError(
|
||||
`SpacetimeDB is not configured. Reducer ${reducer} cannot run. ` +
|
||||
`Check STACK_SPACETIMEDB_URL and STACK_SPACETIMEDB_SERVICE_TOKEN.`
|
||||
@ -106,11 +104,12 @@ export function opt<T>(value: T | null | undefined): { some: T } | { none: [] }
|
||||
return value == null ? { none: [] } : { some: value };
|
||||
}
|
||||
|
||||
export async function callSql<T = Record<string, unknown>>(sql: string): Promise<T[]> {
|
||||
const token = await getServiceToken();
|
||||
if (!token) return [];
|
||||
async function rawCallSql(token: string, sql: string): Promise<Array<{
|
||||
schema: { elements: Array<{ name: { some?: string } | null }> },
|
||||
rows: unknown[][],
|
||||
}>> {
|
||||
const base = httpBase();
|
||||
if (!base) return [];
|
||||
if (!base) throw new StackAssertionError("SpacetimeDB not configured");
|
||||
const dbName = getEnvVariable("STACK_SPACETIMEDB_DB_NAME");
|
||||
const res = await fetch(`${base}/v1/database/${encodeURIComponent(dbName)}/sql`, {
|
||||
method: "POST",
|
||||
@ -122,11 +121,15 @@ export async function callSql<T = Record<string, unknown>>(sql: string): Promise
|
||||
const preview = (await res.text()).slice(0, 200);
|
||||
throw spacetimeDbError("SQL query failed", res.status, preview);
|
||||
}
|
||||
const parsed = await res.json() as Array<{
|
||||
return await res.json() as Array<{
|
||||
schema: { elements: Array<{ name: { some?: string } | null }> },
|
||||
rows: unknown[][],
|
||||
}>;
|
||||
if (parsed.length === 0) return [];
|
||||
}
|
||||
|
||||
export async function callSql<T = Record<string, unknown>>(sql: string): Promise<T[]> {
|
||||
const parsed = await withEnrollmentRetry((token) => rawCallSql(token, sql));
|
||||
if (parsed == null || parsed.length === 0) return [];
|
||||
const first = parsed[0];
|
||||
const cols = first.schema.elements.map(e => e.name?.some ?? "");
|
||||
return first.rows.map(row => {
|
||||
|
||||
26
apps/backend/src/lib/ai/tools/index.test.ts
Normal file
26
apps/backend/src/lib/ai/tools/index.test.ts
Normal file
@ -0,0 +1,26 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { TOOL_NAMES, validateToolNames } from "./index";
|
||||
|
||||
describe("validateToolNames stays in sync with TOOL_NAMES", () => {
|
||||
it("accepts the full TOOL_NAMES list", () => {
|
||||
expect(validateToolNames([...TOOL_NAMES])).toBe(true);
|
||||
});
|
||||
|
||||
it("accepts each tool name individually", () => {
|
||||
for (const name of TOOL_NAMES) {
|
||||
expect(validateToolNames([name])).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it("rejects unknown names", () => {
|
||||
expect(validateToolNames(["does-not-exist"])).toBe(false);
|
||||
expect(validateToolNames([...TOOL_NAMES, "does-not-exist"])).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects non-array inputs", () => {
|
||||
expect(validateToolNames("docs")).toBe(false);
|
||||
expect(validateToolNames(null)).toBe(false);
|
||||
expect(validateToolNames(undefined)).toBe(false);
|
||||
expect(validateToolNames({})).toBe(false);
|
||||
});
|
||||
});
|
||||
@ -81,22 +81,13 @@ export async function getTools(
|
||||
|
||||
/**
|
||||
* Validates that all requested tool names are valid.
|
||||
* Throws an error if any tool name is invalid.
|
||||
* Returns false if any tool name is not in `TOOL_NAMES`.
|
||||
*/
|
||||
export function validateToolNames(toolNames: unknown): toolNames is ToolName[] {
|
||||
if (!Array.isArray(toolNames)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const validToolNames: ToolName[] = [
|
||||
"docs",
|
||||
"sql-query",
|
||||
"create-email-theme",
|
||||
"create-email-template",
|
||||
"create-email-draft",
|
||||
"update-dashboard",
|
||||
"patch-dashboard"
|
||||
];
|
||||
|
||||
return toolNames.every((name) => validToolNames.includes(name as ToolName));
|
||||
return toolNames.every((name): name is ToolName =>
|
||||
typeof name === "string" && (TOOL_NAMES as readonly string[]).includes(name)
|
||||
);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user