diff --git a/packages/stack-server/prisma/schema.prisma b/packages/stack-server/prisma/schema.prisma index 98134a7c3..5b52ab70a 100644 --- a/packages/stack-server/prisma/schema.prisma +++ b/packages/stack-server/prisma/schema.prisma @@ -252,6 +252,8 @@ model OauthProviderConfig { createdAt DateTime @default(now()) updatedAt DateTime @updatedAt + enabled Boolean @default(true) + proxiedOauthConfig ProxiedOauthProviderConfig? standardOauthConfig StandardOauthProviderConfig? projectUserOauthAccounts ProjectUserOauthAccount[] diff --git a/packages/stack-server/src/app/(main)/(protected)/projects/[projectId]/auth/providers/page-client.tsx b/packages/stack-server/src/app/(main)/(protected)/projects/[projectId]/auth/providers/page-client.tsx index c0cad5bda..2dcc4ea45 100644 --- a/packages/stack-server/src/app/(main)/(protected)/projects/[projectId]/auth/providers/page-client.tsx +++ b/packages/stack-server/src/app/(main)/(protected)/projects/[projectId]/auth/providers/page-client.tsx @@ -15,12 +15,14 @@ export default function ProvidersClient() { const [invalidationCounter, setInvalidationCounter] = useState(0); const projectPromise = useStrictMemo(async () => { - return await stackAdminApp.getProject(); + return await stackAdminApp.getProject({ showDisabledOauth: true }); }, [stackAdminApp, invalidationCounter]); const project = use(projectPromise); const oauthProviders = project.evaluatedConfig.oauthProviders; + console.log(oauthProviders); + return ( <> @@ -58,11 +60,15 @@ export default function ProvidersClient() { key={id} id={id} provider={provider} - updateProvider={async (provider?: OauthProviderConfigJson) => { + updateProvider={async (provider: OauthProviderConfigJson) => { + const alreadyExist = oauthProviders.some((p) => p.id === id); + const newOauthProviders = oauthProviders.map((p) => p.id === id ? provider : p); + if (!alreadyExist) { + newOauthProviders.push(provider); + } + await stackAdminApp.updateProject({ - config: { - oauthProviders: oauthProviders.map((p) => p.id === id ? provider : p).filter((p) => p) as OauthProviderConfigJson[], - }, + config: { oauthProviders: newOauthProviders }, }); setInvalidationCounter((counter) => counter + 1); }} diff --git a/packages/stack-server/src/app/(main)/(protected)/projects/[projectId]/auth/providers/provider-accordion.tsx b/packages/stack-server/src/app/(main)/(protected)/projects/[projectId]/auth/providers/provider-accordion.tsx index 8f7d4e7dd..d09f291a0 100644 --- a/packages/stack-server/src/app/(main)/(protected)/projects/[projectId]/auth/providers/provider-accordion.tsx +++ b/packages/stack-server/src/app/(main)/(protected)/projects/[projectId]/auth/providers/provider-accordion.tsx @@ -32,7 +32,7 @@ export type ProviderType = typeof availableProviders[number]; type Props = { id: ProviderType, provider?: OauthProviderConfigJson, - updateProvider: (provider?: OauthProviderConfigJson) => Promise, + updateProvider: (provider: OauthProviderConfigJson) => Promise, }; function toTitle(id: ProviderType) { @@ -46,19 +46,22 @@ function toTitle(id: ProviderType) { function AccordionSummaryContent(props: Props) { const title = toTitle(props.id); - const [checked, setChecked] = useState(!!props.provider); + const enabled = props.provider?.enabled; + const [checked, setChecked] = useState(enabled); return ( - + { e.stopPropagation(); - if (checked) { - setChecked(false); - await props.updateProvider(); + setChecked(e.target.checked); + if (props.provider) { + await props.updateProvider({ ...props.provider, enabled: e.target.checked }); + } else { + await props.updateProvider({ id: props.id, type: toSharedProvider(props.id), enabled: e.target.checked }); } }} /> @@ -70,9 +73,9 @@ function AccordionSummaryContent(props: Props) { } export function ProviderAccordion(props: Props) { - if (!props.provider) { + if (!props.provider?.enabled) { return ( - + ); diff --git a/packages/stack-server/src/app/api/v1/projects/[projectId]/route.tsx b/packages/stack-server/src/app/api/v1/projects/[projectId]/route.tsx index 5485852b9..0d54db0fd 100644 --- a/packages/stack-server/src/app/api/v1/projects/[projectId]/route.tsx +++ b/packages/stack-server/src/app/api/v1/projects/[projectId]/route.tsx @@ -3,7 +3,7 @@ import * as yup from "yup"; import { StatusError } from "@stackframe/stack-shared/dist/utils/errors"; import { parseRequest, smartRouteHandler } from "@/lib/route-handlers"; import { checkApiKeySet, publishableClientKeyHeaderSchema, superSecretAdminKeyHeaderSchema } from "@/lib/api-keys"; -import { isProjectAdmin, updateProject } from "@/lib/projects"; +import { getProject, isProjectAdmin, updateProject } from "@/lib/projects"; import { ClientProjectJson, SharedProvider, StandardProvider, sharedProviders, standardProviders } from "@stackframe/stack-shared/dist/interface/clientInterface"; import { ProjectIdOrKeyInvalidErrorCode, KnownError } from "@stackframe/stack-shared/dist/utils/types"; import { OauthProviderUpdateOptions, ProjectUpdateOptions } from "@stackframe/stack-shared/dist/interface/adminInterface"; @@ -16,6 +16,7 @@ const putOrGetSchema = yup.object({ "x-stack-project-id": yup.string().required(), }).required(), body: yup.object({ + showDisabledOauth: yup.boolean().optional(), isProductionMode: yup.boolean().optional(), config: yup.object({ domains: yup.array(yup.object({ @@ -25,6 +26,7 @@ const putOrGetSchema = yup.object({ oauthProviders: yup.array( yup.object({ id: yup.string().required(), + enabled: yup.boolean().required(), type: yup.string().required(), clientId: yup.string().optional(), clientSecret: yup.string().optional(), @@ -47,7 +49,7 @@ const handler = smartRouteHandler(async (req: NextRequest, options: { params: { body, } = await parseRequest(req, putOrGetSchema); - const update = body ?? {}; + const { showDisabledOauth, ...update } = body ?? {}; const pkValid = await checkApiKeySet(projectId, { publishableClientKey }); const asValid = await isProjectAdmin(projectId, adminAccessToken); @@ -58,12 +60,13 @@ const handler = smartRouteHandler(async (req: NextRequest, options: { params: { config: update.config && { domains: update.config.domains, oauthProviders: update.config.oauthProviders && update.config.oauthProviders.map((provider) => { - if (sharedProviders.includes(provider.type)) { + if (sharedProviders.includes(provider.type as SharedProvider)) { return { id: provider.id, + enabled: provider.enabled, type: provider.type as SharedProvider, }; - } else if (standardProviders.includes(provider.type)) { + } else if (standardProviders.includes(provider.type as StandardProvider)) { if (!provider.clientId) { throw new StatusError(StatusError.BadRequest, "Missing clientId"); } @@ -73,6 +76,7 @@ const handler = smartRouteHandler(async (req: NextRequest, options: { params: { return { id: provider.id, + enabled: provider.enabled, type: provider.type as StandardProvider, clientId: provider.clientId, clientSecret: provider.clientSecret, @@ -90,13 +94,18 @@ const handler = smartRouteHandler(async (req: NextRequest, options: { params: { const project = await updateProject( projectId, typedUpdate, + showDisabledOauth, ); return NextResponse.json(project); } else if (asValid || pkValid) { if (Object.entries(update).length !== 0) { throw new StatusError(StatusError.Forbidden, "Can't update project with only publishable client key"); } - const project = await updateProject(projectId, {}); + if (showDisabledOauth) { + throw new StatusError(StatusError.Forbidden, "Can't show disabled oauth providers with only publishable client key"); + } + + const project = await getProject(projectId); if (!project) { throw new Error("Project not found but the API key was valid? Something weird happened"); } @@ -120,4 +129,5 @@ const handler = smartRouteHandler(async (req: NextRequest, options: { params: { }); export const GET = handler; export const PUT = handler; +export const POST = handler; export const DELETE = handler; diff --git a/packages/stack-server/src/lib/projects.tsx b/packages/stack-server/src/lib/projects.tsx index 967f6b808..0caa15204 100644 --- a/packages/stack-server/src/lib/projects.tsx +++ b/packages/stack-server/src/lib/projects.tsx @@ -7,6 +7,7 @@ import { generateUuid } from "@stackframe/stack-shared/dist/utils/uuids"; import { EmailConfigJson, SharedProvider, StandardProvider, sharedProviders, standardProviders } from "@stackframe/stack-shared/dist/interface/clientInterface"; import { typedToUppercase } from "@stackframe/stack-shared/dist/utils/strings"; import { OauthProviderUpdateOptions, ProjectUpdateOptions } from "@stackframe/stack-shared/dist/interface/adminInterface"; +import { throwErr } from "@stackframe/stack-shared/dist/utils/errors"; function toDBSharedProvider(type: SharedProvider): ProxiedOauthProviderType { @@ -208,13 +209,14 @@ export async function createProject( return projectJsonFromDbType(project); } -export async function getProject(projectId: string): Promise { - return await updateProject(projectId, {}); +export async function getProject(projectId: string, showDisabledOauth: boolean = false): Promise { + return await updateProject(projectId, {}, showDisabledOauth); } export async function updateProject( projectId: string, - options: ProjectUpdateOptions + options: ProjectUpdateOptions, + showDisabledOauth: boolean = false ): Promise { // TODO: Validate production mode consistency const transaction = []; @@ -229,11 +231,6 @@ export async function updateProject( } if (options.config?.domains) { - // Fetch current domains - const currentDomains = await prismaClient.projectDomain.findMany({ - where: { projectConfigId: project.config.id }, - }); - const newDomains = options.config.domains; // delete existing domains @@ -253,46 +250,109 @@ export async function updateProject( }); } - if (options.config?.oauthProviders) { - transaction.push(prismaClient.oauthProviderConfig.deleteMany({ - where: { projectConfigId: project.config.id }, - })); - - options.config.oauthProviders.forEach(providerConfig => { - if (sharedProviders.includes(providerConfig.type as SharedProvider)) { - transaction.push(prismaClient.oauthProviderConfig.create({ - data: { - projectConfigId: project.config.id, - id: providerConfig.id, - proxiedOauthConfig: { - create: { - type: toDBSharedProvider(providerConfig.type as SharedProvider), - }, - }, - }, - })); - } else if (standardProviders.includes(providerConfig.type as StandardProvider)) { - // make typescript happy - const typedProviderConfig = providerConfig as OauthProviderUpdateOptions & { type: StandardProvider }; - - transaction.push(prismaClient.oauthProviderConfig.create({ - data: { - projectConfigId: project.config.id, - id: providerConfig.id, - standardOauthConfig: { - create: { - type: toDBStandardProvider(providerConfig.type as StandardProvider), - clientId: typedProviderConfig.clientId, - clientSecret: typedProviderConfig.clientSecret, - tenantId: typedProviderConfig.tenantId, - }, - }, - }, - })); - } else { - console.error(`Invalid provider type '${providerConfig.type}'`); + const oauthProvidersUpdate = options.config?.oauthProviders; + if (oauthProvidersUpdate) { + const oldProviders = project.config.oauthProviderConfigs; + const providerMap = new Map(oldProviders.map((provider) => [ + provider.id, + { + providerUpdate: oauthProvidersUpdate.find((p) => p.id === provider.id) ?? throwErr(`Missing provider update for provider '${provider.id}'`), + oldProvider: provider, } - }); + ])); + + const newProviders = oauthProvidersUpdate.map((providerUpdate) => ({ + id: providerUpdate.id, + update: providerUpdate + })).filter(({ id }) => !providerMap.has(id)); + + // Update existing providers + for (const [id, { providerUpdate, oldProvider }] of providerMap) { + let providerConfigUpdate; + if (sharedProviders.includes(providerUpdate.type as SharedProvider)) { + providerConfigUpdate = { + proxiedOauthConfig: { + update: { + type: toDBSharedProvider(providerUpdate.type as SharedProvider), + }, + }, + }; + + if (oldProvider.standardOauthConfig) { + transaction.push(prismaClient.standardOauthProviderConfig.delete({ + where: { projectConfigId_id: { projectConfigId: project.config.id, id } }, + })); + } + + } else if (standardProviders.includes(providerUpdate.type as StandardProvider)) { + const typedProviderConfig = providerUpdate as OauthProviderUpdateOptions & { type: StandardProvider }; + + providerConfigUpdate = { + standardOauthConfig: { + update: { + type: toDBStandardProvider(providerUpdate.type as StandardProvider), + clientId: typedProviderConfig.clientId, + clientSecret: typedProviderConfig.clientSecret, + tenantId: typedProviderConfig.tenantId, + }, + }, + }; + + if (oldProvider.proxiedOauthConfig) { + transaction.push(prismaClient.proxiedOauthProviderConfig.delete({ + where: { projectConfigId_id: { projectConfigId: project.config.id, id } }, + })); + } + } else { + console.error(`Invalid provider type '${providerUpdate.type}'`); + } + + transaction.push(prismaClient.oauthProviderConfig.update({ + where: { projectConfigId_id: { projectConfigId: project.config.id, id } }, + data: { + enabled: providerUpdate.enabled, + ...providerConfigUpdate, + }, + })); + } + + // Create new providers + for (const provider of newProviders) { + let providerConfigData; + if (sharedProviders.includes(provider.update.type as SharedProvider)) { + providerConfigData = { + proxiedOauthConfig: { + create: { + type: toDBSharedProvider(provider.update.type as SharedProvider), + }, + }, + }; + } else if (standardProviders.includes(provider.update.type as StandardProvider)) { + const typedProviderConfig = provider.update as OauthProviderUpdateOptions & { type: StandardProvider }; + + providerConfigData = { + standardOauthConfig: { + create: { + type: toDBStandardProvider(provider.update.type as StandardProvider), + clientId: typedProviderConfig.clientId, + clientSecret: typedProviderConfig.clientSecret, + tenantId: typedProviderConfig.tenantId, + }, + }, + }; + } else { + console.error(`Invalid provider type '${provider.update.type}'`); + } + + transaction.push(prismaClient.oauthProviderConfig.create({ + data: { + id: provider.id, + projectConfigId: project.config.id, + enabled: provider.update.enabled, + ...providerConfigData, + }, + })); + } } if (options.config?.credentialEnabled !== undefined) { @@ -311,7 +371,7 @@ export async function updateProject( })); } - const result = await prismaClient.$transaction(transaction); + await prismaClient.$transaction(transaction); const updatedProject = await prismaClient.project.findUnique({ where: { id: projectId }, @@ -322,10 +382,10 @@ export async function updateProject( return null; } - return projectJsonFromDbType(updatedProject); + return projectJsonFromDbType(updatedProject, showDisabledOauth); } -function projectJsonFromDbType(project: ProjectDB): ProjectJson { +function projectJsonFromDbType(project: ProjectDB, showDisabledOauth: boolean = false): ProjectJson { let emailConfig: EmailConfigJson | undefined; const emailServiceConfig = project.config.emailServiceConfig; if (emailServiceConfig) { @@ -364,15 +424,20 @@ function projectJsonFromDbType(project: ProjectDB): ProjectJson { handlerPath: domain.handlerPath, })), oauthProviders: project.config.oauthProviderConfigs.flatMap((provider): OauthProviderConfigJson[] => { + if (!showDisabledOauth && !provider.enabled) { + return []; + } if (provider.proxiedOauthConfig) { return [{ id: provider.id, + enabled: provider.enabled, type: fromDBSharedProvider(provider.proxiedOauthConfig.type), }]; } if (provider.standardOauthConfig) { return [{ id: provider.id, + enabled: provider.enabled, type: fromDBStandardProvider(provider.standardOauthConfig.type), clientId: provider.standardOauthConfig.clientId, clientSecret: provider.standardOauthConfig.clientSecret, diff --git a/packages/stack-shared/src/interface/adminInterface.ts b/packages/stack-shared/src/interface/adminInterface.ts index c0072fe5e..2519c0cd8 100644 --- a/packages/stack-shared/src/interface/adminInterface.ts +++ b/packages/stack-shared/src/interface/adminInterface.ts @@ -18,6 +18,7 @@ export type AdminAuthApplicationOptions = Readonly< export type OauthProviderUpdateOptions = { id: string, + enabled: boolean, } & ( | { type: SharedProvider, @@ -136,10 +137,16 @@ export class StackAdminInterface extends StackServerInterface { ]); } - async getProject(): Promise { + async getProject(options?: { showDisabledOauth?: boolean }): Promise { const response = await this.sendAdminRequest( "/projects/" + encodeURIComponent(this.projectId), - {}, + { + method: "POST", + headers: { + "content-type": "application/json", + }, + body: JSON.stringify(options || {}), + }, null, ); return await response.json(); diff --git a/packages/stack-shared/src/interface/clientInterface.ts b/packages/stack-shared/src/interface/clientInterface.ts index f74636146..3d370f758 100644 --- a/packages/stack-shared/src/interface/clientInterface.ts +++ b/packages/stack-shared/src/interface/clientInterface.ts @@ -110,6 +110,7 @@ export type ProjectJson = Readonly<{ export type OauthProviderConfigJson = { id: string, + enabled: boolean, } & ( | { type: SharedProvider } | {