diff --git a/apps/backend/src/lib/plan-usage.ts b/apps/backend/src/lib/plan-usage.ts index 50b8be096..99d7de8f2 100644 --- a/apps/backend/src/lib/plan-usage.ts +++ b/apps/backend/src/lib/plan-usage.ts @@ -12,6 +12,7 @@ import { getPrismaClientForTenancy, getPrismaSchemaForTenancy, globalPrismaClien import { BASE_PLAN_IDS_BY_TIER, ITEM_IDS, PLAN_LIMITS, UNLIMITED, type ItemId, type PlanId } from "@hexclave/shared/dist/plans"; import type { PlanUsageResponse } from "@hexclave/shared/dist/interface/admin-interface"; import { HexclaveAssertionError, throwErr } from "@hexclave/shared/dist/utils/errors"; +import { mapWithConcurrency } from "@hexclave/shared/dist/utils/promises"; import type { SubscriptionRow } from "./payments/schema/types"; type PlanUsageKind = PlanUsageResponse["rows"][number]["kind"]; @@ -207,29 +208,72 @@ async function getOwnerTeamDisplayName(internalTenancy: Tenancy, ownerTeamId: st return team?.displayName ?? throwErr(`Owner team ${ownerTeamId} not found in the internal tenancy`); } -async function countMeteredUsageForTenancy(tenancyId: string, period: UsagePeriod): Promise { - const tenancy = await getTenancy(tenancyId) ?? throwErr(`Tenancy ${tenancyId} not found while counting plan usage`); - const schema = await getPrismaSchemaForTenancy(tenancy); - const prisma = await getPrismaClientForTenancy(tenancy); - const rows = await prisma.$replica().$queryRaw>` +type TenancyPrismaClient = Awaited>; + +type TenancyMeteredUsageGroup = { + prisma: TenancyPrismaClient, + schema: string, + tenancyIds: string[], +}; + +// Tenancies can route to different source-of-truth databases/schemas, so we can't assume a single +// query covers every tenancy. We group tenancies that share a (client, schema) and run one aggregate +// COUNT per group: the common case (all projects on one database) collapses to a single round trip, +// while multi-database teams fan out to one query per distinct database instead of one per tenancy. +async function groupTenanciesByMeteredUsageSource(tenancyIds: string[]): Promise { + const resolved = await mapWithConcurrency(tenancyIds, PLAN_USAGE_TENANCY_COUNTER_CONCURRENCY, async (tenancyId) => { + const tenancy = await getTenancy(tenancyId) ?? throwErr(`Tenancy ${tenancyId} not found while counting plan usage`); + const [schema, prisma] = await Promise.all([ + getPrismaSchemaForTenancy(tenancy), + getPrismaClientForTenancy(tenancy), + ]); + return { tenancyId: tenancy.id, schema, prisma }; + }); + + const byClient = new Map>(); + for (const { tenancyId, schema, prisma } of resolved) { + let bySchema = byClient.get(prisma); + if (bySchema == null) { + bySchema = new Map(); + byClient.set(prisma, bySchema); + } + const existing = bySchema.get(schema); + if (existing == null) { + bySchema.set(schema, [tenancyId]); + } else { + existing.push(tenancyId); + } + } + + const groups: TenancyMeteredUsageGroup[] = []; + for (const [prisma, bySchema] of byClient) { + for (const [schema, groupTenancyIds] of bySchema) { + groups.push({ prisma, schema, tenancyIds: groupTenancyIds }); + } + } + return groups; +} + +async function countMeteredUsageForGroup(group: TenancyMeteredUsageGroup, period: UsagePeriod): Promise { + const rows = await group.prisma.$replica().$queryRaw>` SELECT ( SELECT COUNT(*)::int - FROM ${sqlQuoteIdent(schema)}."EmailOutbox" - WHERE "tenancyId" = ${tenancy.id}::uuid + FROM ${sqlQuoteIdent(group.schema)}."EmailOutbox" + WHERE "tenancyId" = ANY(${group.tenancyIds}::uuid[]) AND "startedSendingAt" IS NOT NULL AND "startedSendingAt" >= ${period.start} AND "startedSendingAt" < ${period.end} ) AS "emails", ( SELECT COUNT(*)::int - FROM ${sqlQuoteIdent(schema)}."SessionReplay" - WHERE "tenancyId" = ${tenancy.id}::uuid + FROM ${sqlQuoteIdent(group.schema)}."SessionReplay" + WHERE "tenancyId" = ANY(${group.tenancyIds}::uuid[]) AND "startedAt" >= ${period.start} AND "startedAt" < ${period.end} ) AS "sessionReplays" `; - const row = rows[0] ?? throwErr(`Missing plan usage count row for tenancy ${tenancy.id}`); + const row = rows[0] ?? throwErr(`Missing plan usage count row for metered usage group on schema ${group.schema}`); return { emails: Number(row.emails), sessionReplays: Number(row.sessionReplays), @@ -237,27 +281,26 @@ async function countMeteredUsageForTenancy(tenancyId: string, period: UsagePerio } async function sumTenancyMeteredUsage(tenancyIds: string[], period: UsagePeriod): Promise { - const totals: TenancyMeteredUsage = { - emails: 0, - sessionReplays: 0, - }; - let nextIndex = 0; - - // Keep this page from turning a team with many tenancies into an unbounded burst of replica COUNTs. - async function worker(): Promise { - while (nextIndex < tenancyIds.length) { - const index = nextIndex; - nextIndex++; - const tenancyId = tenancyIds[index] ?? throwErr(`Missing tenancy ID at index ${index} while counting plan usage`); - const usage = await countMeteredUsageForTenancy(tenancyId, period); - totals.emails += usage.emails; - totals.sessionReplays += usage.sessionReplays; - } + if (tenancyIds.length === 0) { + return { emails: 0, sessionReplays: 0 }; } - const workerCount = Math.min(PLAN_USAGE_TENANCY_COUNTER_CONCURRENCY, tenancyIds.length); - await Promise.all(Array.from({ length: workerCount }, async () => await worker())); - return totals; + const groups = await groupTenanciesByMeteredUsageSource(tenancyIds); + // The group count equals the number of distinct databases (usually 1), so concurrency mostly guards + // the pathological multi-database team rather than the per-tenancy fan-out it used to. + const subtotals = await mapWithConcurrency( + groups, + PLAN_USAGE_TENANCY_COUNTER_CONCURRENCY, + async (group) => await countMeteredUsageForGroup(group, period), + ); + + return subtotals.reduce( + (totals, subtotal) => ({ + emails: totals.emails + subtotal.emails, + sessionReplays: totals.sessionReplays + subtotal.sessionReplays, + }), + { emails: 0, sessionReplays: 0 }, + ); } async function countAnalyticsEventsForProjects(projectIds: string[], period: UsagePeriod): Promise { diff --git a/packages/shared/src/utils/promises.tsx b/packages/shared/src/utils/promises.tsx index 0cfb7df3b..580dade88 100644 --- a/packages/shared/src/utils/promises.tsx +++ b/packages/shared/src/utils/promises.tsx @@ -434,6 +434,64 @@ import.meta.vitest?.test("timeoutThrow", async ({ expect }) => { }); +/** + * Maps over `items` with `fn`, running at most `concurrency` invocations at a time. + * + * Unlike `Promise.all(items.map(fn))`, this bounds the number of in-flight + * promises, which matters when `fn` hits a shared resource (e.g. a database) and + * an unbounded fan-out could exhaust connections or overload a replica. Results + * are returned in input order regardless of completion order, and the first + * rejection propagates (in-flight workers still settle but their results are + * discarded). + */ +export async function mapWithConcurrency( + items: readonly T[], + concurrency: number, + fn: (item: T, index: number) => Promise, +): Promise { + if (!Number.isInteger(concurrency) || concurrency < 1) { + throw new HexclaveAssertionError(`mapWithConcurrency requires a positive integer concurrency, got ${concurrency}`); + } + const results = new Array(items.length); + let nextIndex = 0; + const worker = async () => { + while (true) { + // Claim an index synchronously before awaiting so workers never process the same item. + const index = nextIndex++; + if (index >= items.length) return; + results[index] = await fn(items[index]!, index); + } + }; + const workerCount = Math.min(concurrency, items.length); + await Promise.all(Array.from({ length: workerCount }, () => worker())); + return results; +} +import.meta.vitest?.test("mapWithConcurrency", async ({ expect }) => { + // Preserves input order regardless of completion order. + const ordered = await mapWithConcurrency([30, 10, 20], 3, async (ms, index) => { + await wait(ms); + return `${index}:${ms}`; + }); + expect(ordered).toEqual(["0:30", "1:10", "2:20"]); + + // Never exceeds the configured concurrency. + let inFlight = 0; + let maxInFlight = 0; + await mapWithConcurrency(Array.from({ length: 10 }, (_, i) => i), 3, async () => { + inFlight++; + maxInFlight = Math.max(maxInFlight, inFlight); + await wait(5); + inFlight--; + }); + expect(maxInFlight).toBe(3); + + // Empty input spawns no workers and returns an empty array. + expect(await mapWithConcurrency([], 4, async () => 1)).toEqual([]); + + // Invalid concurrency fails loudly. + await expect(mapWithConcurrency([1], 0, async (x) => x)).rejects.toThrow("positive integer concurrency"); +}); + export type RateLimitOptions = { /** * The number of requests to process in parallel. Currently only 1 is supported.