diff --git a/apps/api/src/app/api/chat/[sessionId]/route.ts b/apps/api/src/app/api/chat/[sessionId]/route.ts index 32efba3a1..eaecf9380 100644 --- a/apps/api/src/app/api/chat/[sessionId]/route.ts +++ b/apps/api/src/app/api/chat/[sessionId]/route.ts @@ -1,7 +1,7 @@ import { db } from "@o3dotdev/code-db/client"; import { chatSessions } from "@o3dotdev/code-db/schema"; import { and, eq, isNull } from "drizzle-orm"; -import { getDurableStream, requireAuth } from "../lib"; +import { findChatSessionOwner, getDurableStream, requireAuth } from "../lib"; function errorMessage(error: unknown): string { if (error instanceof Error) return error.message; @@ -44,6 +44,11 @@ export async function PUT( ); } + const existingOwner = await findChatSessionOwner(sessionId); + if (existingOwner && existingOwner.createdBy !== session.user.id) { + return new Response("Not found", { status: 404 }); + } + const stream = getDurableStream(sessionId); try { await stream.create({ contentType: "application/json" }); @@ -139,10 +144,20 @@ export async function PATCH( const body = (await request.json()) as { title?: string }; if (body.title !== undefined) { - await db + const [updated] = await db .update(chatSessions) .set({ title: body.title }) - .where(eq(chatSessions.id, sessionId)); + .where( + and( + eq(chatSessions.id, sessionId), + eq(chatSessions.createdBy, session.user.id), + ), + ) + .returning({ id: chatSessions.id }); + + if (!updated) { + return new Response("Not found", { status: 404 }); + } } return Response.json({ success: true }, { status: 200 }); diff --git a/apps/api/src/app/api/chat/[sessionId]/stream/route.ts b/apps/api/src/app/api/chat/[sessionId]/stream/route.ts index 57723458c..0ec4c6241 100644 --- a/apps/api/src/app/api/chat/[sessionId]/stream/route.ts +++ b/apps/api/src/app/api/chat/[sessionId]/stream/route.ts @@ -1,8 +1,9 @@ import { db } from "@o3dotdev/code-db/client"; import { chatSessions } from "@o3dotdev/code-db/schema"; -import { eq } from "drizzle-orm"; +import { and, eq } from "drizzle-orm"; import { env } from "@/env"; import { + loadOwnedChatSession, PRODUCER_RESPONSE_HEADERS, PROTOCOL_QUERY_PARAMS, PROTOCOL_RESPONSE_HEADERS, @@ -23,6 +24,10 @@ export async function GET( if (!session) return new Response("Unauthorized", { status: 401 }); const { sessionId } = await params; + + const owned = await loadOwnedChatSession(sessionId, session.user.id); + if (!owned) return new Response("Not found", { status: 404 }); + const url = new URL(request.url); const upstream = new URL(streamUrl(sessionId)); @@ -83,6 +88,10 @@ export async function POST( if (!session) return new Response("Unauthorized", { status: 401 }); const { sessionId } = await params; + + const owned = await loadOwnedChatSession(sessionId, session.user.id); + if (!owned) return new Response("Not found", { status: 404 }); + const upstream = streamUrl(sessionId); const headers: Record = { @@ -137,6 +146,9 @@ export async function DELETE( const { sessionId } = await params; + const owned = await loadOwnedChatSession(sessionId, session.user.id); + if (!owned) return new Response("Not found", { status: 404 }); + const response = await fetch(streamUrl(sessionId), { method: "DELETE", headers: { @@ -144,7 +156,14 @@ export async function DELETE( }, }); - await db.delete(chatSessions).where(eq(chatSessions.id, sessionId)); + await db + .delete(chatSessions) + .where( + and( + eq(chatSessions.id, sessionId), + eq(chatSessions.createdBy, session.user.id), + ), + ); const headers = new Headers(); for (const [key, value] of response.headers.entries()) { @@ -172,6 +191,9 @@ export async function HEAD( const { sessionId } = await params; + const owned = await loadOwnedChatSession(sessionId, session.user.id); + if (!owned) return new Response("Not found", { status: 404 }); + const response = await fetch(streamUrl(sessionId), { method: "HEAD", headers: { diff --git a/apps/api/src/app/api/chat/lib.ts b/apps/api/src/app/api/chat/lib.ts index 2ac5af528..d591424ff 100644 --- a/apps/api/src/app/api/chat/lib.ts +++ b/apps/api/src/app/api/chat/lib.ts @@ -1,5 +1,8 @@ import { DurableStream } from "@durable-streams/client"; import { auth } from "@o3dotdev/code-auth/server"; +import { db } from "@o3dotdev/code-db/client"; +import { chatSessions } from "@o3dotdev/code-db/schema"; +import { and, eq } from "drizzle-orm"; import { env } from "@/env"; export const PROTOCOL_QUERY_PARAMS = ["offset", "live", "cursor"]; @@ -36,6 +39,26 @@ export async function requireAuth(request: Request) { return sessionData; } +export async function loadOwnedChatSession(sessionId: string, userId: string) { + const [row] = await db + .select({ id: chatSessions.id, createdBy: chatSessions.createdBy }) + .from(chatSessions) + .where( + and(eq(chatSessions.id, sessionId), eq(chatSessions.createdBy, userId)), + ) + .limit(1); + return row ?? null; +} + +export async function findChatSessionOwner(sessionId: string) { + const [row] = await db + .select({ createdBy: chatSessions.createdBy }) + .from(chatSessions) + .where(eq(chatSessions.id, sessionId)) + .limit(1); + return row ?? null; +} + export function streamUrl(sessionId: string) { return `${env.DURABLE_STREAMS_URL}/sessions/${sessionId}`; } diff --git a/apps/api/src/trpc/context.ts b/apps/api/src/trpc/context.ts index cffb581a4..9c029f818 100644 --- a/apps/api/src/trpc/context.ts +++ b/apps/api/src/trpc/context.ts @@ -8,6 +8,8 @@ import { env } from "@/env"; const apiUrl = env.NEXT_PUBLIC_API_URL.replace(/\/+$/, ""); +const TRUSTED_API_CLIENTS = new Set(["o3-code-cli"]); + function looksLikeJwt(token: string): boolean { const parts = token.split("."); return parts.length === 3 && parts.every(Boolean); @@ -34,6 +36,12 @@ async function sessionFromOAuthBearer( return null; } + const authorizedClientId = + typeof payload.azp === "string" ? payload.azp : null; + if (authorizedClientId && !TRUSTED_API_CLIENTS.has(authorizedClientId)) { + return null; + } + const userId = typeof payload.sub === "string" ? payload.sub : null; if (!userId) return null; diff --git a/apps/relay/src/directory.ts b/apps/relay/src/directory.ts index 25081cbcd..d38fdb175 100644 --- a/apps/relay/src/directory.ts +++ b/apps/relay/src/directory.ts @@ -11,6 +11,7 @@ const TTL_GRACE_MS = 90_000; const redis = new Redis({ url: env.KV_REST_API_URL, token: env.KV_REST_API_TOKEN, + readYourWrites: false, }); export interface TunnelOwner {