Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions apps/web/src/app/api/openrouter/models/route.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import { beforeEach, describe, expect, test } from '@jest/globals';
import { NextRequest } from 'next/server';
import type { OpenRouterModel } from '@/lib/organizations/organization-types';
import { getEnhancedOpenRouterModels } from '@/lib/ai-gateway/providers/openrouter';
import { getUserFromAuth } from '@/lib/user/server';
import { getDirectByokModelsForUser } from '@/lib/ai-gateway/providers/direct-byok';
import { listAvailableExperimentModels } from '@/lib/ai-gateway/experiments/list-available-experiment-models';
import { addUserByokAvailability, getUserByokProviderIds } from '@/lib/ai-gateway/byok';
import { getAvailableModelsForOrganization } from '@/lib/organizations/organization-models';
import { GET } from './route';

jest.mock('@sentry/nextjs', () => ({ captureException: jest.fn() }));
jest.mock('@/lib/user/server', () => ({ getUserFromAuth: jest.fn() }));
jest.mock('@/lib/ai-gateway/providers/openrouter', () => ({
getEnhancedOpenRouterModels: jest.fn(),
}));
jest.mock('@/lib/ai-gateway/providers/direct-byok', () => ({
getDirectByokModelsForUser: jest.fn(),
}));
jest.mock('@/lib/ai-gateway/experiments/list-available-experiment-models', () => ({
listAvailableExperimentModels: jest.fn(),
}));
jest.mock('@/lib/ai-gateway/byok', () => ({
addUserByokAvailability: jest.fn(),
getUserByokProviderIds: jest.fn(),
}));
jest.mock('@/lib/organizations/organization-models', () => ({
getAvailableModelsForOrganization: jest.fn(),
}));
jest.mock('@/lib/drizzle', () => ({ readDb: {} }));

const mockedGetUserFromAuth = jest.mocked(getUserFromAuth);
const mockedGetEnhancedOpenRouterModels = jest.mocked(getEnhancedOpenRouterModels);
const mockedGetDirectByokModelsForUser = jest.mocked(getDirectByokModelsForUser);
const mockedListAvailableExperimentModels = jest.mocked(listAvailableExperimentModels);
const mockedAddUserByokAvailability = jest.mocked(addUserByokAvailability);
const mockedGetUserByokProviderIds = jest.mocked(getUserByokProviderIds);
const mockedGetAvailableModelsForOrganization = jest.mocked(getAvailableModelsForOrganization);

function makeModel(id: string): OpenRouterModel {
return {
id,
name: id,
created: 0,
description: '',
architecture: {
input_modalities: ['text'],
output_modalities: ['text'],
tokenizer: 'test',
},
top_provider: { is_moderated: false },
pricing: { prompt: '0', completion: '0' },
context_length: 0,
};
}

function request() {
return new NextRequest('http://localhost:3000/api/openrouter/models');
}

describe('GET /api/openrouter/models', () => {
beforeEach(() => {
jest.resetAllMocks();
mockedGetUserFromAuth.mockResolvedValue({
user: null,
organizationId: null,
authFailedResponse: null,
} as never);
mockedGetEnhancedOpenRouterModels.mockResolvedValue({ data: [makeModel('public/model')] });
mockedGetDirectByokModelsForUser.mockResolvedValue([]);
mockedListAvailableExperimentModels.mockResolvedValue([]);
mockedGetUserByokProviderIds.mockResolvedValue([]);
mockedGetAvailableModelsForOrganization.mockResolvedValue(null);
});

test('leaves BYOK availability undefined for unauthenticated requests', async () => {
const response = await GET(request());

expect(response.status).toBe(200);
await expect(response.json()).resolves.toEqual({ data: [makeModel('public/model')] });
expect(mockedGetUserByokProviderIds).not.toHaveBeenCalled();
expect(mockedAddUserByokAvailability).not.toHaveBeenCalled();
});

test('returns BYOK availability for every authenticated model', async () => {
const publicModel = makeModel('public/model');
const directModel = { ...makeModel('direct/model'), hasUserByokAvailable: true };
const experimentModel = makeModel('experiment/model');
mockedGetUserFromAuth.mockResolvedValue({
user: { id: 'user-id' },
organizationId: null,
authFailedResponse: null,
} as never);
mockedGetDirectByokModelsForUser.mockResolvedValue([directModel] as never);
mockedListAvailableExperimentModels.mockResolvedValue([experimentModel]);
mockedGetUserByokProviderIds.mockResolvedValue(['anthropic']);
mockedAddUserByokAvailability.mockResolvedValue([
{ ...publicModel, hasUserByokAvailable: true },
]);

const response = await GET(request());

expect(response.status).toBe(200);
await expect(response.json()).resolves.toEqual({
data: [
{ ...publicModel, hasUserByokAvailable: true },
directModel,
{ ...experimentModel, hasUserByokAvailable: false },
],
});
});
});
28 changes: 25 additions & 3 deletions apps/web/src/app/api/openrouter/models/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { getAvailableModelsForOrganization } from '@/lib/organizations/organizat
import { FEATURE_HEADER, validateFeatureHeader } from '@/lib/feature-detection';
import { filterByFeature } from '@/lib/ai-gateway/models';
import { listAvailableExperimentModels } from '@/lib/ai-gateway/experiments/list-available-experiment-models';
import { addUserByokAvailability, getUserByokProviderIds } from '@/lib/ai-gateway/byok';
import { readDb } from '@/lib/drizzle';

async function tryGetUserFromAuth() {
try {
Expand Down Expand Up @@ -43,10 +45,30 @@ export async function GET(
if (!Array.isArray(data.data)) {
return NextResponse.json(data);
}
const byokModels = auth?.user ? await getDirectByokModelsForUser(auth.user.id) : [];
const experimentModels = await listAvailableExperimentModels();
if (!auth?.user) {
const experimentModels = await listAvailableExperimentModels();
return NextResponse.json({
data: filterByFeature(data.data.concat(experimentModels), feature),
});
}

const [byokModels, experimentModels, enabledByokProviderIds] = await Promise.all([
getDirectByokModelsForUser(auth.user.id),
listAvailableExperimentModels(),
getUserByokProviderIds(readDb, auth.user.id),
]);
const modelsWithByokAvailability = await addUserByokAvailability(
data.data,
enabledByokProviderIds
);
return NextResponse.json({
data: filterByFeature(data.data.concat(byokModels, experimentModels), feature),
data: filterByFeature(
modelsWithByokAvailability.concat(
byokModels,
experimentModels.map(model => ({ ...model, hasUserByokAvailable: false }))
),
feature
),
});
} catch (error) {
captureException(error, {
Expand Down
47 changes: 47 additions & 0 deletions apps/web/src/lib/ai-gateway/byok/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import { isCodestralModel } from '@/lib/ai-gateway/providers/mistral';
import { mapModelIdToVercel } from '@/lib/ai-gateway/providers/vercel/mapModelIdToVercel';
import type { BYOKResult } from '@/lib/ai-gateway/providers/types';
import { getVercelModelsMetadata } from '@/lib/ai-gateway/providers/gateway-models-cache';
import type { OpenRouterModel } from '@/lib/organizations/organization-types';
import { isKiloExclusiveModel } from '@/lib/ai-gateway/models';

export async function getModelUserByokProviders(modelId: string): Promise<UserByokProviderId[]> {
const vercelModelMetadata = await getVercelModelsMetadata();
Expand All @@ -35,6 +37,51 @@ export async function getModelUserByokProviders(modelId: string): Promise<UserBy
return providers;
}

export async function getUserByokProviderIds(
fromDb: typeof db,
userId: string
): Promise<UserByokProviderId[]> {
const rows = await fromDb
.select({ provider_id: byok_api_keys.provider_id })
.from(byok_api_keys)
.where(and(eq(byok_api_keys.kilo_user_id, userId), eq(byok_api_keys.is_enabled, true)));

return rows.map(row => UserByokProviderIdSchema.parse(row.provider_id));
}

export async function getOrganizationByokProviderIds(
fromDb: typeof db,
organizationId: string
): Promise<UserByokProviderId[]> {
const rows = await fromDb
.select({ provider_id: byok_api_keys.provider_id })
.from(byok_api_keys)
.where(
and(eq(byok_api_keys.organization_id, organizationId), eq(byok_api_keys.is_enabled, true))
);

return rows.map(row => UserByokProviderIdSchema.parse(row.provider_id));
}

export async function addUserByokAvailability(
models: OpenRouterModel[],
enabledProviderIds: UserByokProviderId[]
): Promise<OpenRouterModel[]> {
const enabledProviders = new Set(enabledProviderIds);
return Promise.all(
models.map(async model => {
if (isKiloExclusiveModel(model.id)) {
return { ...model, hasUserByokAvailable: false };
}
const supportedProviders = await getModelUserByokProviders(model.id);
return {
...model,
hasUserByokAvailable: supportedProviders.some(provider => enabledProviders.has(provider)),
};
})
);
}

export function decryptByokRow({
encrypted_api_key,
provider_id,
Expand Down
1 change: 1 addition & 0 deletions apps/web/src/lib/ai-gateway/providers/direct-byok/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ function convertModel(
supported_parameters: ['max_tokens', 'temperature', 'tools', 'reasoning', 'include_reasoning'],
default_parameters: {},
preferredIndex: model.flags?.includes('recommended') ? preferredIndex : undefined,
hasUserByokAvailable: true,
opencode: {
ai_sdk_provider: getAiSdkProvider(id) ?? provider.default_ai_sdk_provider,
variants: getModelVariants(id),
Expand Down
21 changes: 19 additions & 2 deletions apps/web/src/lib/organizations/organization-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import { getDirectByokModelsForOrganization } from '@/lib/ai-gateway/providers/d
import { getOrganizationById } from '@/lib/organizations/organizations';
import { getEffectiveModelRestrictions } from '@/lib/organizations/model-restrictions';
import { listAvailableExperimentModels } from '@/lib/ai-gateway/experiments/list-available-experiment-models';
import { addUserByokAvailability, getOrganizationByokProviderIds } from '@/lib/ai-gateway/byok';
import { readDb } from '@/lib/drizzle';

export async function getAvailableModelsForOrganization(
organizationId: string
Expand Down Expand Up @@ -40,12 +42,27 @@ export async function getAvailableModelsForOrganization(
filteredModels = models;
}

filteredModels = await addUserByokAvailability(
filteredModels,
await getOrganizationByokProviderIds(readDb, organizationId)
);

if (organization.plan !== 'enterprise' && organization.settings.data_collection !== 'deny') {
filteredModels.push(...(await listAvailableExperimentModels()));
filteredModels.push(
...(await listAvailableExperimentModels()).map(model => ({
...model,
hasUserByokAvailable: false,
}))
);
}

filteredModels.push(...(await getDirectByokModelsForOrganization(organizationId)));
filteredModels.push(...(await listAvailableCustomLlms(organizationId)));
filteredModels.push(
...(await listAvailableCustomLlms(organizationId)).map(model => ({
...model,
hasUserByokAvailable: false,
}))
);

return {
...responseData,
Expand Down
1 change: 1 addition & 0 deletions apps/web/src/lib/organizations/organization-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ const OpenRouterModelSchema = z.object({
preferredIndex: z.number().optional(),
isFree: z.boolean().optional(),
mayTrainOnYourPrompts: z.boolean().optional(),
hasUserByokAvailable: z.boolean().optional(),
terminalBench: z
.object({
overallScore: z.number(),
Expand Down