import { prisma } from "@langfuse/shared/src/db"; import { withMiddlewares } from "@/src/features/public-api/server/withMiddlewares"; import { clearModelCacheForProject } from "@langfuse/shared/src/server"; import { createAuthedProjectAPIRoute } from "@/src/features/public-api/server/createAuthedProjectAPIRoute"; import { GetModelsV1Query, GetModelsV1Response, PostModelsV1Body, PostModelsV1Response, prismaToApiModelDefinition, } from "@/src/features/public-api/types/models"; import { InvalidRequestError } from "@langfuse/shared"; import { isValidPostgresRegex } from "@/src/features/models/server/isValidPostgresRegex"; import { auditLog } from "@/src/features/audit-logs/auditLog"; import { type Decimal } from "decimal.js"; export default withMiddlewares({ GET: createAuthedProjectAPIRoute({ name: "Get model definitions", querySchema: GetModelsV1Query, responseSchema: GetModelsV1Response, fn: async ({ query, auth }) => { const models = await prisma.model.findMany({ where: { OR: [ { projectId: auth.scope.projectId, }, { projectId: null, }, ], }, orderBy: [ { modelName: "asc" }, { unit: "asc" }, { startDate: { sort: "desc", nulls: "last", }, }, ], include: { Price: { select: { usageType: true, price: true }, }, }, take: query.limit, skip: (query.page - 1) * query.limit, }); const totalItems = await prisma.model.count({ where: { OR: [ { projectId: auth.scope.projectId, }, { projectId: null, }, ], }, }); return { data: models.map(prismaToApiModelDefinition), meta: { page: query.page, limit: query.limit, totalItems, totalPages: Math.ceil(totalItems / query.limit), }, }; }, }), POST: createAuthedProjectAPIRoute({ name: "Create custom model definition", bodySchema: PostModelsV1Body, responseSchema: PostModelsV1Response, fn: async ({ body, auth }) => { const validRegex = await isValidPostgresRegex(body.matchPattern, prisma); if (!validRegex) { throw new InvalidRequestError( "matchPattern is not a valid regex pattern (Postgres)", ); } const { tokenizerConfig, ...rest } = body; const model = await prisma.$transaction(async (tx) => { const createdModel = await tx.model.create({ data: { ...rest, tokenizerConfig: tokenizerConfig ?? undefined, projectId: auth.scope.projectId, }, }); const prices = [ { usageType: "input", price: body.inputPrice }, { usageType: "output", price: body.outputPrice }, { usageType: "total", price: body.totalPrice }, ]; await Promise.all( prices .filter(({ price }) => price != null) .map(({ usageType, price }) => tx.price.create({ data: { modelId: createdModel.id, projectId: createdModel.projectId, usageType, price: price as number, // type guard checked in array filter }, }), ), ); await auditLog({ action: "create", resourceType: "model", resourceId: createdModel.id, projectId: auth.scope.projectId, orgId: auth.scope.orgId, apiKeyId: auth.scope.apiKeyId, after: createdModel, }); return createdModel; }); // Clear model cache for the project after successful creation await clearModelCacheForProject(auth.scope.projectId); return prismaToApiModelDefinition({ ...model, Price: (["inputPrice", "outputPrice", "totalPrice"] as const) .filter((key) => model[key] != null) .map((key) => ({ usageType: key.split("Price")[0], price: model[key] as Decimal, })), }); }, }), });