diff --git a/apps/web/src/app/api/openrouter/models/route.test.ts b/apps/web/src/app/api/openrouter/models/route.test.ts new file mode 100644 index 0000000000..c0b66862d9 --- /dev/null +++ b/apps/web/src/app/api/openrouter/models/route.test.ts @@ -0,0 +1,112 @@ +import { beforeEach, describe, expect, test } from '@jest/globals'; +import { NextRequest } from 'next/server'; +import type { OpenRouterModel } from '@/lib/organizations/organization-types'; +import { getEnhancedOpenRouterModels } from '@/lib/ai-gateway/providers/openrouter'; +import { getUserFromAuth } from '@/lib/user/server'; +import { getDirectByokModelsForUser } from '@/lib/ai-gateway/providers/direct-byok'; +import { listAvailableExperimentModels } from '@/lib/ai-gateway/experiments/list-available-experiment-models'; +import { addUserByokAvailability, getUserByokProviderIds } from '@/lib/ai-gateway/byok'; +import { getAvailableModelsForOrganization } from '@/lib/organizations/organization-models'; +import { GET } from './route'; + +jest.mock('@sentry/nextjs', () => ({ captureException: jest.fn() })); +jest.mock('@/lib/user/server', () => ({ getUserFromAuth: jest.fn() })); +jest.mock('@/lib/ai-gateway/providers/openrouter', () => ({ + getEnhancedOpenRouterModels: jest.fn(), +})); +jest.mock('@/lib/ai-gateway/providers/direct-byok', () => ({ + getDirectByokModelsForUser: jest.fn(), +})); +jest.mock('@/lib/ai-gateway/experiments/list-available-experiment-models', () => ({ + listAvailableExperimentModels: jest.fn(), +})); +jest.mock('@/lib/ai-gateway/byok', () => ({ + addUserByokAvailability: jest.fn(), + getUserByokProviderIds: jest.fn(), +})); +jest.mock('@/lib/organizations/organization-models', () => ({ + getAvailableModelsForOrganization: jest.fn(), +})); +jest.mock('@/lib/drizzle', () => ({ readDb: {} })); + +const mockedGetUserFromAuth = jest.mocked(getUserFromAuth); +const mockedGetEnhancedOpenRouterModels = jest.mocked(getEnhancedOpenRouterModels); +const mockedGetDirectByokModelsForUser = jest.mocked(getDirectByokModelsForUser); +const mockedListAvailableExperimentModels = jest.mocked(listAvailableExperimentModels); +const mockedAddUserByokAvailability = jest.mocked(addUserByokAvailability); +const mockedGetUserByokProviderIds = jest.mocked(getUserByokProviderIds); +const mockedGetAvailableModelsForOrganization = jest.mocked(getAvailableModelsForOrganization); + +function makeModel(id: string): OpenRouterModel { + return { + id, + name: id, + created: 0, + description: '', + architecture: { + input_modalities: ['text'], + output_modalities: ['text'], + tokenizer: 'test', + }, + top_provider: { is_moderated: false }, + pricing: { prompt: '0', completion: '0' }, + context_length: 0, + }; +} + +function request() { + return new NextRequest('http://localhost:3000/api/openrouter/models'); +} + +describe('GET /api/openrouter/models', () => { + beforeEach(() => { + jest.resetAllMocks(); + mockedGetUserFromAuth.mockResolvedValue({ + user: null, + organizationId: null, + authFailedResponse: null, + } as never); + mockedGetEnhancedOpenRouterModels.mockResolvedValue({ data: [makeModel('public/model')] }); + mockedGetDirectByokModelsForUser.mockResolvedValue([]); + mockedListAvailableExperimentModels.mockResolvedValue([]); + mockedGetUserByokProviderIds.mockResolvedValue([]); + mockedGetAvailableModelsForOrganization.mockResolvedValue(null); + }); + + test('leaves BYOK availability undefined for unauthenticated requests', async () => { + const response = await GET(request()); + + expect(response.status).toBe(200); + await expect(response.json()).resolves.toEqual({ data: [makeModel('public/model')] }); + expect(mockedGetUserByokProviderIds).not.toHaveBeenCalled(); + expect(mockedAddUserByokAvailability).not.toHaveBeenCalled(); + }); + + test('returns BYOK availability for every authenticated model', async () => { + const publicModel = makeModel('public/model'); + const directModel = { ...makeModel('direct/model'), hasUserByokAvailable: true }; + const experimentModel = makeModel('experiment/model'); + mockedGetUserFromAuth.mockResolvedValue({ + user: { id: 'user-id' }, + organizationId: null, + authFailedResponse: null, + } as never); + mockedGetDirectByokModelsForUser.mockResolvedValue([directModel] as never); + mockedListAvailableExperimentModels.mockResolvedValue([experimentModel]); + mockedGetUserByokProviderIds.mockResolvedValue(['anthropic']); + mockedAddUserByokAvailability.mockResolvedValue([ + { ...publicModel, hasUserByokAvailable: true }, + ]); + + const response = await GET(request()); + + expect(response.status).toBe(200); + await expect(response.json()).resolves.toEqual({ + data: [ + { ...publicModel, hasUserByokAvailable: true }, + directModel, + { ...experimentModel, hasUserByokAvailable: false }, + ], + }); + }); +}); diff --git a/apps/web/src/app/api/openrouter/models/route.ts b/apps/web/src/app/api/openrouter/models/route.ts index e42b9fc3c7..ba03671cc0 100644 --- a/apps/web/src/app/api/openrouter/models/route.ts +++ b/apps/web/src/app/api/openrouter/models/route.ts @@ -9,6 +9,8 @@ import { getAvailableModelsForOrganization } from '@/lib/organizations/organizat import { FEATURE_HEADER, validateFeatureHeader } from '@/lib/feature-detection'; import { filterByFeature } from '@/lib/ai-gateway/models'; import { listAvailableExperimentModels } from '@/lib/ai-gateway/experiments/list-available-experiment-models'; +import { addUserByokAvailability, getUserByokProviderIds } from '@/lib/ai-gateway/byok'; +import { readDb } from '@/lib/drizzle'; async function tryGetUserFromAuth() { try { @@ -43,10 +45,30 @@ export async function GET( if (!Array.isArray(data.data)) { return NextResponse.json(data); } - const byokModels = auth?.user ? await getDirectByokModelsForUser(auth.user.id) : []; - const experimentModels = await listAvailableExperimentModels(); + if (!auth?.user) { + const experimentModels = await listAvailableExperimentModels(); + return NextResponse.json({ + data: filterByFeature(data.data.concat(experimentModels), feature), + }); + } + + const [byokModels, experimentModels, enabledByokProviderIds] = await Promise.all([ + getDirectByokModelsForUser(auth.user.id), + listAvailableExperimentModels(), + getUserByokProviderIds(readDb, auth.user.id), + ]); + const modelsWithByokAvailability = await addUserByokAvailability( + data.data, + enabledByokProviderIds + ); return NextResponse.json({ - data: filterByFeature(data.data.concat(byokModels, experimentModels), feature), + data: filterByFeature( + modelsWithByokAvailability.concat( + byokModels, + experimentModels.map(model => ({ ...model, hasUserByokAvailable: false })) + ), + feature + ), }); } catch (error) { captureException(error, { diff --git a/apps/web/src/lib/ai-gateway/byok/index.ts b/apps/web/src/lib/ai-gateway/byok/index.ts index c2d3a15cc0..626e52470d 100644 --- a/apps/web/src/lib/ai-gateway/byok/index.ts +++ b/apps/web/src/lib/ai-gateway/byok/index.ts @@ -13,6 +13,8 @@ import { isCodestralModel } from '@/lib/ai-gateway/providers/mistral'; import { mapModelIdToVercel } from '@/lib/ai-gateway/providers/vercel/mapModelIdToVercel'; import type { BYOKResult } from '@/lib/ai-gateway/providers/types'; import { getVercelModelsMetadata } from '@/lib/ai-gateway/providers/gateway-models-cache'; +import type { OpenRouterModel } from '@/lib/organizations/organization-types'; +import { isKiloExclusiveModel } from '@/lib/ai-gateway/models'; export async function getModelUserByokProviders(modelId: string): Promise { const vercelModelMetadata = await getVercelModelsMetadata(); @@ -35,6 +37,51 @@ export async function getModelUserByokProviders(modelId: string): Promise { + const rows = await fromDb + .select({ provider_id: byok_api_keys.provider_id }) + .from(byok_api_keys) + .where(and(eq(byok_api_keys.kilo_user_id, userId), eq(byok_api_keys.is_enabled, true))); + + return rows.map(row => UserByokProviderIdSchema.parse(row.provider_id)); +} + +export async function getOrganizationByokProviderIds( + fromDb: typeof db, + organizationId: string +): Promise { + const rows = await fromDb + .select({ provider_id: byok_api_keys.provider_id }) + .from(byok_api_keys) + .where( + and(eq(byok_api_keys.organization_id, organizationId), eq(byok_api_keys.is_enabled, true)) + ); + + return rows.map(row => UserByokProviderIdSchema.parse(row.provider_id)); +} + +export async function addUserByokAvailability( + models: OpenRouterModel[], + enabledProviderIds: UserByokProviderId[] +): Promise { + const enabledProviders = new Set(enabledProviderIds); + return Promise.all( + models.map(async model => { + if (isKiloExclusiveModel(model.id)) { + return { ...model, hasUserByokAvailable: false }; + } + const supportedProviders = await getModelUserByokProviders(model.id); + return { + ...model, + hasUserByokAvailable: supportedProviders.some(provider => enabledProviders.has(provider)), + }; + }) + ); +} + export function decryptByokRow({ encrypted_api_key, provider_id, diff --git a/apps/web/src/lib/ai-gateway/providers/direct-byok/index.ts b/apps/web/src/lib/ai-gateway/providers/direct-byok/index.ts index 078dd9f3b0..851574b9f7 100644 --- a/apps/web/src/lib/ai-gateway/providers/direct-byok/index.ts +++ b/apps/web/src/lib/ai-gateway/providers/direct-byok/index.ts @@ -57,6 +57,7 @@ function convertModel( supported_parameters: ['max_tokens', 'temperature', 'tools', 'reasoning', 'include_reasoning'], default_parameters: {}, preferredIndex: model.flags?.includes('recommended') ? preferredIndex : undefined, + hasUserByokAvailable: true, opencode: { ai_sdk_provider: getAiSdkProvider(id) ?? provider.default_ai_sdk_provider, variants: getModelVariants(id), diff --git a/apps/web/src/lib/organizations/organization-models.ts b/apps/web/src/lib/organizations/organization-models.ts index fd604b91fd..27b1e42db3 100644 --- a/apps/web/src/lib/organizations/organization-models.ts +++ b/apps/web/src/lib/organizations/organization-models.ts @@ -10,6 +10,8 @@ import { getDirectByokModelsForOrganization } from '@/lib/ai-gateway/providers/d import { getOrganizationById } from '@/lib/organizations/organizations'; import { getEffectiveModelRestrictions } from '@/lib/organizations/model-restrictions'; import { listAvailableExperimentModels } from '@/lib/ai-gateway/experiments/list-available-experiment-models'; +import { addUserByokAvailability, getOrganizationByokProviderIds } from '@/lib/ai-gateway/byok'; +import { readDb } from '@/lib/drizzle'; export async function getAvailableModelsForOrganization( organizationId: string @@ -40,12 +42,27 @@ export async function getAvailableModelsForOrganization( filteredModels = models; } + filteredModels = await addUserByokAvailability( + filteredModels, + await getOrganizationByokProviderIds(readDb, organizationId) + ); + if (organization.plan !== 'enterprise' && organization.settings.data_collection !== 'deny') { - filteredModels.push(...(await listAvailableExperimentModels())); + filteredModels.push( + ...(await listAvailableExperimentModels()).map(model => ({ + ...model, + hasUserByokAvailable: false, + })) + ); } filteredModels.push(...(await getDirectByokModelsForOrganization(organizationId))); - filteredModels.push(...(await listAvailableCustomLlms(organizationId))); + filteredModels.push( + ...(await listAvailableCustomLlms(organizationId)).map(model => ({ + ...model, + hasUserByokAvailable: false, + })) + ); return { ...responseData, diff --git a/apps/web/src/lib/organizations/organization-types.ts b/apps/web/src/lib/organizations/organization-types.ts index 1888fe9c73..aa26a84f8c 100644 --- a/apps/web/src/lib/organizations/organization-types.ts +++ b/apps/web/src/lib/organizations/organization-types.ts @@ -200,6 +200,7 @@ const OpenRouterModelSchema = z.object({ preferredIndex: z.number().optional(), isFree: z.boolean().optional(), mayTrainOnYourPrompts: z.boolean().optional(), + hasUserByokAvailable: z.boolean().optional(), terminalBench: z .object({ overallScore: z.number(),