diff --git a/apps/dev-playground/app.yaml b/apps/dev-playground/app.yaml index 7b57e4ff8..cb8111ae9 100644 --- a/apps/dev-playground/app.yaml +++ b/apps/dev-playground/app.yaml @@ -29,3 +29,11 @@ env: valueFrom: volume - name: DATABRICKS_VOLUME_IMPLICIT valueFrom: volume + # OBO demo: same physical volume; auth: "on-behalf-of-user" routes + # HTTP traffic through runInUserContext so SDK calls execute as the + # end user. + - name: DATABRICKS_VOLUME_OBO_DEMO + valueFrom: volume + # Lakebase database resource + - name: LAKEBASE_ENDPOINT + valueFrom: database diff --git a/apps/dev-playground/client/src/components/lakebase/OboProductsPanel.tsx b/apps/dev-playground/client/src/components/lakebase/OboProductsPanel.tsx new file mode 100644 index 000000000..673570a3a --- /dev/null +++ b/apps/dev-playground/client/src/components/lakebase/OboProductsPanel.tsx @@ -0,0 +1,354 @@ +import { + Badge, + Button, + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, + Input, +} from "@databricks/appkit-ui/react"; +import { Loader2, Package, ShieldCheck } from "lucide-react"; +import { useId, useState } from "react"; +import { useLakebaseData, useLakebasePost } from "@/hooks/use-lakebase-data"; + +interface Product { + id: string; + name: string; + category: string; + price: number | string; + stock: number; + created_by: string | null; + created_at: string; +} + +interface CreateProductRequest { + name: string; + category: string; + price: number; + stock: number; +} + +export function OboProductsPanel() { + const nameId = useId(); + const categoryId = useId(); + const priceId = useId(); + const stockId = useId(); + + const { + data: myProducts, + loading: myLoading, + error: myError, + refetch: refetchMy, + } = useLakebaseData("/api/lakebase-examples/raw/my-products"); + + const { + data: allProducts, + loading: allLoading, + error: allError, + refetch: refetchAll, + } = useLakebaseData("/api/lakebase-examples/raw/products"); + + const { post, loading: creating } = useLakebasePost< + CreateProductRequest, + Product + >("/api/lakebase-examples/raw/my-products"); + + const generateRandomProduct = () => { + const products = [ + "Ergonomic Keyboard", + "Wireless Mouse", + "USB-C Hub", + "Laptop Stand", + "Monitor Arm", + "Mechanical Keyboard", + "Gaming Headset", + "Webcam HD", + ]; + const categories = ["Electronics", "Accessories", "Peripherals", "Office"]; + const price = (Math.random() * (199.99 - 29.99) + 29.99).toFixed(2); + const stock = Math.floor(Math.random() * (500 - 50) + 50); + + return { + name: products[Math.floor(Math.random() * products.length)], + category: categories[Math.floor(Math.random() * categories.length)], + price, + stock: String(stock), + }; + }; + + const [formData, setFormData] = useState(generateRandomProduct()); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + const result = await post({ + name: formData.name, + category: formData.category, + price: Number(formData.price), + stock: Number(formData.stock), + }); + + if (result) { + setFormData(generateRandomProduct()); + refetchMy(); + refetchAll(); + } + }; + + const myProductsList = myProducts ?? []; + + return ( +
+ {/* Header */} + + +
+
+ +
+
+ Raw Driver — On-Behalf-Of (OBO) + + Per-user connection pool with Row-Level Security (RLS). Each + user gets their own pg.Pool authenticated with their Databricks + identity. The database filters rows based on{" "} + current_user. + +
+
+
+
+ + {/* Create product as user */} + + + + Create Product (as current user) + + + This product will have created_by set to your identity. + RLS will make it visible only to you. + + + +
+
+
+ + + setFormData({ ...formData, name: e.target.value }) + } + placeholder="Wireless Mouse" + required + /> +
+
+ + + setFormData({ ...formData, category: e.target.value }) + } + placeholder="Electronics" + required + /> +
+
+ + + setFormData({ ...formData, price: e.target.value }) + } + placeholder="29.99" + required + /> +
+
+ + + setFormData({ ...formData, stock: e.target.value }) + } + placeholder="100" + required + /> +
+
+ +
+
+
+ + {/* Side-by-side comparison */} +
+ {/* My products (OBO, RLS filtered) */} + + +
+
+ + My Products (OBO pool) + + + RLS-filtered via per-user pool. Users with{" "} + databricks_superuser role bypass RLS. + +
+ +
+
+ + {myLoading && ( +
+
+ Loading... +
+ )} + {myError && ( +
+ {myError.message} +
+ )} + {!myLoading && myProductsList.length === 0 && ( +
+ +

No products yet. Create one above.

+
+ )} + {myProductsList.length > 0 && ( + + )} + + + + {/* All products (SP, bypasses RLS) */} + + +
+
+ + All Products (SP pool) + + + Service principal bypasses RLS + +
+ +
+
+ + {allLoading && ( +
+
+ Loading... +
+ )} + {allError && ( +
+ {allError.message} +
+ )} + {allProducts && allProducts.length > 0 && ( + + )} + + +
+
+ ); +} + +function ProductTable({ + products, + showCreatedBy, +}: { + products: Product[]; + showCreatedBy?: boolean; +}) { + return ( +
+ + + + + + + {showCreatedBy && ( + + )} + + + + {products.map((p) => ( + + + + + {showCreatedBy && ( + + )} + + ))} + +
+ Name + + Category + + Price + + Created By +
{p.name} + {p.category} + + ${Number(p.price).toFixed(2)} + + {p.created_by ?? "—"} +
+
+ ); +} diff --git a/apps/dev-playground/client/src/components/lakebase/ProductsPanel.tsx b/apps/dev-playground/client/src/components/lakebase/ProductsPanel.tsx deleted file mode 100644 index d1b6c690a..000000000 --- a/apps/dev-playground/client/src/components/lakebase/ProductsPanel.tsx +++ /dev/null @@ -1,305 +0,0 @@ -import { - Badge, - Button, - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, - Input, -} from "@databricks/appkit-ui/react"; -import { Database, Loader2, Package } from "lucide-react"; -import { useId, useState } from "react"; -import { useLakebaseData, useLakebasePost } from "@/hooks/use-lakebase-data"; - -interface Product { - id: number; - name: string; - category: string; - price: number | string; // PostgreSQL DECIMAL returns as string - stock: number; - created_by?: string; - created_at: string; -} - -interface CreateProductRequest { - name: string; - category: string; - price: number; - stock: number; -} - -interface HealthStatus { - status: string; - connected: boolean; - message: string; -} - -export function ProductsPanel() { - const nameId = useId(); - const categoryId = useId(); - const priceId = useId(); - const stockId = useId(); - - const { - data: products, - loading: productsLoading, - error: productsError, - refetch, - } = useLakebaseData("/api/lakebase-examples/raw/products"); - - const { data: health } = useLakebaseData( - "/api/lakebase-examples/raw/health", - ); - - const { post, loading: creating } = useLakebasePost< - CreateProductRequest, - Product - >("/api/lakebase-examples/raw/products"); - - const generateRandomProduct = () => { - const products = [ - "Ergonomic Keyboard", - "Wireless Mouse", - "USB-C Hub", - "Laptop Stand", - "Monitor Arm", - "Mechanical Keyboard", - "Gaming Headset", - "Webcam HD", - ]; - const categories = ["Electronics", "Accessories", "Peripherals", "Office"]; - const price = (Math.random() * (199.99 - 29.99) + 29.99).toFixed(2); - const stock = Math.floor(Math.random() * (500 - 50) + 50); - - return { - name: products[Math.floor(Math.random() * products.length)], - category: categories[Math.floor(Math.random() * categories.length)], - price, - stock: String(stock), - }; - }; - - const [formData, setFormData] = useState(generateRandomProduct()); - - const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault(); - const result = await post({ - name: formData.name, - category: formData.category, - price: Number(formData.price), - stock: Number(formData.stock), - }); - - if (result) { - setFormData(generateRandomProduct()); - refetch(); - } - }; - - return ( -
- {/* Header with connection status */} - - -
-
-
- -
-
- Raw Driver Example - - Direct PostgreSQL connection using pg.Pool with automatic - OAuth token refresh - -
-
- {health && ( - - {health.connected ? "Connected" : "Disconnected"} - - )} -
-
-
- - {/* Create product form */} - - - Create Product - - -
-
-
- - - setFormData({ ...formData, name: e.target.value }) - } - placeholder="Wireless Mouse" - required - /> -
-
- - - setFormData({ ...formData, category: e.target.value }) - } - placeholder="Electronics" - required - /> -
-
- - - setFormData({ ...formData, price: e.target.value }) - } - placeholder="29.99" - required - /> -
-
- - - setFormData({ ...formData, stock: e.target.value }) - } - placeholder="100" - required - /> -
-
- -
-
-
- - {/* Products list */} - - -
- Products Catalog - -
-
- - {productsLoading && ( -
-
- Loading products... -
- )} - - {productsError && ( -
- Error:{" "} - {productsError.message} -
- )} - - {products && products.length === 0 && ( -
- -

No products available. Create your first product above.

-
- )} - - {products && products.length > 0 && ( -
- - - - - - - - - - - - {products.map((product) => ( - - - - - - - - ))} - -
- ID - - Name - - Category - - Price - - Stock -
{product.id}{product.name} - {product.category} - - ${Number(product.price).toFixed(2)} - - {product.stock} -
-
- )} - - -
- ); -} diff --git a/apps/dev-playground/client/src/components/lakebase/index.ts b/apps/dev-playground/client/src/components/lakebase/index.ts index 6ba528a63..d64d2e3d1 100644 --- a/apps/dev-playground/client/src/components/lakebase/index.ts +++ b/apps/dev-playground/client/src/components/lakebase/index.ts @@ -1,4 +1,4 @@ export { ActivityLogsPanel } from "./ActivityLogsPanel"; +export { OboProductsPanel } from "./OboProductsPanel"; export { OrdersPanel } from "./OrdersPanel"; -export { ProductsPanel } from "./ProductsPanel"; export { TasksPanel } from "./TasksPanel"; diff --git a/apps/dev-playground/client/src/routes/lakebase.route.tsx b/apps/dev-playground/client/src/routes/lakebase.route.tsx index 59694e248..9ea8117e6 100644 --- a/apps/dev-playground/client/src/routes/lakebase.route.tsx +++ b/apps/dev-playground/client/src/routes/lakebase.route.tsx @@ -7,8 +7,8 @@ import { import { createFileRoute, retainSearchParams } from "@tanstack/react-router"; import { ActivityLogsPanel, + OboProductsPanel, OrdersPanel, - ProductsPanel, TasksPanel, } from "@/components/lakebase"; @@ -23,7 +23,6 @@ function LakebaseRoute() { return (
- {/* Page header */}

Lakebase Examples

@@ -33,17 +32,16 @@ function LakebaseRoute() {

- {/* Tabs for different examples */} - Raw Driver + Raw Driver (OBO) Drizzle ORM TypeORM Sequelize - + diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index 74d813a90..ecbd18e78 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -5,6 +5,7 @@ import { type FilePolicy, files, genie, + lakebase, PolicyDeniedError, server, serving, @@ -327,6 +328,15 @@ const dashboard_pilot = createAgent({ }, }); +/** + * OBO demo policy: deny anything running as the SP (including the dev + * fallback when no `x-forwarded-access-token` is present). Only real + * end-users (`isServicePrincipal: false`) get through. + */ +const usersOnly: FilePolicy = (_action, _resource, user) => { + return user.isServicePrincipal !== true; +}; + createApp({ plugins: [ server(), @@ -336,6 +346,7 @@ createApp({ genie({ spaces: { demo: process.env.DATABRICKS_GENIE_SPACE_ID ?? "placeholder" }, }), + ...(process.env.LAKEBASE_ENDPOINT ? [lakebase()] : []), lakebaseExamples(), files({ volumes: { @@ -362,6 +373,14 @@ createApp({ write_only: { policy: files.policy.not(files.policy.publicRead()) }, // no explicit policy → falls back to publicRead() + startup warning implicit: {}, + // OBO demo volume — auth: "on-behalf-of-user" routes HTTP traffic + // through `runInUserContext` so SDK calls execute with the end + // user's access token. The `usersOnly` policy denies any traffic + // that wasn't authenticated via `x-forwarded-access-token`. + obo_demo: { + auth: "on-behalf-of-user", + policy: usersOnly, + }, }, }), serving(), @@ -390,6 +409,51 @@ createApp({ ...(process.env.APPKIT_E2E_TEST && { client: createMockClient() }), async onPluginsReady(appkit) { appkit.server.extend((app) => { + // ── Lakebase OBO routes (per-user pool, RLS enforced) ────────── + + if ("lakebase" in appkit) { + // GET /api/lakebase-examples/raw/my-products — RLS-filtered list + app.get("/api/lakebase-examples/raw/my-products", async (req, res) => { + try { + const result = await appkit.lakebase + .asUser(req) + .query( + "SELECT * FROM raw_example.products ORDER BY created_at DESC", + ); + res.json(result.rows); + } catch (error: unknown) { + const err = error as Error; + res.status(500).json({ + error: "Failed to fetch user products", + message: err.message, + }); + } + }); + + // POST /api/lakebase-examples/raw/my-products — create as user + // created_by is set to current_user by the per-user pool's identity + app.post("/api/lakebase-examples/raw/my-products", async (req, res) => { + try { + const { name, category, price, stock } = req.body; + + const result = await appkit.lakebase.asUser(req).query( + `INSERT INTO raw_example.products (name, category, price, stock, created_by) + VALUES ($1, $2, $3, $4, current_user) RETURNING *`, + [name, category, Number(price), Number(stock)], + ); + res.json(result.rows[0]); + } catch (error: unknown) { + const err = error as Error; + res.status(500).json({ + error: "Failed to create product", + message: err.message, + }); + } + }); + } + + // ── Analytics examples ────────── + app.get("/sp", (_req, res) => { appkit.analytics .query("SELECT * FROM samples.nyctaxi.trips;") @@ -683,6 +747,43 @@ createApp({ res.status(404).json({ error: msg }); } }); + + /** + * Per-volume OBO mode demo. Hits the `obo_demo` volume — configured + * with `auth: "on-behalf-of-user"` — to confirm: + * + * 1. With a forwarded user identity, HTTP routes execute the SDK + * call as the end user (request goes through `runInUserContext`). + * 2. Without `x-forwarded-access-token`, production returns 401; + * development falls back to the SP and the `usersOnly` policy + * rejects with 403. + * 3. Programmatic `appkit.files("obo_demo").asUser(req).list()` runs + * inside the same user context. + * + * Returns the HTTP status, body, and the user identity the server + * observes — so the policy-matrix client can render a clear + * pass/fail panel. + */ + app.get("/policy/obo-volume", async (req, res) => { + const xForwardedUser = req.header("x-forwarded-user") ?? null; + const xForwardedToken = + (req.header("x-forwarded-access-token")?.length ?? 0) > 0; + + const programmatic: ProbeResult[] = await runProbes([ + [ + "obo_demo", + "list", + () => appkit.files("obo_demo").asUser(req).list(), + ], + ]); + + res.json({ + mode: "on-behalf-of-user", + xForwardedUser, + xForwardedAccessTokenPresent: xForwardedToken, + programmatic, + }); + }); }); }, }).catch(console.error); diff --git a/apps/dev-playground/server/lakebase-examples/raw-driver-example.ts b/apps/dev-playground/server/lakebase-examples/raw-driver-example.ts index 43b2ca3b2..327b5a601 100644 --- a/apps/dev-playground/server/lakebase-examples/raw-driver-example.ts +++ b/apps/dev-playground/server/lakebase-examples/raw-driver-example.ts @@ -11,22 +11,25 @@ let pool: Pool; * - Direct pg.Pool usage without ORM abstraction * - Manual SQL query writing with parameterized queries * - Schema and table creation (idempotent) - * - Basic CRUD operations - * - Connection health checking + * - Row-Level Security (RLS) setup + * - Basic CRUD operations (SP pool) + * + * OBO routes are registered separately in index.ts via the Lakebase plugin's + * `asUser(req)` pattern — see `onPluginsReady`. */ interface Product { - id: number; + id: string; name: string; category: string; price: number; stock: number; - created_by?: string; + created_by: string | null; created_at: Date; } export async function setup(user?: string) { - // Create pool with automatic OAuth token refresh + // Create service principal pool with automatic OAuth token refresh pool = createLakebasePool({ user }); // Create schema and table (idempotent) @@ -34,15 +37,47 @@ export async function setup(user?: string) { CREATE SCHEMA IF NOT EXISTS raw_example; CREATE TABLE IF NOT EXISTS raw_example.products ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), name VARCHAR(255) NOT NULL, category VARCHAR(100), price DECIMAL(10, 2), stock INTEGER DEFAULT 0, + created_by VARCHAR(255) DEFAULT current_user, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); `); + // Enable Row-Level Security (idempotent) + await pool.query(` + ALTER TABLE raw_example.products ENABLE ROW LEVEL SECURITY; + `); + + // Create RLS policy (idempotent via IF NOT EXISTS-like pattern) + // Users see only rows they created (or rows with NULL created_by for seed data). + // The table owner (service principal) bypasses RLS automatically. + await pool.query(` + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_policies + WHERE schemaname = 'raw_example' + AND tablename = 'products' + AND policyname = 'user_products_policy' + ) THEN + CREATE POLICY user_products_policy ON raw_example.products + FOR ALL TO PUBLIC + USING (created_by = current_user OR created_by IS NULL); + END IF; + END + $$; + `); + + // Grant schema/table access to PUBLIC so OBO users can SELECT/INSERT + await pool.query(` + GRANT USAGE ON SCHEMA raw_example TO PUBLIC; + GRANT ALL ON ALL TABLES IN SCHEMA raw_example TO PUBLIC; + `); + // Seed sample data if table is empty const { rows } = await pool.query<{ count: string }>( "SELECT COUNT(*) as count FROM raw_example.products", @@ -53,7 +88,9 @@ export async function setup(user?: string) { } export function registerRoutes(router: IAppRouter, basePath: string) { - // GET /api/lakebase-examples/raw/products - List all products + // ── Service principal routes (bypass RLS as table owner) ────────── + + // GET /raw/products - List ALL products (SP pool, bypasses RLS) router.get(`${basePath}/products`, async (_req, res) => { try { const result = await pool.query( @@ -69,7 +106,7 @@ export function registerRoutes(router: IAppRouter, basePath: string) { } }); - // POST /api/lakebase-examples/raw/products - Create new product + // POST /raw/products - Create product as SP (no created_by) router.post(`${basePath}/products`, async (req, res) => { try { const { name, category, price, stock } = req.body; @@ -89,7 +126,7 @@ export function registerRoutes(router: IAppRouter, basePath: string) { } }); - // GET /api/lakebase-examples/raw/health - Connection health check + // GET /raw/health - Connection health check router.get(`${basePath}/health`, async (_req, res) => { try { await pool.query("SELECT 1"); diff --git a/docs/docs/api/appkit/Function.createLakebasePoolManager.md b/docs/docs/api/appkit/Function.createLakebasePoolManager.md new file mode 100644 index 000000000..9bcf25ef1 --- /dev/null +++ b/docs/docs/api/appkit/Function.createLakebasePoolManager.md @@ -0,0 +1,36 @@ +# Function: createLakebasePoolManager() + +```ts +function createLakebasePoolManager(baseConfig?: Partial): LakebasePoolManager; +``` + +Create a pool manager that maintains per-key Lakebase connection pools. + +Each pool is created via `createLakebasePool` with the base config merged +with per-pool overrides (e.g. a user's `workspaceClient` and `user`). + +A periodic cleanup removes empty Pool objects (where all connections have +been closed by pg's built-in `idleTimeoutMillis`) from the internal Map. + +## Parameters + +| Parameter | Type | +| ------ | ------ | +| `baseConfig?` | `Partial`\<[`LakebasePoolConfig`](Interface.LakebasePoolConfig.md)\> | + +## Returns + +[`LakebasePoolManager`](Interface.LakebasePoolManager.md) + +## Example + +```typescript +const poolManager = createLakebasePoolManager(); + +// In a route handler: +const userPool = poolManager.getPool(userName, { + workspaceClient: new WorkspaceClient({ token: userToken, host, authType: "pat" }), + user: userName, +}); +const result = await userPool.query("SELECT * FROM products"); +``` diff --git a/docs/docs/api/appkit/Interface.LakebasePool.md b/docs/docs/api/appkit/Interface.LakebasePool.md new file mode 100644 index 000000000..52fe12a73 --- /dev/null +++ b/docs/docs/api/appkit/Interface.LakebasePool.md @@ -0,0 +1,80 @@ +# Interface: LakebasePool + +Subset of `pg.Pool` exposed by the Lakebase plugin. + +RoutingPool does not extend EventEmitter — event listener methods +like `on('error', ...)` are not available. Use `query()`, `connect()`, +and `end()` for all pool operations. + +## Properties + +### idleCount + +```ts +readonly idleCount: number; +``` + +*** + +### totalCount + +```ts +readonly totalCount: number; +``` + +*** + +### waitingCount + +```ts +readonly waitingCount: number; +``` + +## Methods + +### connect() + +```ts +connect(): Promise; +``` + +#### Returns + +`Promise`\<`PoolClient`\> + +*** + +### end() + +```ts +end(): Promise; +``` + +#### Returns + +`Promise`\<`void`\> + +*** + +### query() + +```ts +query(text: string, values?: unknown[]): Promise>; +``` + +#### Type Parameters + +| Type Parameter | Default type | +| ------ | ------ | +| `T` *extends* `QueryResultRow` | `any` | + +#### Parameters + +| Parameter | Type | +| ------ | ------ | +| `text` | `string` | +| `values?` | `unknown`[] | + +#### Returns + +`Promise`\<`QueryResult`\<`T`\>\> diff --git a/docs/docs/api/appkit/Interface.LakebasePoolManager.md b/docs/docs/api/appkit/Interface.LakebasePoolManager.md new file mode 100644 index 000000000..31c605d64 --- /dev/null +++ b/docs/docs/api/appkit/Interface.LakebasePoolManager.md @@ -0,0 +1,100 @@ +# Interface: LakebasePoolManager + +Manages multiple Lakebase connection pools keyed by an identifier (e.g. userId). + +Used for On-Behalf-Of (OBO) scenarios where each user needs their own pool +with their own OAuth token refresh, enabling features like Row-Level Security. + +## Properties + +### size + +```ts +readonly size: number; +``` + +Number of active pools. + +## Methods + +### closeAll() + +```ts +closeAll(): Promise; +``` + +Close all managed pools and stop cleanup (for graceful shutdown). + +#### Returns + +`Promise`\<`void`\> + +*** + +### closePool() + +```ts +closePool(key: string): Promise; +``` + +Close and remove a specific pool. + +#### Parameters + +| Parameter | Type | +| ------ | ------ | +| `key` | `string` | + +#### Returns + +`Promise`\<`void`\> + +*** + +### getPool() + +```ts +getPool( + key: string, + perPoolConfig: Partial, + tokenFingerprint?: string): Pool; +``` + +Get an existing pool or create a new one for the given key. +When creating, merges `perPoolConfig` with the base config passed to the factory. + +If `tokenFingerprint` is provided and differs from the cached pool's +fingerprint, the stale pool is closed and a fresh one is created with +the new config (including the updated `workspaceClient`). + +#### Parameters + +| Parameter | Type | +| ------ | ------ | +| `key` | `string` | +| `perPoolConfig` | `Partial`\<[`LakebasePoolConfig`](Interface.LakebasePoolConfig.md)\> | +| `tokenFingerprint?` | `string` | + +#### Returns + +`Pool` + +*** + +### hasPool() + +```ts +hasPool(key: string): boolean; +``` + +Check whether a pool exists for the given key. + +#### Parameters + +| Parameter | Type | +| ------ | ------ | +| `key` | `string` | + +#### Returns + +`boolean` diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index 66b826495..6ac54fa47 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -53,7 +53,9 @@ surface with `@databricks/appkit/beta`. Not meant for application imports. | [JobAPI](Interface.JobAPI.md) | User-facing API for a single configured job. | | [JobConfig](Interface.JobConfig.md) | Per-job configuration options. | | [JobsConnectorConfig](Interface.JobsConnectorConfig.md) | - | +| [LakebasePool](Interface.LakebasePool.md) | Subset of `pg.Pool` exposed by the Lakebase plugin. | | [LakebasePoolConfig](Interface.LakebasePoolConfig.md) | Configuration for creating a Lakebase connection pool | +| [LakebasePoolManager](Interface.LakebasePoolManager.md) | Manages multiple Lakebase connection pools keyed by an identifier (e.g. userId). | | [McpConnectAllResult](Interface.McpConnectAllResult.md) | Per-endpoint outcome of [AppKitMcpClient.connectAll](Class.AppKitMcpClient.md#connectall). Callers (the agents plugin in particular) use the split to warn at startup when some MCP servers are unreachable without aborting boot for the rest. | | [Message](Interface.Message.md) | - | | [PluginManifest](Interface.PluginManifest.md) | Plugin manifest that declares metadata and resource requirements. Attached to plugin classes as a static property. Extends the shared PluginManifest with strict resource types. | @@ -125,6 +127,7 @@ surface with `@databricks/appkit/beta`. Not meant for application imports. | [createAgent](Function.createAgent.md) | Pure factory for agent definitions. Returns the passed-in definition after cycle-detecting the sub-agent graph. Accepts the full `AgentDefinition` shape and is safe to call at module top-level. | | [createApp](Function.createApp.md) | Bootstraps AppKit with the provided configuration. | | [createLakebasePool](Function.createLakebasePool.md) | Create a Lakebase pool with appkit's logger integration. Telemetry automatically uses appkit's OpenTelemetry configuration via global registry. | +| [createLakebasePoolManager](Function.createLakebasePoolManager.md) | Create a pool manager that maintains per-key Lakebase connection pools. | | [defineTool](Function.defineTool.md) | Defines a single tool entry for a plugin's internal registry. | | [executeFromRegistry](Function.executeFromRegistry.md) | Validates tool-call arguments against the entry's schema and invokes its handler. On validation failure, returns an LLM-friendly error string (matching the behavior of `tool()`) rather than throwing, so the model can self-correct on its next turn. | | [extractServingEndpoints](Function.extractServingEndpoints.md) | Extract serving endpoint config from a server file by AST-parsing it. Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls and extracts the endpoint alias names and their environment variable mappings. | diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index fd91d60ce..e7c06eefc 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -197,11 +197,21 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.JobsConnectorConfig", label: "JobsConnectorConfig" }, + { + type: "doc", + id: "api/appkit/Interface.LakebasePool", + label: "LakebasePool" + }, { type: "doc", id: "api/appkit/Interface.LakebasePoolConfig", label: "LakebasePoolConfig" }, + { + type: "doc", + id: "api/appkit/Interface.LakebasePoolManager", + label: "LakebasePoolManager" + }, { type: "doc", id: "api/appkit/Interface.McpConnectAllResult", @@ -500,6 +510,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Function.createLakebasePool", label: "createLakebasePool" }, + { + type: "doc", + id: "api/appkit/Function.createLakebasePoolManager", + label: "createLakebasePoolManager" + }, { type: "doc", id: "api/appkit/Function.defineTool", diff --git a/docs/docs/plugins/execution-context.md b/docs/docs/plugins/execution-context.md index 98d2815bd..bb2400492 100644 --- a/docs/docs/plugins/execution-context.md +++ b/docs/docs/plugins/execution-context.md @@ -54,6 +54,12 @@ The `plugin.execute` span created by the execution interceptor chain includes th These attributes are automatically added when your plugin uses `execute()` or `executeStream()`. All built-in plugins use these methods for their OBO operations. Custom plugins should do the same to get automatic telemetry instrumentation. +## Lakebase per-user connections + +The Lakebase plugin uses a different mechanism for `asUser(req)`: instead of swapping the `WorkspaceClient` via AsyncLocalStorage, it creates a **separate `pg.Pool` per user**, each with its own OAuth token refresh. This is necessary because PostgreSQL connections are authenticated at connection time — the pool itself is the authentication boundary. + +See [Lakebase plugin — per-user connections](./lakebase.md#on-behalf-of-obo--per-user-connections) for details. + ## Development mode behavior In local development (`NODE_ENV=development`), if `asUser(req)` is called without a user token, it logs a warning and skips user impersonation — the operation runs with the default credentials configured for the app instead. The telemetry span will show `execution.context: "service"` with `execution.obo_dev_fallback: true` to distinguish these from regular service principal calls. diff --git a/docs/docs/plugins/lakebase.md b/docs/docs/plugins/lakebase.md index 768da3c2f..e6d728f9b 100644 --- a/docs/docs/plugins/lakebase.md +++ b/docs/docs/plugins/lakebase.md @@ -113,6 +113,88 @@ await createApp({ }); ``` +## On-Behalf-Of (OBO) — per-user connections + +When your app needs Row-Level Security (RLS) or per-user data isolation, use `asUser(req)` to execute queries using a per-user Lakebase connection pool. Each user's pool is authenticated with their Databricks identity, so PostgreSQL's `current_user` reflects the actual user. + +### Prerequisites + +1. **Enable user authorization** in your Databricks App with the **`postgres`** scope. See [User authorization](https://docs.databricks.com/aws/en/dev-tools/databricks-apps/auth#user-authorization) for setup instructions. In your `databricks.yml`: + ```yaml + resources: + apps: + app: + user_api_scopes: + - postgres + ``` + Apps scaffolded with `databricks apps init` and the Lakebase plugin include this automatically. + +2. Each app user needs a **Postgres role** in Lakebase. Create one with the Databricks CLI: + + ```bash + databricks postgres create-role "projects/{project_id}/branches/{branch_id}" \ + --json '{"spec": {"identity_type": "USER", "postgres_role": "user@example.com"}}' + ``` + + Alternatively, create roles in the Lakebase UI under **Branch Overview** → **Add role**. + + :::note + Do not grant `databricks_superuser` to OBO users — superusers bypass RLS. Use [fine-grained grants](#fine-grained-permissions) instead. + ::: + +### Usage + +No configuration needed — just call `asUser(req)`: + +```ts +const AppKit = await createApp({ + plugins: [server(), lakebase()], +}); + +// Service principal query (default — bypasses RLS as table owner) +const all = await AppKit.lakebase.query("SELECT * FROM app.orders"); + +// User-scoped query (per-user pool, RLS enforced) +app.get("/api/my-orders", async (req, res) => { + const result = await AppKit.lakebase + .asUser(req) + .query("SELECT * FROM app.orders ORDER BY created_at DESC"); + res.json(result.rows); +}); +``` + +When `asUser(req)` is called: +1. The user's token and identity are extracted from `x-forwarded-access-token` and `x-forwarded-email` headers (set automatically by Databricks Apps). +2. A per-user `pg.Pool` is created (or reused) with the user's OAuth credentials. +3. `query()` and `pool` use the user's pool — `current_user` in PostgreSQL reflects the user's identity. + +### Row-Level Security example + +```sql +-- As the service principal (during app setup): +ALTER TABLE app.orders ENABLE ROW LEVEL SECURITY; + +CREATE POLICY user_orders ON app.orders + FOR ALL TO PUBLIC + USING (owner = current_user); + +-- Grant access so OBO users can query +GRANT USAGE ON SCHEMA app TO PUBLIC; +GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA app TO PUBLIC; +``` + +### How it works + +- The **service principal pool** (`AppKit.lakebase.pool`) is always created and used for DDL operations, seeding, and admin queries. +- **Per-user pools** are created on the first `asUser(req)` call and cached by user identity. Each pool has its own OAuth token refresh cycle. +- Idle connections within per-user pools close automatically (30s idle timeout). Empty pool objects are cleaned up periodically. +- On shutdown, all pools (SP + user) are closed gracefully. +- In development mode (`NODE_ENV=development`), if no user token is available, `asUser(req)` falls back to the SP pool with a warning. + +:::caution[RLS and superusers] +PostgreSQL superusers bypass Row-Level Security entirely. Users with the `databricks_superuser` role will see all rows regardless of RLS policies. For RLS enforcement, use [fine-grained grants](#fine-grained-permissions) instead of the superuser role. +::: + ## Database Permissions When you create the app with the Lakebase resource using the [Getting started](#getting-started-with-the-lakebase) guide, the Service Principal is automatically granted `CONNECT_AND_CREATE` permission on the `postgres` resource. This lets the Service Principal connect to the database and create new objects, but **not access any existing schemas or tables.** @@ -123,21 +205,37 @@ To develop locally against a deployed Lakebase database: 1. **Deploy the app first.** The Service Principal creates the database schema and tables on first deploy. Apps generated from `databricks apps init` handle this automatically - they check if tables exist on startup and skip creation if they do. -2. **Grant `databricks_superuser` via the Lakebase UI:** - 1. Open the Lakebase Autoscaling UI and navigate to your project's **Branch Overview** page. - 2. Click **Add role** (or **Edit role** if your OAuth role already exists). - 3. Select your Databricks identity as the principal and check the **`databricks_superuser`** system role. +2. **Grant `databricks_superuser`** (skip if you are the Lakebase project owner — you already have full access): + + ```bash + # Create a new role with databricks_superuser + databricks postgres create-role "projects/{project_id}/branches/{branch_id}" \ + --json '{"spec": {"identity_type": "USER", "postgres_role": "user@example.com", "membership_roles": ["DATABRICKS_SUPERUSER"]}}' + ``` + + To grant superuser to an existing role, use [`update-role`](https://docs.databricks.com/aws/en/dev-tools/cli/reference/postgres-commands#databricks-postgres-update-role): + + ```bash + databricks postgres update-role \ + "projects/{project_id}/branches/{branch_id}/roles/{role_id}" \ + "spec.membership_roles" \ + --json '{"spec": {"membership_roles": ["DATABRICKS_SUPERUSER"]}}' + ``` + + Alternatively, you can manage roles in the Lakebase Autoscaling UI under your project's **Branch Overview** page → **Add role** / **Edit role**. 3. **Run locally** - your Databricks user identity (email) is used for OAuth authentication. The `databricks_superuser` role gives full **DML access** (read/write data) but **not DDL** (creating schemas or tables) - that's why deploying first matters (see note below). -For other users, use the same **Add role** flow in the Lakebase UI to create an OAuth role with `databricks_superuser` for each user. +For other users, repeat step 2 to create an OAuth role with `databricks_superuser` for each user. :::tip [Postgres password authentication](https://docs.databricks.com/aws/en/oltp/projects/authentication#overview) is a simpler alternative that avoids OAuth role permission complexity. However, it requires you to set up a password for the user in the **Branch Overview** page in the Lakebase Autoscaling UI. ::: :::info[Why deploy first?] -When the app is deployed, the Service Principal creates schemas and tables and becomes their owner. A `databricks_superuser` has full **DML access** (SELECT, INSERT, UPDATE, DELETE) to these objects, but **cannot run DDL** (CREATE SCHEMA, CREATE TABLE) on schemas owned by the Service Principal. Deploying first ensures all objects exist before local development begins. +When the app is deployed, the Service Principal creates schemas and tables and becomes their owner. `databricks_superuser` gives full DML access (read/write) but not DDL, so local development works only after the schema exists. + +If you run `npm run dev` first, your credentials own the schema and the deployed app hits `permission denied`. To recover, export any data first (`pg_dump` or a temporary schema copy), then drop the schema and redeploy. After redeploying, the Service Principal recreates the schema on startup. (PostgreSQL schema ownership is tied to the role that created it and cannot be reassigned by regular users.) ::: ### Fine-grained permissions diff --git a/packages/appkit/src/connectors/lakebase/index.ts b/packages/appkit/src/connectors/lakebase/index.ts index c58b7a8cb..17d7491f6 100644 --- a/packages/appkit/src/connectors/lakebase/index.ts +++ b/packages/appkit/src/connectors/lakebase/index.ts @@ -35,3 +35,10 @@ export { RequestedClaimsPermissionSet, type RequestedResource, } from "@databricks/lakebase"; + +export { + createLakebasePoolManager, + type LakebasePoolManager, +} from "./pool-manager"; + +export { type LakebasePool, RoutingPool } from "./routing-pool"; diff --git a/packages/appkit/src/connectors/lakebase/pool-manager.ts b/packages/appkit/src/connectors/lakebase/pool-manager.ts new file mode 100644 index 000000000..e48fa70cb --- /dev/null +++ b/packages/appkit/src/connectors/lakebase/pool-manager.ts @@ -0,0 +1,140 @@ +import type { LakebasePoolConfig } from "@databricks/lakebase"; +import type { Pool } from "pg"; +import { createLakebasePool } from "./index"; + +/** Interval for removing empty (connectionless) pools from the Map. */ +const CLEANUP_INTERVAL_MS = 5 * 60 * 1000; // 5 minutes + +/** + * Manages multiple Lakebase connection pools keyed by an identifier (e.g. userId). + * + * Used for On-Behalf-Of (OBO) scenarios where each user needs their own pool + * with their own OAuth token refresh, enabling features like Row-Level Security. + */ +export interface LakebasePoolManager { + /** + * Get an existing pool or create a new one for the given key. + * When creating, merges `perPoolConfig` with the base config passed to the factory. + * + * If `tokenFingerprint` is provided and differs from the cached pool's + * fingerprint, the stale pool is closed and a fresh one is created with + * the new config (including the updated `workspaceClient`). + */ + getPool( + key: string, + perPoolConfig: Partial, + tokenFingerprint?: string, + ): Pool; + + /** Check whether a pool exists for the given key. */ + hasPool(key: string): boolean; + + /** Close and remove a specific pool. */ + closePool(key: string): Promise; + + /** Close all managed pools and stop cleanup (for graceful shutdown). */ + closeAll(): Promise; + + /** Number of active pools. */ + readonly size: number; +} + +/** + * Create a pool manager that maintains per-key Lakebase connection pools. + * + * Each pool is created via `createLakebasePool` with the base config merged + * with per-pool overrides (e.g. a user's `workspaceClient` and `user`). + * + * A periodic cleanup removes empty Pool objects (where all connections have + * been closed by pg's built-in `idleTimeoutMillis`) from the internal Map. + * + * @example OBO usage + * ```typescript + * const poolManager = createLakebasePoolManager(); + * + * // In a route handler: + * const userPool = poolManager.getPool(userName, { + * workspaceClient: new WorkspaceClient({ token: userToken, host, authType: "pat" }), + * user: userName, + * }); + * const result = await userPool.query("SELECT * FROM products"); + * ``` + */ +export function createLakebasePoolManager( + baseConfig?: Partial, +): LakebasePoolManager { + interface PoolEntry { + pool: Pool; + tokenFingerprint?: string; + } + + const entries = new Map(); + + // Periodically remove empty Pool objects from the Map. + // pg.Pool's idleTimeoutMillis closes idle connections automatically; + // this just cleans up the Map entries once all connections are gone. + const cleanupTimer = setInterval(() => { + for (const [key, entry] of entries) { + if (entry.pool.totalCount === 0) { + entry.pool.end().catch(() => {}); + entries.delete(key); + } + } + }, CLEANUP_INTERVAL_MS); + cleanupTimer.unref(); + + return { + getPool( + key: string, + perPoolConfig: Partial, + tokenFingerprint?: string, + ): Pool { + const existing = entries.get(key); + + if (existing) { + // When the caller provides a fingerprint that differs from the + // cached one, the underlying OBO token has rotated. The pool's + // password callback holds a stale WorkspaceClient (authType: "pat", + // static token) that will fail once the Lakebase Postgres token + // needs refreshing. Drain the old pool and create a fresh one. + const stale = + tokenFingerprint && + existing.tokenFingerprint && + tokenFingerprint !== existing.tokenFingerprint; + + if (!stale) return existing.pool; + + existing.pool.end().catch(() => {}); + } + + // Safe without locking: createLakebasePool is synchronous and Node.js + // is single-threaded, so no preemption between get() and set(). + const pool = createLakebasePool({ ...baseConfig, ...perPoolConfig }); + entries.set(key, { pool, tokenFingerprint }); + return pool; + }, + + hasPool(key: string): boolean { + return entries.has(key); + }, + + async closePool(key: string): Promise { + const entry = entries.get(key); + if (entry) { + await entry.pool.end(); + entries.delete(key); + } + }, + + async closeAll(): Promise { + clearInterval(cleanupTimer); + const endPromises = [...entries.values()].map((e) => e.pool.end()); + await Promise.all(endPromises); + entries.clear(); + }, + + get size() { + return entries.size; + }, + }; +} diff --git a/packages/appkit/src/connectors/lakebase/routing-pool.ts b/packages/appkit/src/connectors/lakebase/routing-pool.ts new file mode 100644 index 000000000..55dd74fa0 --- /dev/null +++ b/packages/appkit/src/connectors/lakebase/routing-pool.ts @@ -0,0 +1,71 @@ +import type { Pool, PoolClient, QueryResult, QueryResultRow } from "pg"; +import { getUserContext } from "../../context/execution-context"; +import type { UserContext } from "../../context/user-context"; + +/** + * Subset of `pg.Pool` exposed by the Lakebase plugin. + * + * RoutingPool does not extend EventEmitter — event listener methods + * like `on('error', ...)` are not available. Use `query()`, `connect()`, + * and `end()` for all pool operations. + */ +export interface LakebasePool { + query( + text: string, + values?: unknown[], + ): Promise>; + connect(): Promise; + end(): Promise; + readonly totalCount: number; + readonly idleCount: number; + readonly waitingCount: number; +} + +/** + * A `pg.Pool`-like wrapper that routes queries to the appropriate pool + * based on the current execution context. + * + * When called inside `runInUserContext()` (set up by `Plugin.asUser(req)`), + * queries route to the per-user pool returned by `resolveUserPool`. + * Otherwise, queries route to the service-principal pool. + * + * This enables OBO (On-Behalf-Of) without custom `asUser()` overrides — + * the base class sets up AsyncLocalStorage context, and the RoutingPool + * reads it transparently. + */ +export class RoutingPool implements LakebasePool { + constructor( + private spPool: Pool, + private resolveUserPool: (ctx: UserContext) => Pool, + ) {} + + private activePool(): Pool { + const userCtx = getUserContext(); + return userCtx ? this.resolveUserPool(userCtx) : this.spPool; + } + + query( + text: string, + values?: unknown[], + ): Promise> { + return this.activePool().query(text, values); + } + + connect(): Promise { + return this.activePool().connect(); + } + + async end(): Promise { + await this.spPool.end(); + } + + get totalCount() { + return this.spPool.totalCount; + } + get idleCount() { + return this.spPool.idleCount; + } + get waitingCount() { + return this.spPool.waitingCount; + } +} diff --git a/packages/appkit/src/connectors/lakebase/tests/pool-manager.test.ts b/packages/appkit/src/connectors/lakebase/tests/pool-manager.test.ts new file mode 100644 index 000000000..2b0efff77 --- /dev/null +++ b/packages/appkit/src/connectors/lakebase/tests/pool-manager.test.ts @@ -0,0 +1,139 @@ +import type { Pool } from "pg"; +import { afterEach, describe, expect, test, vi } from "vitest"; + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +const mockPools: Pool[] = []; + +vi.mock("../index", () => ({ + createLakebasePool: vi.fn(() => { + const pool = { + query: vi.fn(async () => ({ rows: [] })), + connect: vi.fn(), + end: vi.fn(async () => {}), + totalCount: 1, + idleCount: 0, + waitingCount: 0, + } as unknown as Pool; + mockPools.push(pool); + return pool; + }), +})); + +import { createLakebasePoolManager } from "../pool-manager"; + +afterEach(() => { + mockPools.length = 0; + vi.restoreAllMocks(); +}); + +describe("createLakebasePoolManager", () => { + test("creates and caches a pool for a key", () => { + const manager = createLakebasePoolManager(); + const pool1 = manager.getPool("user-a", { user: "user-a" }); + const pool2 = manager.getPool("user-a", { user: "user-a" }); + + expect(pool1).toBe(pool2); + expect(mockPools).toHaveLength(1); + expect(manager.size).toBe(1); + }); + + test("creates separate pools for different keys", () => { + const manager = createLakebasePoolManager(); + const poolA = manager.getPool("user-a", { user: "user-a" }); + const poolB = manager.getPool("user-b", { user: "user-b" }); + + expect(poolA).not.toBe(poolB); + expect(mockPools).toHaveLength(2); + expect(manager.size).toBe(2); + }); + + test("hasPool returns correct state", () => { + const manager = createLakebasePoolManager(); + + expect(manager.hasPool("user-a")).toBe(false); + manager.getPool("user-a", { user: "user-a" }); + expect(manager.hasPool("user-a")).toBe(true); + }); + + test("closePool closes and removes a specific pool", async () => { + const manager = createLakebasePoolManager(); + const pool = manager.getPool("user-a", { user: "user-a" }); + + await manager.closePool("user-a"); + + expect(pool.end).toHaveBeenCalled(); + expect(manager.hasPool("user-a")).toBe(false); + expect(manager.size).toBe(0); + }); + + test("closePool is a no-op for unknown keys", async () => { + const manager = createLakebasePoolManager(); + await manager.closePool("nonexistent"); + expect(manager.size).toBe(0); + }); + + test("closeAll closes all pools and clears the map", async () => { + const manager = createLakebasePoolManager(); + manager.getPool("user-a", { user: "user-a" }); + manager.getPool("user-b", { user: "user-b" }); + + await manager.closeAll(); + + expect(mockPools[0].end).toHaveBeenCalled(); + expect(mockPools[1].end).toHaveBeenCalled(); + expect(manager.size).toBe(0); + }); + + test("getPool after closeAll creates a fresh pool", async () => { + const manager = createLakebasePoolManager(); + const first = manager.getPool("user-a", { user: "user-a" }); + + await manager.closeAll(); + const second = manager.getPool("user-a", { user: "user-a" }); + + expect(second).not.toBe(first); + expect(manager.size).toBe(1); + }); + + test("returns cached pool when tokenFingerprint matches", () => { + const manager = createLakebasePoolManager(); + const pool1 = manager.getPool("user-a", { user: "user-a" }, "fp-aaa"); + const pool2 = manager.getPool("user-a", { user: "user-a" }, "fp-aaa"); + + expect(pool1).toBe(pool2); + expect(mockPools).toHaveLength(1); + }); + + test("rebuilds pool when tokenFingerprint changes", () => { + const manager = createLakebasePoolManager(); + const pool1 = manager.getPool("user-a", { user: "user-a" }, "fp-aaa"); + const pool2 = manager.getPool("user-a", { user: "user-a" }, "fp-bbb"); + + expect(pool2).not.toBe(pool1); + expect(pool1.end).toHaveBeenCalled(); + expect(mockPools).toHaveLength(2); + expect(manager.size).toBe(1); + }); + + test("returns cached pool when no tokenFingerprint is provided", () => { + const manager = createLakebasePoolManager(); + const pool1 = manager.getPool("user-a", { user: "user-a" }); + const pool2 = manager.getPool("user-a", { user: "user-a" }); + + expect(pool1).toBe(pool2); + expect(mockPools).toHaveLength(1); + }); +}); diff --git a/packages/appkit/src/connectors/lakebase/tests/routing-pool.test.ts b/packages/appkit/src/connectors/lakebase/tests/routing-pool.test.ts new file mode 100644 index 000000000..f87277616 --- /dev/null +++ b/packages/appkit/src/connectors/lakebase/tests/routing-pool.test.ts @@ -0,0 +1,137 @@ +import type { Pool } from "pg"; +import { describe, expect, test, vi } from "vitest"; +import { RoutingPool } from "../routing-pool"; + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +function makeMockPool(label: string) { + return { + query: vi.fn(async () => ({ rows: [{ source: label }] })), + connect: vi.fn(async () => ({ + query: vi.fn(async () => ({ rows: [{ source: `${label}-client` }] })), + release: vi.fn(), + })), + end: vi.fn(async () => {}), + totalCount: 5, + idleCount: 3, + waitingCount: 0, + } as unknown as Pool; +} + +describe("RoutingPool", () => { + test("routes to SP pool when no user context is active", async () => { + const spPool = makeMockPool("sp"); + const userPool = makeMockPool("user"); + const pool = new RoutingPool(spPool, () => userPool); + + const result = await pool.query("SELECT 1"); + + expect(result.rows).toEqual([{ source: "sp" }]); + expect(spPool.query).toHaveBeenCalledWith("SELECT 1", undefined); + expect(userPool.query).not.toHaveBeenCalled(); + }); + + test("routes to user pool inside runInUserContext", async () => { + const { runInUserContext } = await import( + "../../../context/execution-context" + ); + + const spPool = makeMockPool("sp"); + const userPool = makeMockPool("user"); + const resolveUserPool = vi.fn(() => userPool); + const pool = new RoutingPool(spPool, resolveUserPool); + + const userCtx = { + client: {} as any, + userId: "user-1", + workspaceId: Promise.resolve("ws-1"), + isUserContext: true as const, + }; + const result = await runInUserContext(userCtx, () => + pool.query("SELECT 1"), + ); + + expect(result.rows).toEqual([{ source: "user" }]); + expect(userPool.query).toHaveBeenCalledWith("SELECT 1", undefined); + expect(spPool.query).not.toHaveBeenCalled(); + expect(resolveUserPool).toHaveBeenCalledWith(userCtx); + }); + + test("connect() routes to user pool inside runInUserContext", async () => { + const { runInUserContext } = await import( + "../../../context/execution-context" + ); + + const spPool = makeMockPool("sp"); + const userPool = makeMockPool("user"); + const pool = new RoutingPool(spPool, () => userPool); + + const userCtx = { + client: {} as any, + userId: "user-1", + workspaceId: Promise.resolve("ws-1"), + isUserContext: true as const, + }; + const client = await runInUserContext(userCtx, () => pool.connect()); + + expect(userPool.connect).toHaveBeenCalled(); + expect(spPool.connect).not.toHaveBeenCalled(); + expect(client).toBeDefined(); + }); + + test("forwards query values to user pool inside runInUserContext", async () => { + const { runInUserContext } = await import( + "../../../context/execution-context" + ); + + const spPool = makeMockPool("sp"); + const userPool = makeMockPool("user"); + const pool = new RoutingPool(spPool, () => userPool); + + const userCtx = { + client: {} as any, + userId: "user-1", + workspaceId: Promise.resolve("ws-1"), + isUserContext: true as const, + }; + await runInUserContext(userCtx, () => + pool.query("SELECT * FROM t WHERE id = $1", [42]), + ); + + expect(userPool.query).toHaveBeenCalledWith( + "SELECT * FROM t WHERE id = $1", + [42], + ); + expect(spPool.query).not.toHaveBeenCalled(); + }); + + test("end() closes SP pool", async () => { + const spPool = makeMockPool("sp"); + const pool = new RoutingPool(spPool, () => makeMockPool("user")); + + await pool.end(); + + expect(spPool.end).toHaveBeenCalled(); + }); + + test("forwards monitoring properties from SP pool", () => { + const spPool = makeMockPool("sp"); + const pool = new RoutingPool(spPool, () => makeMockPool("user")); + + expect(pool.totalCount).toBe(5); + expect(pool.idleCount).toBe(3); + expect(pool.waitingCount).toBe(0); + }); +}); diff --git a/packages/appkit/src/context/execution-context.ts b/packages/appkit/src/context/execution-context.ts index d707f52de..8202b9bd0 100644 --- a/packages/appkit/src/context/execution-context.ts +++ b/packages/appkit/src/context/execution-context.ts @@ -89,3 +89,12 @@ export function isInUserContext(): boolean { const ctx = executionContextStorage.getStore(); return ctx !== undefined; } + +/** + * Get the user context if one is active, otherwise `undefined`. + * Unlike `getExecutionContext()`, this does not require `ServiceContext` + * to be initialized and never throws. + */ +export function getUserContext(): UserContext | undefined { + return executionContextStorage.getStore(); +} diff --git a/packages/appkit/src/context/service-context.ts b/packages/appkit/src/context/service-context.ts index 1d860e3d2..fa2f9c3ef 100644 --- a/packages/appkit/src/context/service-context.ts +++ b/packages/appkit/src/context/service-context.ts @@ -1,3 +1,4 @@ +import { createHash } from "node:crypto"; import { type ClientOptions, ConfigError, @@ -113,6 +114,7 @@ export class ServiceContext { token: string, userId: string, userName?: string, + userEmail?: string, ): UserContext { if (!token) { throw AuthenticationError.missingToken("user token"); @@ -137,10 +139,17 @@ export class ServiceContext { getClientOptions(), ); + const tokenFingerprint = createHash("sha256") + .update(token) + .digest("hex") + .slice(0, 16); + return { client: userClient, userId, userName, + userEmail, + tokenFingerprint, warehouseId: serviceCtx.warehouseId, workspaceId: serviceCtx.workspaceId, isUserContext: true, diff --git a/packages/appkit/src/context/tests/service-context.test.ts b/packages/appkit/src/context/tests/service-context.test.ts index 6901cd2eb..8e655721e 100644 --- a/packages/appkit/src/context/tests/service-context.test.ts +++ b/packages/appkit/src/context/tests/service-context.test.ts @@ -216,6 +216,21 @@ describe("ServiceContext", () => { }); }); + test("should include tokenFingerprint derived from the token", () => { + const userCtx = ServiceContext.createUserContext("user-token", "user-1"); + + expect(userCtx.tokenFingerprint).toBeDefined(); + expect(typeof userCtx.tokenFingerprint).toBe("string"); + expect(userCtx.tokenFingerprint).toHaveLength(16); + }); + + test("should produce different fingerprints for different tokens", () => { + const ctxA = ServiceContext.createUserContext("token-aaa", "user-1"); + const ctxB = ServiceContext.createUserContext("token-bbb", "user-1"); + + expect(ctxA.tokenFingerprint).not.toBe(ctxB.tokenFingerprint); + }); + test("should handle missing userName gracefully", () => { const userCtx = ServiceContext.createUserContext("user-token", "user-1"); diff --git a/packages/appkit/src/context/user-context.ts b/packages/appkit/src/context/user-context.ts index 20746c919..dddd9b4bd 100644 --- a/packages/appkit/src/context/user-context.ts +++ b/packages/appkit/src/context/user-context.ts @@ -11,6 +11,10 @@ export interface UserContext { userId: string; /** The user's name (from request headers) */ userName?: string; + /** The user's email (from `x-forwarded-email` header) */ + userEmail?: string; + /** Truncated SHA-256 hash of the user's OBO token, used to detect token rotation */ + tokenFingerprint?: string; /** Promise that resolves to the warehouse ID (inherited from service context, only present when a plugin requires `SQL_WAREHOUSE` resource) */ warehouseId?: Promise; /** Promise that resolves to the workspace ID (inherited from service context) */ diff --git a/packages/appkit/src/core/appkit.ts b/packages/appkit/src/core/appkit.ts index 3421e03bb..24345c6e7 100644 --- a/packages/appkit/src/core/appkit.ts +++ b/packages/appkit/src/core/appkit.ts @@ -10,12 +10,14 @@ import type { } from "shared"; import { version as productVersion } from "../../package.json"; import { CacheManager } from "../cache"; -import { ServiceContext } from "../context"; +import { runInUserContext, ServiceContext } from "../context"; +import type { UserContext } from "../context/user-context"; import { isInternalTelemetryEnabled, TelemetryReporter, } from "../internal-telemetry"; import { createLogger } from "../logging/logger"; +import { USER_CONTEXT_SYMBOL } from "../plugin/plugin"; import { ResourceRegistry, ResourceType } from "../registry"; import type { TelemetryConfig } from "../telemetry"; import { TelemetryManager } from "../telemetry"; @@ -132,6 +134,32 @@ export class AppKit { } } + /** + * Wraps all function properties in an exports object so they run + * inside the given user context (via AsyncLocalStorage). + * This ensures RoutingPool and other context-aware code sees the + * user identity even though the function was obtained outside the proxy. + */ + private wrapExportsInUserContext( + exports: Record, + userContext: UserContext, + ) { + for (const key in exports) { + if (!Object.hasOwn(exports, key)) continue; + const val = exports[key]; + if (typeof val === "function") { + const fn = val as (...args: unknown[]) => unknown; + exports[key] = (...args: unknown[]) => + runInUserContext(userContext, () => fn(...args)); + } else if (AppKit.isPlainObject(val)) { + this.wrapExportsInUserContext( + val as Record, + userContext, + ); + } + } + } + /** * Wraps a plugin's exports with an `asUser` method that returns * a user-scoped version of the exports. @@ -166,11 +194,22 @@ export class AppKit { */ asUser: (req: import("express").Request) => { const userPlugin = (plugin as any).asUser(req); - const userExports = (userPlugin.exports?.() ?? {}) as Record< + const userContext = (userPlugin as any)[ + USER_CONTEXT_SYMBOL + ] as UserContext; + const userExports = (plugin.exports?.() ?? {}) as Record< string, unknown >; - this.bindExportMethods(userExports, userPlugin); + // Wrap each export in runInUserContext instead of bind. + // bind() bypasses the Proxy get trap, so methods called via bind + // would not run inside the user's AsyncLocalStorage context. + if (userContext) { + this.wrapExportsInUserContext(userExports, userContext); + } else { + // Fallback for dev mode proxy (no userContext symbol) + this.bindExportMethods(userExports, userPlugin); + } return userExports; }, }; diff --git a/packages/appkit/src/core/tests/appkit-as-user-exports.test.ts b/packages/appkit/src/core/tests/appkit-as-user-exports.test.ts new file mode 100644 index 000000000..968c064e1 --- /dev/null +++ b/packages/appkit/src/core/tests/appkit-as-user-exports.test.ts @@ -0,0 +1,210 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import type { UserContext } from "../../context/user-context"; + +/** + * Tests the exports-level asUser(req) flow: + * appkit.plugin.asUser(req).method() + * + * Verifies that exported functions are wrapped in runInUserContext(), + * so getUserContext() returns user context during the call — regardless + * of whether the export is a class method or an inline arrow function. + */ + +// ── Mock heavy dependencies ───────────────────────────────────────── + +vi.mock("../../cache", () => ({ + CacheManager: { + getInstance: vi.fn(async () => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(async (_k: unknown[], fn: () => Promise) => + fn(), + ), + generateKey: vi.fn(() => "test-key"), + })), + }, +})); + +vi.mock("../../telemetry", async () => { + const actual = + await vi.importActual("../../telemetry"); + return { + ...actual, + TelemetryManager: { + initialize: vi.fn(), + getProvider: () => ({ + getTracer: () => ({ + startActiveSpan: vi.fn((_name: string, fn: (span: any) => any) => + fn({ end: vi.fn(), setStatus: vi.fn(), recordException: vi.fn() }), + ), + }), + getMeter: () => ({ + createCounter: vi.fn(() => ({ add: vi.fn() })), + createHistogram: vi.fn(() => ({ record: vi.fn() })), + }), + getLogger: () => ({ emit: vi.fn() }), + emit: vi.fn(), + startActiveSpan: vi.fn( + async (_n: string, _o: any, fn: (s: any) => any) => + fn({ end: vi.fn() }), + ), + registerInstrumentations: vi.fn(), + }), + }, + }; +}); + +vi.mock("../../logging/logger", () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), +})); + +vi.mock("../../internal-telemetry", () => ({ + isInternalTelemetryEnabled: vi.fn(() => false), + TelemetryReporter: { report: vi.fn() }, +})); + +// ── Imports (after mocks) ─────────────────────────────────────────── + +import { createMockRequest, setupDatabricksEnv } from "@tools/test-helpers"; +import { getUserContext } from "../../context/execution-context"; +import { ServiceContext } from "../../context/service-context"; +import { Plugin } from "../../plugin/plugin"; +import { toPlugin } from "../../plugin/to-plugin"; +import { createApp } from "../appkit"; + +// ── Mock SDK ──────────────────────────────────────────────────────── + +const { MockWorkspaceClient } = vi.hoisted(() => { + const MockWorkspaceClient = vi.fn().mockImplementation(() => ({ + currentUser: { me: vi.fn().mockResolvedValue({ id: "sp-user-123" }) }, + apiClient: { + request: vi.fn().mockResolvedValue({ "x-databricks-org-id": "ws-456" }), + }, + })); + return { MockWorkspaceClient }; +}); + +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: MockWorkspaceClient, + ConfigError: class extends Error {}, +})); + +// ── Test plugin ───────────────────────────────────────────────────── + +/** Captures getUserContext() at call time and returns it. */ +class ContextProbePlugin extends Plugin { + static manifest = { + name: "probe" as const, + displayName: "Context Probe", + description: "Test plugin that captures user context", + resources: { required: [], optional: [] }, + }; + + /** Class method — discoverable by the proxy. */ + getContext() { + return getUserContext(); + } + + exports() { + return { + // Class method bound to this + getContext: this.getContext.bind(this), + // Inline arrow function — the key case this fix addresses + getContextArrow: () => getUserContext(), + }; + } +} + +const probe = toPlugin(ContextProbePlugin); + +// ── Tests ─────────────────────────────────────────────────────────── + +describe("exports-level asUser(req)", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + vi.clearAllMocks(); + ServiceContext.reset(); + setupDatabricksEnv(); + }); + + afterEach(() => { + process.env = { ...originalEnv }; + ServiceContext.reset(); + }); + + test("class method export runs in user context via asUser(req)", async () => { + const appkit = (await createApp({ plugins: [probe()] })) as any; + + const req = createMockRequest({ + headers: { + "x-forwarded-access-token": "user-token-abc", + "x-forwarded-user": "alice", + "x-forwarded-email": "alice@example.com", + }, + }); + + const userExports = appkit.probe.asUser(req); + const ctx = userExports.getContext() as UserContext; + + expect(ctx).toBeDefined(); + expect(ctx.isUserContext).toBe(true); + expect(ctx.userId).toBe("alice"); + }); + + test("inline arrow function export runs in user context via asUser(req)", async () => { + const appkit = (await createApp({ plugins: [probe()] })) as any; + + const req = createMockRequest({ + headers: { + "x-forwarded-access-token": "user-token-abc", + "x-forwarded-user": "bob", + "x-forwarded-email": "bob@example.com", + }, + }); + + const userExports = appkit.probe.asUser(req); + const ctx = userExports.getContextArrow() as UserContext; + + expect(ctx).toBeDefined(); + expect(ctx.isUserContext).toBe(true); + expect(ctx.userId).toBe("bob"); + }); + + test("SP exports (without asUser) do not have user context", async () => { + const appkit = (await createApp({ plugins: [probe()] })) as any; + + const ctx = appkit.probe.getContext(); + expect(ctx).toBeUndefined(); + }); + + test("dev mode fallback works when no token is present", async () => { + process.env.NODE_ENV = "development"; + + const appkit = (await createApp({ plugins: [probe()] })) as any; + + const req = createMockRequest({ + headers: {}, // No token + }); + + // Should not throw in dev mode + const userExports = appkit.probe.asUser(req); + expect(userExports.getContext).toBeDefined(); + expect(typeof userExports.getContext).toBe("function"); + }); +}); diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 00fd6ff86..b25380737 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -19,13 +19,16 @@ export type { JobsConnectorConfig } from "./connectors/jobs"; export type { DatabaseCredential, GenerateDatabaseCredentialRequest, + LakebasePool, LakebasePoolConfig, + LakebasePoolManager, RequestedClaims, RequestedResource, } from "./connectors/lakebase"; // Lakebase Autoscaling connector export { createLakebasePool, + createLakebasePoolManager, generateDatabaseCredential, getLakebaseOrmConfig, getLakebasePgConfig, diff --git a/packages/appkit/src/plugin/plugin.ts b/packages/appkit/src/plugin/plugin.ts index 49d211913..4c00c41e1 100644 --- a/packages/appkit/src/plugin/plugin.ts +++ b/packages/appkit/src/plugin/plugin.ts @@ -43,6 +43,13 @@ import type { const logger = createLogger("plugin"); +/** + * Symbol used to expose the UserContext from an asUser() proxy. + * Allows wrapWithAsUser in appkit.ts to retrieve the context and + * wrap export methods in runInUserContext(). + */ +export const USER_CONTEXT_SYMBOL = Symbol("appkit.userContext"); + /** * OTel context key for marking OBO dev mode fallback. * Set when asUser() is called in development mode without a user token. @@ -393,6 +400,7 @@ export abstract class Plugin< asUser(req: express.Request): this { const token = req.header("x-forwarded-access-token"); const userId = req.header("x-forwarded-user"); + const userEmail = req.header("x-forwarded-email"); const isDev = process.env.NODE_ENV === "development"; // In local development, skip user impersonation @@ -434,6 +442,8 @@ export abstract class Plugin< const userContext = ServiceContext.createUserContext( token, effectiveUserId, + undefined, + userEmail ?? undefined, ); // Return a proxy that wraps method calls in user context @@ -448,6 +458,9 @@ export abstract class Plugin< private _createUserContextProxy(userContext: UserContext): this { return new Proxy(this, { get: (target, prop, receiver) => { + // Expose userContext via symbol so wrapWithAsUser can wrap exports + if (prop === USER_CONTEXT_SYMBOL) return userContext; + const value = Reflect.get(target, prop, receiver); if (typeof value !== "function") { diff --git a/packages/appkit/src/plugins/lakebase/lakebase.ts b/packages/appkit/src/plugins/lakebase/lakebase.ts index 49930355f..b8b1b16be 100644 --- a/packages/appkit/src/plugins/lakebase/lakebase.ts +++ b/packages/appkit/src/plugins/lakebase/lakebase.ts @@ -1,12 +1,17 @@ -import type { Pool, QueryResult, QueryResultRow } from "pg"; +import type { QueryResult, QueryResultRow } from "pg"; import type { AgentToolDefinition, ToolProvider } from "shared"; import { z } from "zod"; import { createLakebasePool, + createLakebasePoolManager, getLakebaseOrmConfig, getLakebasePgConfig, getUsernameWithApiLookup, + type LakebasePool, + type LakebasePoolManager, + RoutingPool, } from "../../connectors/lakebase"; +import { getUserContext } from "../../context/execution-context"; import { buildToolkitEntries } from "../../core/agent/build-toolkit"; import { defineTool, @@ -22,12 +27,25 @@ import type { ILakebaseConfig } from "./types"; const logger = createLogger("lakebase"); +/** Default pool settings for per-user OBO pools. */ +const OBO_POOL_DEFAULTS = { + max: 3, + allowExitOnIdle: true, + idleTimeoutMillis: 30_000, +}; + /** * AppKit plugin for Databricks Lakebase Autoscaling. * * Wraps `@databricks/lakebase` to provide a standard `pg.Pool` with automatic * OAuth token refresh, integrated with AppKit's logger and OpenTelemetry setup. * + * Supports On-Behalf-Of (OBO) via `asUser(req)` — each user gets a separate + * `pg.Pool` authenticated with their Databricks identity, enabling features + * like Row-Level Security (RLS). Routing is handled transparently by + * {@link RoutingPool}, which reads the execution context set by the base + * class `asUser()`. + * * @example * ```ts * import { createApp, lakebase, server } from "@databricks/appkit"; @@ -36,7 +54,11 @@ const logger = createLogger("lakebase"); * plugins: [server(), lakebase()], * }); * + * // Service principal query * const result = await AppKit.lakebase.query("SELECT * FROM users WHERE id = $1", [userId]); + * + * // User-scoped query (per-user pool, RLS enforced) + * const mine = await AppKit.lakebase.asUser(req).query("SELECT * FROM my_data"); * ``` */ export class LakebasePlugin extends Plugin implements ToolProvider { @@ -44,25 +66,54 @@ export class LakebasePlugin extends Plugin implements ToolProvider { static manifest = manifest as PluginManifest<"lakebase">; protected declare config: ILakebaseConfig; - private pool: Pool | null = null; + private pool: RoutingPool | null = null; + private oboPoolManager: LakebasePoolManager | null = null; /** - * Initializes the Lakebase connection pool. + * Initializes the Lakebase connection pool and OBO pool manager. * Called automatically by AppKit during the plugin setup phase. * - * Resolves the PostgreSQL username via {@link getUsernameWithApiLookup}, - * which tries config, env vars, and finally the Databricks workspace API. + * Creates a {@link RoutingPool} that automatically routes queries to either + * the service-principal pool or a per-user pool based on the execution + * context (set by `Plugin.asUser(req)` via AsyncLocalStorage). */ async setup() { const poolConfig = this.config.pool; const user = await getUsernameWithApiLookup(poolConfig); - this.pool = createLakebasePool({ ...poolConfig, user }); - logger.info("Lakebase pool initialized"); + + const spPool = createLakebasePool({ ...poolConfig, user }); + logger.info("Lakebase SP pool initialized"); + + this.oboPoolManager = createLakebasePoolManager({ + ...poolConfig, + ...OBO_POOL_DEFAULTS, + }); + logger.info("Lakebase OBO pool manager initialized"); + + const oboManager = this.oboPoolManager; + this.pool = new RoutingPool(spPool, (ctx) => { + if (!oboManager) throw new Error("OBO pool manager not initialized"); + // Lakebase OAuth roles use email as the postgres role when available + const userKey = ctx.userEmail ?? ctx.userId; + const isNew = !oboManager.hasPool(userKey); + const pool = oboManager.getPool( + userKey, + { workspaceClient: ctx.client, user: userKey }, + ctx.tokenFingerprint, + ); + if (isNew) { + logger.debug("Created OBO pool for user (total: %d)", oboManager.size); + } + return pool; + }); } /** * Executes a parameterized SQL query against the Lakebase pool. * + * When called inside `asUser(req)`, the query automatically routes to + * the per-user pool via {@link RoutingPool}. + * * @param text - SQL query string, using `$1`, `$2`, ... placeholders * @param values - Parameter values corresponding to placeholders * @returns Query result with typed rows @@ -117,29 +168,30 @@ export class LakebasePlugin extends Plugin implements ToolProvider { } /** - * Gracefully drains and closes the connection pool. + * Gracefully drains and closes all connection pools (SP + OBO). * Called automatically by AppKit during shutdown. */ abortActiveOperations(): void { super.abortActiveOperations(); if (this.pool) { - logger.info("Closing Lakebase pool"); + logger.info("Closing Lakebase SP pool"); this.pool.end().catch((err) => { - logger.error("Error closing Lakebase pool: %O", err); + logger.error("Error closing Lakebase SP pool: %O", err); }); this.pool = null; } + if (this.oboPoolManager) { + logger.info( + "Closing all Lakebase OBO pools (%d)", + this.oboPoolManager.size, + ); + this.oboPoolManager.closeAll().catch((err) => { + logger.error("Error closing Lakebase OBO pools: %O", err); + }); + this.oboPoolManager = null; + } } - /** - * Returns the plugin's public API, accessible via `AppKit.lakebase`. - * - * - `pool` — The raw `pg.Pool` instance, for use with ORMs or advanced scenarios - * - `query` — Convenience method for executing parameterized SQL queries - * - `getOrmConfig()` — Returns a config object compatible with Drizzle, TypeORM, Sequelize, etc. - * - `getPgConfig()` — Returns a `pg.PoolConfig` object for manual pool construction - */ - /** * Agent tool registry. Empty by default — the Lakebase plugin does NOT * expose its SQL connection to LLM agents unless the developer explicitly @@ -153,7 +205,7 @@ export class LakebasePlugin extends Plugin implements ToolProvider { if (config.exposeAsAgentTool) { this.tools = { query: this.buildQueryTool(config.exposeAsAgentTool) }; logger.warn( - "Lakebase agent tool is enabled (readOnly=%s). Every agent with access to this plugin can execute SQL against the Lakebase database as the service principal.", + "Lakebase agent tool is enabled (readOnly=%s). Every agent with access to this plugin can execute SQL against the Lakebase database as the requesting user's identity.", config.exposeAsAgentTool.readOnly !== false, ); } @@ -165,8 +217,8 @@ export class LakebasePlugin extends Plugin implements ToolProvider { const readOnly = opt.readOnly !== false; return defineTool({ description: readOnly - ? "Execute a read-only SQL query against the Lakebase PostgreSQL database. Only SELECT, WITH, SHOW, EXPLAIN, and DESCRIBE statements are accepted. Use $1, $2, etc. as placeholders and pass values separately. Runs as the application's service principal." - : "Execute a parameterized SQL statement against the Lakebase PostgreSQL database. Use $1, $2, etc. as placeholders and pass values separately. Runs as the application's service principal. This tool can modify data; every invocation requires explicit human approval.", + ? "Execute a read-only SQL query against the Lakebase PostgreSQL database. Only SELECT, WITH, SHOW, EXPLAIN, and DESCRIBE statements are accepted. Use $1, $2, etc. as placeholders and pass values separately." + : "Execute a parameterized SQL statement against the Lakebase PostgreSQL database. Use $1, $2, etc. as placeholders and pass values separately. This tool can modify data; every invocation requires explicit human approval.", schema: z.object({ text: z .string() @@ -181,6 +233,7 @@ export class LakebasePlugin extends Plugin implements ToolProvider { annotations: { effect: readOnly ? "read" : "destructive", idempotent: false, + requiresUserContext: true, }, execute: async (args, signal) => { // Matches the files plugin pattern: the pg connection API @@ -217,13 +270,38 @@ export class LakebasePlugin extends Plugin implements ToolProvider { return buildToolkitEntries(this.name, this.tools, opts); } + /** + * Returns the pool config for the current execution context. + * Inside `asUser(req)`, returns user-scoped config; otherwise SP config. + */ + private activePoolConfig() { + const ctx = getUserContext(); + if (ctx) { + const user = ctx.userEmail ?? ctx.userId; + return { ...this.config.pool, workspaceClient: ctx.client, user }; + } + return this.config.pool; + } + + /** + * Returns the plugin's public API, accessible via `AppKit.lakebase`. + * + * - `pool` — The connection pool (routes to per-user pool when inside `asUser(req)`) + * - `query` — Convenience method for executing parameterized SQL queries + * - `getOrmConfig()` — Returns a config object compatible with Drizzle, TypeORM, Sequelize, etc. + * Inside `asUser(req)`, returns user-scoped config. + * - `getPgConfig()` — Returns a `pg.PoolConfig` object for manual pool construction. + * Inside `asUser(req)`, returns user-scoped config. + * + * Use `AppKit.lakebase.asUser(req)` to get the same API backed by a per-user pool. + */ exports() { return { // biome-ignore lint/style/noNonNullAssertion: pool is guaranteed non-null after setup(), which AppKit always awaits before exposing the plugin API - pool: this.pool!, + pool: this.pool! as LakebasePool, query: this.query.bind(this), - getOrmConfig: () => getLakebaseOrmConfig(this.config.pool), - getPgConfig: () => getLakebasePgConfig(this.config.pool), + getOrmConfig: () => getLakebaseOrmConfig(this.activePoolConfig()), + getPgConfig: () => getLakebasePgConfig(this.activePoolConfig()), }; } } diff --git a/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts b/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts index 756423178..24d37341f 100644 --- a/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts +++ b/packages/appkit/src/plugins/lakebase/tests/lakebase-agent-tool.test.ts @@ -5,8 +5,8 @@ import { beforeEach, describe, expect, test, vi } from "vitest"; * * The plugin defaults to **not** exposing an agent tool at all. Enabling the * tool is an explicit opt-in (`exposeAsAgentTool` with an acknowledgement - * flag) because every invocation runs with the application's service- - * principal credentials regardless of which end user initiated the request. + * flag) because every invocation runs with the caller's execution context + * (SP or per-user via RoutingPool). */ vi.mock("../../../cache", () => ({ @@ -30,28 +30,42 @@ vi.mock("../../../cache", () => ({ const clientQueries: Array<{ text: string; values?: unknown[] }> = []; const clientReleases: number[] = []; -vi.mock("../../../connectors/lakebase", () => ({ - createLakebasePool: vi.fn(() => ({ - query: vi.fn(), - connect: vi.fn(async () => { - let releaseCalls = 0; - return { - query: vi.fn(async (text: string, values?: unknown[]) => { - clientQueries.push({ text, values }); - return { rows: [{ n: 1 }] }; - }), - release: vi.fn(() => { - releaseCalls += 1; - clientReleases.push(releaseCalls); - }), - }; - }), - end: vi.fn(), - })), - getLakebaseOrmConfig: vi.fn(() => ({})), - getLakebasePgConfig: vi.fn(() => ({})), - getUsernameWithApiLookup: vi.fn(async () => "test-user"), -})); +vi.mock("../../../connectors/lakebase", async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + createLakebasePool: vi.fn(() => ({ + query: vi.fn(), + connect: vi.fn(async () => { + let releaseCalls = 0; + return { + query: vi.fn(async (text: string, values?: unknown[]) => { + clientQueries.push({ text, values }); + return { rows: [{ n: 1 }] }; + }), + release: vi.fn(() => { + releaseCalls += 1; + clientReleases.push(releaseCalls); + }), + }; + }), + end: vi.fn(), + totalCount: 0, + idleCount: 0, + waitingCount: 0, + })), + createLakebasePoolManager: vi.fn(() => ({ + getPool: vi.fn(), + hasPool: vi.fn(() => false), + closeAll: vi.fn(async () => {}), + size: 0, + })), + getLakebaseOrmConfig: vi.fn(() => ({})), + getLakebasePgConfig: vi.fn(() => ({})), + getUsernameWithApiLookup: vi.fn(async () => "test-user"), + }; +}); import type { Pool, PoolClient } from "pg"; import { LakebasePlugin } from "../lakebase"; @@ -83,6 +97,7 @@ describe("LakebasePlugin — agent tool opt-in", () => { expect(defs[0].annotations).toEqual({ effect: "read", idempotent: false, + requiresUserContext: true, }); }); @@ -94,6 +109,7 @@ describe("LakebasePlugin — agent tool opt-in", () => { expect(defs[0].annotations).toEqual({ effect: "destructive", idempotent: false, + requiresUserContext: true, }); }); }); @@ -145,9 +161,6 @@ describe("LakebasePlugin — readOnly enforcement", () => { }); test("forwards parameter values to the user statement only (the regression fix)", async () => { - // Prior to the fix this would have failed with "cannot insert multiple - // commands into a prepared statement" because pg's Extended Query - // protocol rejects multi-statement batches when values are supplied. await plugin.executeAgentTool("query", { text: "SELECT * FROM users WHERE id = $1", values: [42], @@ -160,8 +173,6 @@ describe("LakebasePlugin — readOnly enforcement", () => { }); test("releases the client even when the user statement throws", async () => { - // Poison the client so the middle query throws (simulates a Postgres - // error like "cannot execute UPDATE in a read-only transaction"). const { createLakebasePool } = await import("../../../connectors/lakebase"); const fakeClient = { query: vi @@ -179,6 +190,9 @@ describe("LakebasePlugin — readOnly enforcement", () => { query: vi.fn(), connect: vi.fn(async (): Promise => fakeClient), end: vi.fn(), + totalCount: 0, + idleCount: 0, + waitingCount: 0, } as unknown as Pool); clientQueries.length = 0; @@ -222,3 +236,131 @@ describe("LakebasePlugin — destructive mode", () => { ); }); }); + +describe("LakebasePlugin — OBO via RoutingPool", () => { + const userPoolQueries: Array<{ text: string; values?: unknown[] }> = []; + const userClientQueries: Array<{ text: string; values?: unknown[] }> = []; + + function makeUserPool() { + return { + query: vi.fn(async (text: string, values?: unknown[]) => { + userPoolQueries.push({ text, values }); + return { rows: [{ from: "user-pool" }] }; + }), + connect: vi.fn(async () => ({ + query: vi.fn(async (text: string, values?: unknown[]) => { + userClientQueries.push({ text, values }); + return { rows: [{ from: "user-pool-client" }] }; + }), + release: vi.fn(), + })), + end: vi.fn(), + totalCount: 0, + idleCount: 0, + waitingCount: 0, + }; + } + + beforeEach(async () => { + userPoolQueries.length = 0; + userClientQueries.length = 0; + clientQueries.length = 0; + + const { createLakebasePoolManager } = await import( + "../../../connectors/lakebase" + ); + vi.mocked(createLakebasePoolManager).mockReturnValue({ + getPool: vi.fn(() => makeUserPool() as unknown as Pool), + hasPool: vi.fn(() => false), + closePool: vi.fn(async () => {}), + closeAll: vi.fn(async () => {}), + get size() { + return 1; + }, + }); + }); + + test("read-only query routes to user pool inside runInUserContext", async () => { + const { runInUserContext } = await import( + "../../../context/execution-context" + ); + const plugin = makePlugin({ exposeAsAgentTool: {} }); + await plugin.setup(); + + const userCtx = { + client: {} as any, + userId: "user-123", + userEmail: "alice@example.com", + workspaceId: Promise.resolve("ws-1"), + isUserContext: true as const, + }; + + const result = await runInUserContext(userCtx, () => + plugin.executeAgentTool("query", { text: "SELECT 1" }), + ); + + expect(result).toEqual([{ from: "user-pool-client" }]); + expect(userClientQueries.map((c) => c.text)).toEqual([ + "BEGIN READ ONLY", + "SELECT 1", + "ROLLBACK", + ]); + // SP pool should NOT have been touched + expect(clientQueries).toHaveLength(0); + }); + + test("destructive query routes to user pool inside runInUserContext", async () => { + const { runInUserContext } = await import( + "../../../context/execution-context" + ); + + const plugin = makePlugin({ exposeAsAgentTool: { readOnly: false } }); + await plugin.setup(); + + const userCtx = { + client: {} as any, + userId: "user-123", + userEmail: "alice@example.com", + workspaceId: Promise.resolve("ws-1"), + isUserContext: true as const, + }; + + const result = await runInUserContext(userCtx, () => + plugin.executeAgentTool("query", { + text: "UPDATE t SET x=1", + values: [42], + }), + ); + + expect(result).toEqual([{ from: "user-pool" }]); + expect(userPoolQueries).toEqual([ + { text: "UPDATE t SET x=1", values: [42] }, + ]); + expect(clientQueries).toHaveLength(0); + }); + + test("read-only policy still enforced in user context", async () => { + const { runInUserContext } = await import( + "../../../context/execution-context" + ); + + const plugin = makePlugin({ exposeAsAgentTool: {} }); + await plugin.setup(); + + const userCtx = { + client: {} as any, + userId: "user-123", + workspaceId: Promise.resolve("ws-1"), + isUserContext: true as const, + }; + + await expect( + runInUserContext(userCtx, () => + plugin.executeAgentTool("query", { text: "DROP TABLE users" }), + ), + ).rejects.toThrow(/read-only policy violation/i); + + expect(userClientQueries).toHaveLength(0); + expect(clientQueries).toHaveLength(0); + }); +}); diff --git a/template/server/routes/lakebase/todo-routes.ts b/template/server/routes/lakebase/todo-routes.ts index 32c47ab8b..7183a4c17 100644 --- a/template/server/routes/lakebase/todo-routes.ts +++ b/template/server/routes/lakebase/todo-routes.ts @@ -1,4 +1,7 @@ {{if .plugins.lakebase -}} +// For per-user connections (OBO) with Row-Level Security, see: +// https://www.databricks.com/devhub/docs/appkit/v0/plugins/lakebase#on-behalf-of-obo--per-user-connections + import { z } from 'zod'; import { Application } from 'express';