From 4cd8ea8b4f194a4745bb9824a774fe1740715e9e Mon Sep 17 00:00:00 2001 From: whoisasx Date: Fri, 19 Jun 2026 02:43:39 +0530 Subject: [PATCH 1/9] fix: keep terminal mux persistent across navigation --- backend/internal/terminal/manager.go | 6 + backend/internal/terminal/manager_test.go | 155 ++++++++++++ .../src/renderer/components/SessionView.tsx | 7 +- .../src/renderer/components/TerminalPane.tsx | 6 +- .../hooks/useTerminalSession.test.tsx | 168 +++++++------ .../src/renderer/hooks/useTerminalSession.ts | 134 +++++----- frontend/src/renderer/lib/shell-context.ts | 2 + .../lib/terminal-mux-transport.test.ts | 156 ++++++++++++ .../renderer/lib/terminal-mux-transport.ts | 237 ++++++++++++++++++ frontend/src/renderer/lib/terminal-mux.ts | 2 +- frontend/src/renderer/routes/_shell.tsx | 12 +- 11 files changed, 727 insertions(+), 158 deletions(-) create mode 100644 frontend/src/renderer/lib/terminal-mux-transport.test.ts create mode 100644 frontend/src/renderer/lib/terminal-mux-transport.ts diff --git a/backend/internal/terminal/manager.go b/backend/internal/terminal/manager.go index 7b22608b..f88d5b4d 100644 --- a/backend/internal/terminal/manager.go +++ b/backend/internal/terminal/manager.go @@ -225,6 +225,12 @@ func (c *connState) openTerminal(id string, rows, cols uint16) { var a *attachment a = newAttachment(id, ports.RuntimeHandle{ID: id}, c.mgr.src, c.mgr.spawn, func(data []byte) { + c.mu.Lock() + current := c.terms[id] + c.mu.Unlock() + if current != a { + return + } c.enqueue(serverMsg{ Ch: chTerminal, ID: id, diff --git a/backend/internal/terminal/manager_test.go b/backend/internal/terminal/manager_test.go index dc0730b8..3f55060f 100644 --- a/backend/internal/terminal/manager_test.go +++ b/backend/internal/terminal/manager_test.go @@ -112,6 +112,18 @@ func nextTerminal(t *testing.T, c *fakeConn) serverMsg { } } +func assertNoTerminalFrame(t *testing.T, c *fakeConn, typ string, d time.Duration) { + t.Helper() + select { + case m := <-c.out: + if m.Ch == chTerminal && m.Type == typ { + t.Fatalf("received unexpected terminal/%s frame", typ) + } + t.Fatalf("received unexpected frame %s/%s", m.Ch, m.Type) + case <-time.After(d): + } +} + // Opening a pane whose runtime is already dead must (1) send opened before // exited (the dead pane is reported, not errored) and (2) clear the conn's // entry, so a later open for the same id on this connection is still served @@ -170,6 +182,149 @@ func TestServeExitAfterOpenClearsEntryAllowingReopen(t *testing.T) { recv(t, conn, chTerminal, msgOpened, 2*time.Second) } +func TestServeSameSocketSwitchesTerminalByCloseThenOpen(t *testing.T) { + src := &fakeSource{alive: true} + p1, p2 := newFakePTY(), newFakePTY() + sp := &fakeSpawner{ptys: []*fakePTY{p1, p2}} + mgr := NewManager(src, nil, testLogger(), WithSpawn(sp.spawn), WithHeartbeat(0)) + defer mgr.Close() + + conn := newFakeConn() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go mgr.Serve(ctx, conn) + + conn.in <- clientMsg{Ch: chTerminal, ID: "t1", Type: msgOpen} + recv(t, conn, chTerminal, msgOpened, time.Second) + eventually(t, time.Second, func() bool { return sp.calls() == 1 }) + + conn.in <- clientMsg{Ch: chTerminal, ID: "t1", Type: msgClose} + eventually(t, time.Second, func() bool { + select { + case <-p1.closed: + return true + default: + return false + } + }) + assertNoTerminalFrame(t, conn, msgExited, 50*time.Millisecond) + + conn.in <- clientMsg{Ch: chTerminal, ID: "t2", Type: msgOpen} + msg := recv(t, conn, chTerminal, msgOpened, time.Second) + if msg.ID != "t2" { + t.Fatalf("opened id = %q, want t2", msg.ID) + } + eventually(t, time.Second, func() bool { return sp.calls() == 2 }) + select { + case <-p2.closed: + t.Fatal("opening t2 on the same socket must not immediately close t2") + default: + } +} + +func TestServeConnectionCleanupClosesAllOpenAttachments(t *testing.T) { + src := &fakeSource{alive: true} + p1, p2 := newFakePTY(), newFakePTY() + sp := &fakeSpawner{ptys: []*fakePTY{p1, p2}} + mgr := NewManager(src, nil, testLogger(), WithSpawn(sp.spawn), WithHeartbeat(0)) + defer mgr.Close() + + conn := newFakeConn() + ctx, cancel := context.WithCancel(context.Background()) + go mgr.Serve(ctx, conn) + + conn.in <- clientMsg{Ch: chTerminal, ID: "t1", Type: msgOpen} + recv(t, conn, chTerminal, msgOpened, time.Second) + conn.in <- clientMsg{Ch: chTerminal, ID: "t2", Type: msgOpen} + recv(t, conn, chTerminal, msgOpened, time.Second) + eventually(t, time.Second, func() bool { return sp.calls() == 2 }) + + cancel() + eventually(t, time.Second, func() bool { + select { + case <-p1.closed: + default: + return false + } + select { + case <-p2.closed: + return true + default: + return false + } + }) +} + +type latePTY struct { + out chan []byte + done chan struct{} + doneOnce sync.Once +} + +func newLatePTY() *latePTY { + return &latePTY{out: make(chan []byte, 4), done: make(chan struct{})} +} + +func (p *latePTY) push(b []byte) { p.out <- b } + +func (p *latePTY) finish() { p.doneOnce.Do(func() { close(p.done) }) } + +func (p *latePTY) Read(b []byte) (int, error) { + select { + case chunk := <-p.out: + return copy(b, chunk), nil + case <-p.done: + return 0, context.Canceled + } +} + +func (p *latePTY) Write(b []byte) (int, error) { return len(b), nil } + +func (p *latePTY) Resize(uint16, uint16) error { return nil } + +func (p *latePTY) Close() error { return nil } + +func TestServeDropsLateDataFromSupersededAttachment(t *testing.T) { + src := &fakeSource{alive: true} + oldPTY := newLatePTY() + newPTY := newFakePTY() + var mu sync.Mutex + ptys := []ptyProcess{oldPTY, newPTY} + calls := 0 + spawn := func(context.Context, []string, uint16, uint16) (ptyProcess, error) { + mu.Lock() + defer mu.Unlock() + p := ptys[calls] + calls++ + return p, nil + } + spawnCalls := func() int { + mu.Lock() + defer mu.Unlock() + return calls + } + mgr := NewManager(src, nil, testLogger(), WithSpawn(spawn), WithHeartbeat(0)) + defer mgr.Close() + defer oldPTY.finish() + + conn := newFakeConn() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go mgr.Serve(ctx, conn) + + conn.in <- clientMsg{Ch: chTerminal, ID: "t1", Type: msgOpen} + recv(t, conn, chTerminal, msgOpened, time.Second) + eventually(t, time.Second, func() bool { return spawnCalls() == 1 }) + + conn.in <- clientMsg{Ch: chTerminal, ID: "t1", Type: msgClose} + conn.in <- clientMsg{Ch: chTerminal, ID: "t1", Type: msgOpen} + recv(t, conn, chTerminal, msgOpened, time.Second) + eventually(t, time.Second, func() bool { return spawnCalls() == 2 }) + + oldPTY.push([]byte("stale output")) + assertNoTerminalFrame(t, conn, msgData, 50*time.Millisecond) +} + // An attachment that exits the moment it is opened (dead runtime) fires onExit // from its run goroutine, racing the reopen that follows the exited frame. The // identity-guarded delete in onExit must never evict a successor attachment diff --git a/frontend/src/renderer/components/SessionView.tsx b/frontend/src/renderer/components/SessionView.tsx index 9acf8b0d..f267c93c 100644 --- a/frontend/src/renderer/components/SessionView.tsx +++ b/frontend/src/renderer/components/SessionView.tsx @@ -28,9 +28,10 @@ type SessionViewProps = { // shell-owned ShellTopbar. Rendered by both the project-scoped and // cross-project session routes. The terminal lives here (not in the shell) — // switching sessions only changes route params, so TanStack Router keeps this -// component mounted and the terminal re-points its mux without remounting -// (useTerminalSession). Leaving for the board unmounts it; a fresh server-side -// zellij attach repaints the pane on return. +// component mounted and the terminal switches visible handles over the +// shell-owned mux without remounting (useTerminalSession). Leaving for the +// board unmounts the visible attachment and sends close(handle), while the +// shell keeps the /mux websocket alive. // // The split is shadcn's resizable (react-resizable-panels v4) with a fully // collapsible inspector: the panel is `collapsible` and driven to 0% via the diff --git a/frontend/src/renderer/components/TerminalPane.tsx b/frontend/src/renderer/components/TerminalPane.tsx index 59a6b9ff..8808c327 100644 --- a/frontend/src/renderer/components/TerminalPane.tsx +++ b/frontend/src/renderer/components/TerminalPane.tsx @@ -3,6 +3,7 @@ import type { TerminalTarget } from "../types/terminal"; import type { WorkspaceSession } from "../types/workspace"; import type { Theme } from "../stores/ui-store"; import { useTerminalSession, type AttachableTerminal, type TerminalSessionState } from "../hooks/useTerminalSession"; +import { useShell } from "../lib/shell-context"; import { XtermTerminal } from "./XtermTerminal"; type TerminalPaneProps = { @@ -38,7 +39,8 @@ function bannerText(state: TerminalSessionState, error?: string): string | undef return undefined; } -function AttachedTerminal({ session, theme, daemonReady, terminalTarget }: TerminalPaneProps) { +function AttachedTerminal({ session, theme, terminalTarget }: TerminalPaneProps) { + const { terminalMux } = useShell(); const attachSession = session && terminalTarget?.kind === "reviewer" ? { ...session, terminalHandleId: terminalTarget.handleId } @@ -49,7 +51,7 @@ function AttachedTerminal({ session, theme, daemonReady, terminalTarget }: Termi // renderer mid-switch and lose the warm GPU surface. const [terminal, setTerminal] = useState(null); const [initFailed, setInitFailed] = useState(false); - const { attach, state, error } = useTerminalSession(attachSession, { daemonReady }); + const { attach, state, error } = useTerminalSession(attachSession, { mux: terminalMux }); const handleId = attachSession?.terminalHandleId; const hadAttachmentRef = useRef(false); diff --git a/frontend/src/renderer/hooks/useTerminalSession.test.tsx b/frontend/src/renderer/hooks/useTerminalSession.test.tsx index 06f76aaf..d10743de 100644 --- a/frontend/src/renderer/hooks/useTerminalSession.test.tsx +++ b/frontend/src/renderer/hooks/useTerminalSession.test.tsx @@ -24,7 +24,7 @@ type FakeMux = { opens: Array<[string, number, number]>; resizes: Array<[string, number, number]>; inputs: Array<[string, string]>; - disposed: boolean; + closes: string[]; emitData(id: string, text: string): void; emitOpened(id: string): void; emitExit(id: string): void; @@ -45,17 +45,26 @@ function createFakeMux(): FakeMux { const opened = new Map void>>(); const error = new Map void>>(); const connection = new Set<(state: MuxConnectionState) => void>(); + let connectionState: MuxConnectionState = "open"; const fake: FakeMux = { opens: [], resizes: [], inputs: [], - disposed: false, + closes: [], mux: { - open: (id, cols, rows) => fake.opens.push([id, cols, rows]), - sendInput: (id, input) => fake.inputs.push([id, input]), - resize: (id, cols, rows) => fake.resizes.push([id, cols, rows]), - close: () => undefined, + open: (id, cols, rows) => { + if (connectionState === "open") fake.opens.push([id, cols, rows]); + }, + sendInput: (id, input) => { + if (connectionState === "open") fake.inputs.push([id, input]); + }, + resize: (id, cols, rows) => { + if (connectionState === "open") fake.resizes.push([id, cols, rows]); + }, + close: (id) => { + if (connectionState === "open") fake.closes.push(id); + }, onData: (id, listener) => subscribe(data, id, listener), onExit: (id, listener) => subscribe(exit, id, listener), onOpened: (id, listener) => subscribe(opened, id, listener), @@ -64,15 +73,16 @@ function createFakeMux(): FakeMux { connection.add(listener); return () => connection.delete(listener); }, - dispose: () => { - fake.disposed = true; - }, + dispose: () => undefined, }, emitData: (id, text) => data.get(id)?.forEach((listener) => listener(new TextEncoder().encode(text))), emitOpened: (id) => opened.get(id)?.forEach((listener) => listener()), emitExit: (id) => exit.get(id)?.forEach((listener) => listener()), emitError: (id, message) => error.get(id)?.forEach((listener) => listener(message)), - emitConnection: (state) => connection.forEach((listener) => listener(state)), + emitConnection: (state) => { + connectionState = state; + connection.forEach((listener) => listener(state)); + }, }; return fake; } @@ -111,28 +121,22 @@ function createFakeTerminal(): FakeTerminal { return terminal; } -function setup({ daemonReady = true, attachedSession = session as WorkspaceSession | undefined } = {}) { - const muxes: FakeMux[] = []; - const createMux = () => { - const fake = createFakeMux(); - muxes.push(fake); - return fake.mux; - }; +function setup({ attachedSession = session as WorkspaceSession | undefined, mux = createFakeMux() } = {}) { const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } }); const invalidateSpy = vi.spyOn(queryClient, "invalidateQueries"); const wrapper = ({ children }: { children: ReactNode }) => ( {children} ); const view = renderHook( - ({ daemonReady: ready }) => useTerminalSession(attachedSession, { daemonReady: ready, createMux }), - { initialProps: { daemonReady }, wrapper }, + ({ currentSession }) => useTerminalSession(currentSession, { mux: mux.mux }), + { initialProps: { currentSession: attachedSession }, wrapper }, ); const terminal = createFakeTerminal(); let detach: () => void = () => undefined; act(() => { detach = view.result.current.attach(terminal); }); - return { view, terminal, muxes, invalidateSpy, detach: () => detach() }; + return { view, terminal, mux, invalidateSpy, detach: () => detach() }; } beforeEach(() => { @@ -146,58 +150,57 @@ afterEach(() => { describe("useTerminalSession", () => { it("opens the pane at the terminal's size and reaches attached on the server ack", () => { - const { view, muxes } = setup(); + const { view, mux } = setup(); expect(view.result.current.state).toBe("connecting"); - expect(muxes).toHaveLength(1); - expect(muxes[0].opens).toEqual([["handle-1", 80, 24]]); - act(() => muxes[0].emitOpened("handle-1")); + expect(mux.opens).toEqual([["handle-1", 80, 24]]); + act(() => mux.emitOpened("handle-1")); expect(view.result.current.state).toBe("attached"); }); it("stays idle when the session has no terminal handle", () => { - const { view, muxes } = setup({ attachedSession: { ...session, terminalHandleId: undefined } }); + const { view, mux } = setup({ attachedSession: { ...session, terminalHandleId: undefined } }); expect(view.result.current.state).toBe("idle"); - expect(muxes).toHaveLength(0); + expect(mux.opens).toHaveLength(0); }); it("forwards PTY output, keystrokes, and resizes across the attachment", () => { - const { terminal, muxes } = setup(); - act(() => muxes[0].emitData("handle-1", "hello")); + const { terminal, mux } = setup(); + act(() => mux.emitData("handle-1", "hello")); expect(terminal.lines).toContain("hello"); terminal.typeKeys("ls\r"); - expect(muxes[0].inputs).toEqual([["handle-1", "ls\r"]]); + expect(mux.inputs).toEqual([["handle-1", "ls\r"]]); terminal.emitResize(120, 40); act(() => void vi.advanceTimersByTime(100)); - expect(muxes[0].resizes).toContainEqual(["handle-1", 120, 40]); + expect(mux.resizes).toContainEqual(["handle-1", 120, 40]); }); it("collapses a drag's burst of grid changes into one trailing PTY resize, then re-asserts it", () => { - const { terminal, muxes } = setup(); - const initialResizes = muxes[0].resizes.length; // connect() sends the opening size + const { terminal, mux } = setup(); + const initialResizes = mux.resizes.length; // attach sends the opening size terminal.emitResize(100, 30); terminal.emitResize(110, 34); terminal.emitResize(120, 40); act(() => void vi.advanceTimersByTime(100)); - expect(muxes[0].resizes.slice(initialResizes)).toEqual([["handle-1", 120, 40]]); + expect(mux.resizes.slice(initialResizes)).toEqual([["handle-1", 120, 40]]); // The settled grid goes out once more: paired with the backend's explicit // SIGWINCH (pty_unix.go) it re-syncs a zellij client that lost the // original update, which otherwise kept the session laid out for the old // size until the next real grid change. act(() => void vi.advanceTimersByTime(250)); - expect(muxes[0].resizes.slice(initialResizes)).toEqual([ + expect(mux.resizes.slice(initialResizes)).toEqual([ ["handle-1", 120, 40], ["handle-1", 120, 40], ]); }); it("a new resize burst supersedes a pending re-assert", () => { - const { terminal, muxes } = setup(); - const initialResizes = muxes[0].resizes.length; + const { terminal, mux } = setup(); + const initialResizes = mux.resizes.length; terminal.emitResize(100, 30); act(() => void vi.advanceTimersByTime(100)); // settles -> sent, re-assert pending terminal.emitResize(120, 40); // user keeps dragging before the re-assert fires act(() => void vi.advanceTimersByTime(100 + 250)); - expect(muxes[0].resizes.slice(initialResizes)).toEqual([ + expect(mux.resizes.slice(initialResizes)).toEqual([ ["handle-1", 100, 30], ["handle-1", 120, 40], ["handle-1", 120, 40], @@ -205,69 +208,78 @@ describe("useTerminalSession", () => { }); it("marks exit in the terminal and refetches workspace state instead of writing status", () => { - const { view, terminal, muxes, invalidateSpy } = setup(); - act(() => muxes[0].emitExit("handle-1")); + const { view, terminal, mux, invalidateSpy } = setup(); + act(() => mux.emitExit("handle-1")); expect(view.result.current.state).toBe("exited"); expect(terminal.lines.some((line) => line.includes("[process exited]"))).toBe(true); expect(invalidateSpy).toHaveBeenCalledWith({ queryKey: workspaceQueryKey }); }); it("surfaces pane errors and refetches, with no automatic retry", () => { - const { view, muxes, invalidateSpy } = setup(); - act(() => muxes[0].emitError("handle-1", "no such pane")); + const { view, mux, invalidateSpy } = setup(); + act(() => mux.emitError("handle-1", "no such pane")); expect(view.result.current.state).toBe("error"); expect(view.result.current.error).toBe("no such pane"); expect(invalidateSpy).toHaveBeenCalledWith({ queryKey: workspaceQueryKey }); - act(() => muxes[0].emitConnection("closed")); + act(() => mux.emitConnection("closed")); act(() => void vi.advanceTimersByTime(60_000)); - expect(muxes).toHaveLength(1); + expect(mux.opens).toHaveLength(1); }); - it("reattaches with a fresh mux after a socket drop, clearing the stale screen", () => { - const { view, terminal, muxes } = setup(); - act(() => muxes[0].emitOpened("handle-1")); - act(() => muxes[0].emitConnection("closed")); + it("reattaches on the same mux after a socket drop, clearing the stale screen", () => { + const { view, terminal, mux } = setup(); + act(() => mux.emitOpened("handle-1")); + act(() => mux.emitConnection("closed")); expect(view.result.current.state).toBe("reattaching"); - act(() => void vi.advanceTimersByTime(500)); - expect(muxes).toHaveLength(2); - expect(muxes[0].disposed).toBe(true); + act(() => mux.emitConnection("open")); expect(terminal.clears).toBe(1); // the fresh zellij attach repaints over a blank grid - expect(muxes[1].opens).toEqual([["handle-1", 80, 24]]); - act(() => muxes[1].emitOpened("handle-1")); + expect(mux.opens).toEqual([ + ["handle-1", 80, 24], + ["handle-1", 80, 24], + ]); + act(() => mux.emitOpened("handle-1")); expect(view.result.current.state).toBe("attached"); }); - it("backs off between failed reconnect attempts", () => { - const { muxes } = setup(); - act(() => muxes[0].emitConnection("closed")); - act(() => void vi.advanceTimersByTime(500)); // attempt 1 after 500ms - expect(muxes).toHaveLength(2); - act(() => muxes[1].emitConnection("closed")); - act(() => void vi.advanceTimersByTime(500)); // attempt 2 needs 1000ms - expect(muxes).toHaveLength(2); - act(() => void vi.advanceTimersByTime(500)); - expect(muxes).toHaveLength(3); + it("does not replay user input typed while the mux is disconnected", () => { + const { terminal, mux } = setup(); + act(() => mux.emitConnection("closed")); + terminal.typeKeys("hidden\r"); + expect(mux.inputs).toEqual([]); + act(() => mux.emitConnection("open")); + expect(mux.inputs).toEqual([]); }); - it("waits for daemon readiness instead of retrying, then reconnects when it flips", () => { - const { view, muxes } = setup({ daemonReady: false }); - act(() => muxes[0].emitConnection("closed")); - expect(view.result.current.state).toBe("reattaching"); + it("detach closes the visible handle, stops reattach, and returns to idle", () => { + const { view, mux, detach } = setup(); + act(() => detach()); + expect(view.result.current.state).toBe("idle"); + expect(mux.closes).toEqual(["handle-1"]); + act(() => mux.emitConnection("closed")); act(() => void vi.advanceTimersByTime(60_000)); - expect(muxes).toHaveLength(1); // no retries against a dead daemon - view.rerender({ daemonReady: true }); - expect(muxes).toHaveLength(2); // reconnects immediately, without backoff debt - act(() => muxes[1].emitOpened("handle-1")); - expect(view.result.current.state).toBe("attached"); + expect(mux.opens).toHaveLength(1); }); - it("detach disposes the mux, stops reattach, and returns to idle", () => { - const { view, muxes, detach } = setup(); + it("closes the old handle and opens the new handle on session switch", () => { + const nextSession = { ...session, id: "sess-2", terminalHandleId: "handle-2" }; + const { view, terminal, mux, detach } = setup(); + act(() => mux.emitData("handle-1", "one")); + expect(terminal.lines).toContain("one"); + + view.rerender({ currentSession: nextSession }); act(() => detach()); - expect(view.result.current.state).toBe("idle"); - expect(muxes[0].disposed).toBe(true); - act(() => muxes[0].emitConnection("closed")); - act(() => void vi.advanceTimersByTime(60_000)); - expect(muxes).toHaveLength(1); + act(() => { + view.result.current.attach(terminal); + }); + act(() => mux.emitData("handle-1", "late")); + act(() => mux.emitData("handle-2", "two")); + + expect(mux.closes).toEqual(["handle-1"]); + expect(mux.opens).toEqual([ + ["handle-1", 80, 24], + ["handle-2", 80, 24], + ]); + expect(terminal.lines).not.toContain("late"); + expect(terminal.lines).toContain("two"); }); }); diff --git a/frontend/src/renderer/hooks/useTerminalSession.ts b/frontend/src/renderer/hooks/useTerminalSession.ts index 8829c70b..c9ed2829 100644 --- a/frontend/src/renderer/hooks/useTerminalSession.ts +++ b/frontend/src/renderer/hooks/useTerminalSession.ts @@ -1,7 +1,7 @@ // Terminal Attachment (see CONTEXT.md): the live binding between a terminal -// pane and a session's PTY over the mux. The hook owns the whole attachment -// lifecycle — open ordering, auto-reattach with backoff, error surfacing, and -// exit handling — so the pane component only renders. +// pane and a session's PTY over the shell-owned mux. The hook owns the visible +// attachment lifecycle — open/close ordering, xterm event listeners, error +// surfacing, and exit handling — so the pane component only renders. // // Status rule: the frontend never writes a session's display status. On mux // `exited`/`error` it invalidates the workspaces query and lets the daemon's @@ -9,8 +9,7 @@ import { useQueryClient } from "@tanstack/react-query"; import { useCallback, useEffect, useRef, useState } from "react"; -import { getApiBaseUrl } from "../lib/api-client"; -import { createTerminalMux, muxUrlFromApiBase, type TerminalMux } from "../lib/terminal-mux"; +import type { TerminalMux } from "../lib/terminal-mux"; import type { WorkspaceSession } from "../types/workspace"; import { workspaceQueryKey } from "./useWorkspaceQuery"; @@ -43,14 +42,10 @@ export type TerminalSessionState = | "error"; // server reported a pane error; no automatic retry export type UseTerminalSessionOptions = { - /** Gates auto-reattach: when false, a dropped socket waits instead of retrying. */ - daemonReady: boolean; - /** Test seam: build the mux client. Defaults to a fresh socket against the current API base. */ - createMux?: () => TerminalMux; + /** Shell-lifetime mux transport. Browser preview passes null and renders a static pane. */ + mux: TerminalMux | null; }; -const RETRY_BASE_MS = 500; -const RETRY_MAX_MS = 8_000; // Trailing debounce on grid changes: a pane drag emits a burst of intermediate // sizes; the attached program should get one SIGWINCH when the drag settles, // not dozens (yyork's terminal-panel does the same at its socket layer). @@ -65,11 +60,6 @@ const RESIZE_DEBOUNCE_MS = 100; // and re-report its grid; when everything is already in sync it's a no-op. const RESIZE_REASSERT_MS = 250; -function defaultCreateMux(): TerminalMux { - // Resolved per connect, not per hook: a daemon restart can change the port. - return createTerminalMux(muxUrlFromApiBase(getApiBaseUrl())); -} - export function useTerminalSession(session: WorkspaceSession | undefined, options: UseTerminalSessionOptions) { const queryClient = useQueryClient(); const [state, setState] = useState("idle"); @@ -80,16 +70,14 @@ export function useTerminalSession(session: WorkspaceSession | undefined, option const optionsRef = useRef(options); optionsRef.current = options; const stateRef = useRef(state); - const connectRef = useRef<() => void>(() => undefined); + const openVisibleRef = useRef<(clearBeforeOpen?: boolean) => void>(() => undefined); const runtime = useRef({ terminal: null as AttachableTerminal | null, mux: null as TerminalMux | null, handle: null as string | null, disposers: [] as Array<() => void>, - retryTimer: null as ReturnType | null, resizeTimer: null as ReturnType | null, - attempts: 0, firstAttach: true, detached: true, }); @@ -103,52 +91,48 @@ export function useTerminalSession(session: WorkspaceSession | undefined, option void queryClient.invalidateQueries({ queryKey: workspaceQueryKey }); }, [queryClient]); - const teardownMux = useCallback(() => { + const detachVisible = useCallback(() => { const r = runtime.current; - if (r.retryTimer) { - clearTimeout(r.retryTimer); - r.retryTimer = null; - } if (r.resizeTimer) { clearTimeout(r.resizeTimer); r.resizeTimer = null; } + const mux = r.mux; + const handle = r.handle; r.disposers.forEach((dispose) => dispose()); r.disposers = []; - r.mux?.dispose(); + if (mux && handle) { + mux.close(handle); + } r.mux = null; }, []); - const scheduleReattach = useCallback(() => { - const r = runtime.current; - if (r.detached || !r.terminal || !r.handle) return; - // A socket dropping after the PTY ended (or errored) changes nothing. - if (stateRef.current === "exited" || stateRef.current === "error") return; - transition("reattaching"); - // Not ready → no timer; the daemonReady effect reconnects when it flips. - if (!optionsRef.current.daemonReady) return; - if (r.retryTimer) return; - const delay = Math.min(RETRY_BASE_MS * 2 ** r.attempts, RETRY_MAX_MS); - r.attempts += 1; - r.retryTimer = setTimeout(() => { - r.retryTimer = null; - connectRef.current(); - }, delay); - }, [transition]); + const openVisible = useCallback( + (clearBeforeOpen = false) => { + const r = runtime.current; + const { terminal, handle, mux } = r; + if (!terminal || !handle || !mux || r.detached) return; + if (clearBeforeOpen || !r.firstAttach) { + terminal.clear(); + } + r.firstAttach = false; + mux.open(handle, terminal.cols, terminal.rows); + mux.resize(handle, terminal.cols, terminal.rows); + }, + [], + ); + openVisibleRef.current = openVisible; - const connect = useCallback(() => { + const bindVisible = useCallback(() => { const r = runtime.current; const { terminal, handle } = r; - if (!terminal || !handle || r.detached) return; - teardownMux(); - - const mux = (optionsRef.current.createMux ?? defaultCreateMux)(); + const mux = optionsRef.current.mux; + if (!terminal || !handle || !mux || r.detached) return; r.mux = mux; r.disposers.push( mux.onData(handle, (bytes) => terminal.write(bytes)), mux.onOpened(handle, () => { - r.attempts = 0; setError(undefined); transition("attached"); }), @@ -164,7 +148,13 @@ export function useTerminalSession(session: WorkspaceSession | undefined, option invalidateWorkspaces(); }), mux.onConnectionChange((connectionState) => { - if (connectionState === "closed") scheduleReattach(); + if (connectionState === "closed") { + if (r.detached || !r.terminal || !r.handle) return; + if (stateRef.current === "exited" || stateRef.current === "error") return; + transition("reattaching"); + } else { + openVisibleRef.current(true); + } }), ); const input = terminal.onData((data) => mux.sendInput(handle, data)); @@ -193,15 +183,8 @@ export function useTerminalSession(session: WorkspaceSession | undefined, option // init handshake + a full repaint; clear the stale screen so the repaint // lands on a blank grid. Screen-clear only, never reset(): RIS would drop // zellij's mouse-tracking mode until the handshake lands. - if (!r.firstAttach) { - terminal.clear(); - } - r.firstAttach = false; - - mux.open(handle, terminal.cols, terminal.rows); - mux.resize(handle, terminal.cols, terminal.rows); - }, [invalidateWorkspaces, scheduleReattach, teardownMux, transition]); - connectRef.current = connect; + openVisible(false); + }, [invalidateWorkspaces, openVisible, transition]); /** * Bind a terminal to the current session's PTY. Call once the terminal is @@ -213,47 +196,54 @@ export function useTerminalSession(session: WorkspaceSession | undefined, option const handle = sessionRef.current?.terminalHandleId ?? null; r.terminal = terminal; r.handle = handle; + r.mux = null; r.detached = false; - r.attempts = 0; r.firstAttach = true; setError(undefined); - if (handle) { + if (handle && optionsRef.current.mux) { transition("connecting"); - connect(); + bindVisible(); + } else if (handle) { + transition("reattaching"); } else { transition("idle"); } return () => { r.detached = true; - teardownMux(); + detachVisible(); r.terminal = null; r.handle = null; setError(undefined); transition("idle"); }; }, - [connect, teardownMux, transition], + [bindVisible, detachVisible, transition], ); - // Daemon came back while we were waiting: reconnect immediately, without - // backoff debt from attempts made against the dead daemon. - const daemonReady = options.daemonReady; + const mux = options.mux; useEffect(() => { const r = runtime.current; - if (!daemonReady || r.detached) return; - if (stateRef.current !== "reattaching" || r.retryTimer) return; - r.attempts = 0; - connect(); - }, [daemonReady, connect]); + if (r.detached || r.mux === mux) return; + detachVisible(); + r.mux = null; + if (!mux) { + if (r.handle) transition("reattaching"); + return; + } + if (r.handle) { + transition("connecting"); + bindVisible(); + } + }, [bindVisible, detachVisible, mux, transition]); // Belt-and-braces: never leak a socket past unmount, even if the owner // forgot to call detach. useEffect( () => () => { runtime.current.detached = true; - teardownMux(); + detachVisible(); }, - [teardownMux], + [detachVisible], ); return { attach, state, error }; diff --git a/frontend/src/renderer/lib/shell-context.ts b/frontend/src/renderer/lib/shell-context.ts index b92d590d..80111fd8 100644 --- a/frontend/src/renderer/lib/shell-context.ts +++ b/frontend/src/renderer/lib/shell-context.ts @@ -1,11 +1,13 @@ import { createContext, useContext } from "react"; import type { useDaemonStatus } from "../hooks/useDaemonStatus"; +import type { TerminalMuxTransport } from "./terminal-mux-transport"; // Shared state the persistent _shell layout owns and route content reads. The // daemon status effect (IPC poll + event transport) must run exactly once, so // it lives in the shell and is handed down here rather than re-run per route. export type ShellContextValue = { daemonStatus: ReturnType; + terminalMux: TerminalMuxTransport | null; createProject: (input: { path: string }) => Promise; }; diff --git a/frontend/src/renderer/lib/terminal-mux-transport.test.ts b/frontend/src/renderer/lib/terminal-mux-transport.test.ts new file mode 100644 index 00000000..1c38a13f --- /dev/null +++ b/frontend/src/renderer/lib/terminal-mux-transport.test.ts @@ -0,0 +1,156 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { MuxConnectionState, TerminalMux } from "./terminal-mux"; +import { createTerminalMuxTransport } from "./terminal-mux-transport"; + +type FakeMux = TerminalMux & { + url: string; + opens: Array<[string, number, number]>; + inputs: Array<[string, string]>; + disposed: boolean; + emitConnection(state: MuxConnectionState): void; + emitData(id: string, bytes: Uint8Array): void; +}; + +function subscribe(map: Map>, id: string, listener: T): () => void { + const set = map.get(id) ?? new Set(); + set.add(listener); + map.set(id, set); + return () => set.delete(listener); +} + +function createFakeMux(url: string): FakeMux { + const data = new Map void>>(); + const connection = new Set<(state: MuxConnectionState) => void>(); + const fake: FakeMux = { + url, + opens: [], + inputs: [], + disposed: false, + open: (id, cols, rows) => fake.opens.push([id, cols, rows]), + sendInput: (id, input) => fake.inputs.push([id, input]), + resize: () => undefined, + close: () => undefined, + onData: (id, listener) => subscribe(data, id, listener), + onExit: () => () => undefined, + onOpened: () => () => undefined, + onError: () => () => undefined, + onConnectionChange: (listener) => { + connection.add(listener); + return () => connection.delete(listener); + }, + dispose: () => { + fake.disposed = true; + }, + emitConnection: (state) => connection.forEach((listener) => listener(state)), + emitData: (id, bytes) => data.get(id)?.forEach((listener) => listener(bytes)), + }; + return fake; +} + +function setup({ daemonReady = true } = {}) { + let baseUrl = "http://127.0.0.1:3001"; + let onBaseUrlChange: (() => void) | undefined; + const muxes: FakeMux[] = []; + const transport = createTerminalMuxTransport({ + daemonReady, + retryBaseMs: 500, + retryMaxMs: 8_000, + getApiBaseUrl: () => baseUrl, + subscribeApiBaseUrl: (listener) => { + onBaseUrlChange = listener; + return () => { + onBaseUrlChange = undefined; + }; + }, + createMux: (url) => { + const mux = createFakeMux(url); + muxes.push(mux); + return mux; + }, + }); + return { + transport, + muxes, + setBaseUrl(next: string) { + baseUrl = next; + onBaseUrlChange?.(); + }, + }; +} + +beforeEach(() => { + vi.useFakeTimers(); +}); + +afterEach(() => { + vi.useRealTimers(); +}); + +describe("createTerminalMuxTransport", () => { + it("waits for daemon readiness before opening /mux", () => { + const { transport, muxes } = setup({ daemonReady: false }); + + expect(muxes).toHaveLength(0); + transport.setDaemonReady(true); + + expect(muxes).toHaveLength(1); + expect(muxes[0].url).toBe("ws://127.0.0.1:3001/mux"); + }); + + it("rebinds to the latest API base URL", () => { + const { setBaseUrl, muxes } = setup(); + const first = muxes[0]; + + setBaseUrl("http://127.0.0.1:4555"); + + expect(first.disposed).toBe(true); + expect(muxes).toHaveLength(2); + expect(muxes[1].url).toBe("ws://127.0.0.1:4555/mux"); + }); + + it("keeps listeners across socket replacement", () => { + const { transport, muxes } = setup(); + const chunks: string[] = []; + transport.onData("t1", (bytes) => chunks.push(new TextDecoder().decode(bytes))); + + muxes[0].emitConnection("open"); + muxes[0].emitData("t1", new TextEncoder().encode("one")); + muxes[0].emitConnection("closed"); + vi.advanceTimersByTime(500); + muxes[1].emitConnection("open"); + muxes[1].emitData("t1", new TextEncoder().encode("two")); + + expect(chunks).toEqual(["one", "two"]); + }); + + it("backs off between reconnect attempts while the daemon is ready", () => { + const { muxes } = setup(); + + muxes[0].emitConnection("closed"); + vi.advanceTimersByTime(499); + expect(muxes).toHaveLength(1); + vi.advanceTimersByTime(1); + expect(muxes).toHaveLength(2); + + muxes[1].emitConnection("closed"); + vi.advanceTimersByTime(999); + expect(muxes).toHaveLength(2); + vi.advanceTimersByTime(1); + expect(muxes).toHaveLength(3); + }); + + it("drops user input while disconnected instead of replaying it on reconnect", () => { + const { transport, muxes } = setup(); + + muxes[0].emitConnection("open"); + transport.sendInput("t1", "before"); + muxes[0].emitConnection("closed"); + transport.sendInput("t1", "during"); + vi.advanceTimersByTime(500); + muxes[1].emitConnection("open"); + transport.sendInput("t1", "after"); + + expect(muxes[0].inputs).toEqual([["t1", "before"]]); + expect(muxes[1].inputs).toEqual([["t1", "after"]]); + }); +}); diff --git a/frontend/src/renderer/lib/terminal-mux-transport.ts b/frontend/src/renderer/lib/terminal-mux-transport.ts new file mode 100644 index 00000000..1aaab137 --- /dev/null +++ b/frontend/src/renderer/lib/terminal-mux-transport.ts @@ -0,0 +1,237 @@ +import { getApiBaseUrl, subscribeApiBaseUrl } from "./api-client"; +import { createTerminalMux, muxUrlFromApiBase, type MuxConnectionState, type TerminalMux } from "./terminal-mux"; + +type DataListener = (bytes: Uint8Array) => void; +type ExitListener = () => void; +type OpenedListener = () => void; +type ErrorListener = (message: string) => void; +type ConnectionListener = (state: MuxConnectionState) => void; + +export type TerminalMuxTransport = TerminalMux & { + setDaemonReady: (ready: boolean) => void; +}; + +export type TerminalMuxTransportOptions = { + daemonReady?: boolean; + createMux?: (url: string) => TerminalMux; + getApiBaseUrl?: () => string; + subscribeApiBaseUrl?: (listener: () => void) => () => void; + retryBaseMs?: number; + retryMaxMs?: number; +}; + +const RETRY_BASE_MS = 500; +const RETRY_MAX_MS = 8_000; + +function subscribeById(map: Map>, id: string, listener: T): () => void { + const set = map.get(id) ?? new Set(); + set.add(listener); + map.set(id, set); + return () => { + set.delete(listener); + if (set.size === 0) map.delete(id); + }; +} + +/** + * Shell-lifetime wrapper around the single-socket mux client. It owns socket + * replacement and keeps pane listeners registered across reconnects; visible + * terminal attachments still decide which handle is opened on each live socket. + */ +export function createTerminalMuxTransport(options: TerminalMuxTransportOptions = {}): TerminalMuxTransport { + const buildMux = options.createMux ?? ((url: string) => createTerminalMux(url)); + const readApiBaseUrl = options.getApiBaseUrl ?? getApiBaseUrl; + const subscribeBaseUrl = options.subscribeApiBaseUrl ?? subscribeApiBaseUrl; + const retryBaseMs = options.retryBaseMs ?? RETRY_BASE_MS; + const retryMaxMs = options.retryMaxMs ?? RETRY_MAX_MS; + + const dataListeners = new Map>(); + const exitListeners = new Map>(); + const openedListeners = new Map>(); + const errorListeners = new Map>(); + const connectionListeners = new Set(); + + let mux: TerminalMux | null = null; + let muxDisposers: Array<() => void> = []; + let socketBindings = new Set(); + let state: MuxConnectionState = "closed"; + let daemonReady = options.daemonReady ?? false; + let disposed = false; + let retryTimer: ReturnType | null = null; + let attempts = 0; + let currentBaseUrl: string | null = null; + + const setState = (next: MuxConnectionState, options: { force?: boolean } = {}) => { + if (disposed || (!options.force && state === next)) return; + state = next; + connectionListeners.forEach((listener) => listener(next)); + }; + + const clearRetry = () => { + if (retryTimer) { + clearTimeout(retryTimer); + retryTimer = null; + } + }; + + const disposeMux = () => { + muxDisposers.forEach((dispose) => dispose()); + muxDisposers = []; + socketBindings = new Set(); + mux?.dispose(); + mux = null; + }; + + const ensureDataBinding = (id: string) => { + if (!mux) return; + const key = `data:${id}`; + if (socketBindings.has(key)) return; + socketBindings.add(key); + muxDisposers.push(mux.onData(id, (bytes) => dataListeners.get(id)?.forEach((listener) => listener(bytes)))); + }; + + const ensureExitBinding = (id: string) => { + if (!mux) return; + const key = `exit:${id}`; + if (socketBindings.has(key)) return; + socketBindings.add(key); + muxDisposers.push(mux.onExit(id, () => exitListeners.get(id)?.forEach((listener) => listener()))); + }; + + const ensureOpenedBinding = (id: string) => { + if (!mux) return; + const key = `opened:${id}`; + if (socketBindings.has(key)) return; + socketBindings.add(key); + muxDisposers.push(mux.onOpened(id, () => openedListeners.get(id)?.forEach((listener) => listener()))); + }; + + const ensureErrorBinding = (id: string) => { + if (!mux) return; + const key = `error:${id}`; + if (socketBindings.has(key)) return; + socketBindings.add(key); + muxDisposers.push(mux.onError(id, (message) => errorListeners.get(id)?.forEach((listener) => listener(message)))); + }; + + const bindSocketListeners = (nextMux: TerminalMux) => { + mux = nextMux; + for (const id of dataListeners.keys()) { + ensureDataBinding(id); + } + for (const id of exitListeners.keys()) { + ensureExitBinding(id); + } + for (const id of openedListeners.keys()) { + ensureOpenedBinding(id); + } + for (const id of errorListeners.keys()) { + ensureErrorBinding(id); + } + muxDisposers.push( + nextMux.onConnectionChange((nextState) => { + if (nextState === "open") { + attempts = 0; + clearRetry(); + setState("open"); + } else { + disposeMux(); + setState("closed", { force: true }); + scheduleReconnect(); + } + }), + ); + }; + + const connect = () => { + if (disposed || !daemonReady || mux) return; + currentBaseUrl = readApiBaseUrl(); + const nextMux = buildMux(muxUrlFromApiBase(currentBaseUrl)); + bindSocketListeners(nextMux); + }; + + const scheduleReconnect = () => { + if (disposed || !daemonReady || retryTimer) return; + const delay = Math.min(retryBaseMs * 2 ** attempts, retryMaxMs); + attempts += 1; + retryTimer = setTimeout(() => { + retryTimer = null; + connect(); + }, delay); + }; + + const rebindBaseUrl = () => { + if (disposed) return; + const nextBaseUrl = readApiBaseUrl(); + if (nextBaseUrl === currentBaseUrl && mux) return; + clearRetry(); + disposeMux(); + setState("closed"); + currentBaseUrl = nextBaseUrl; + attempts = 0; + connect(); + }; + + const removeBaseUrlListener = subscribeBaseUrl(rebindBaseUrl); + if (daemonReady) connect(); + + const sendWhenOpen = (send: (activeMux: TerminalMux) => void) => { + if (state !== "open" || !mux) return; + send(mux); + }; + + return { + open: (id, cols, rows) => sendWhenOpen((activeMux) => activeMux.open(id, cols, rows)), + sendInput: (id, input) => sendWhenOpen((activeMux) => activeMux.sendInput(id, input)), + resize: (id, cols, rows) => sendWhenOpen((activeMux) => activeMux.resize(id, cols, rows)), + close: (id) => sendWhenOpen((activeMux) => activeMux.close(id)), + onData: (id, listener) => { + const unsubscribe = subscribeById(dataListeners, id, listener); + ensureDataBinding(id); + return unsubscribe; + }, + onExit: (id, listener) => { + const unsubscribe = subscribeById(exitListeners, id, listener); + ensureExitBinding(id); + return unsubscribe; + }, + onOpened: (id, listener) => { + const unsubscribe = subscribeById(openedListeners, id, listener); + ensureOpenedBinding(id); + return unsubscribe; + }, + onError: (id, listener) => { + const unsubscribe = subscribeById(errorListeners, id, listener); + ensureErrorBinding(id); + return unsubscribe; + }, + onConnectionChange: (listener) => { + connectionListeners.add(listener); + return () => connectionListeners.delete(listener); + }, + setDaemonReady: (ready) => { + if (daemonReady === ready) return; + daemonReady = ready; + clearRetry(); + if (!daemonReady) { + disposeMux(); + setState("closed"); + return; + } + attempts = 0; + connect(); + }, + dispose: () => { + if (disposed) return; + disposed = true; + clearRetry(); + removeBaseUrlListener(); + disposeMux(); + connectionListeners.clear(); + dataListeners.clear(); + exitListeners.clear(); + openedListeners.clear(); + errorListeners.clear(); + }, + }; +} diff --git a/frontend/src/renderer/lib/terminal-mux.ts b/frontend/src/renderer/lib/terminal-mux.ts index 76cd8c20..0c97d987 100644 --- a/frontend/src/renderer/lib/terminal-mux.ts +++ b/frontend/src/renderer/lib/terminal-mux.ts @@ -120,7 +120,7 @@ function subscribeById(map: Map>, id: string, listener: T): () * Create a mux client over a single WebSocket. Frames sent before the socket is * OPEN are queued and flushed on connect. There is no auto-reconnect at this * layer: a dropped socket is reported through onConnectionChange("closed") and - * the owner (useTerminalSession) decides whether to build a fresh client. + * the shell-lifetime transport decides whether to build a fresh client. */ export function createTerminalMux(url: string, WebSocketImpl: typeof WebSocket = WebSocket): TerminalMux { const socket = new WebSocketImpl(url); diff --git a/frontend/src/renderer/routes/_shell.tsx b/frontend/src/renderer/routes/_shell.tsx index fc276ee4..e17cee2e 100644 --- a/frontend/src/renderer/routes/_shell.tsx +++ b/frontend/src/renderer/routes/_shell.tsx @@ -1,6 +1,6 @@ import { createFileRoute, Outlet, useNavigate } from "@tanstack/react-router"; import { useQueryClient } from "@tanstack/react-query"; -import { type CSSProperties, useCallback, useEffect } from "react"; +import { type CSSProperties, useCallback, useEffect, useMemo } from "react"; import { ShellTopbar } from "../components/ShellTopbar"; import { Sidebar } from "../components/Sidebar"; import { SidebarProvider } from "../components/ui/sidebar"; @@ -9,6 +9,7 @@ import { useDaemonStatus } from "../hooks/useDaemonStatus"; import { useWorkspaceQuery, workspaceQueryKey, workspaceQueryOptions } from "../hooks/useWorkspaceQuery"; import { apiClient, apiErrorMessage } from "../lib/api-client"; import { ShellProvider } from "../lib/shell-context"; +import { createTerminalMuxTransport } from "../lib/terminal-mux-transport"; import { readStoredTheme, type Theme, useUiStore } from "../stores/ui-store"; import type { WorkspaceSummary } from "../types/workspace"; @@ -39,6 +40,7 @@ function ShellLayout() { const workspaces = workspaceQuery.data ?? []; const daemonStatus = useDaemonStatus(queryClient); const { theme, setTheme, isSidebarOpen, toggleSidebar } = useUiStore(); + const terminalMux = useMemo(() => (window.ao ? createTerminalMuxTransport() : null), []); const updateWorkspaces = useCallback( (updater: (workspaces: WorkspaceSummary[]) => WorkspaceSummary[]) => { @@ -82,6 +84,12 @@ function ShellLayout() { document.documentElement.style.colorScheme = theme; }, [theme]); + useEffect(() => { + terminalMux?.setDaemonReady(daemonStatus.state === "ready"); + }, [daemonStatus.state, terminalMux]); + + useEffect(() => () => terminalMux?.dispose(), [terminalMux]); + // Follow OS appearance only until the user picks a theme explicitly. useEffect(() => { if (readStoredTheme()) return; @@ -109,7 +117,7 @@ function ShellLayout() { }, [navigate, workspaces]); return ( - + {/* The topbar spans the full window width above the sidebar row (the macOS traffic lights + TitlebarNav cluster sit in its left inset), and the sidebar hangs below it — so the sidebar border stops at the From 7b66258249350cc90e8d4eb25e33748b7d0dcb75 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 18 Jun 2026 21:14:07 +0000 Subject: [PATCH 2/9] chore: format with prettier [skip ci] --- .../hooks/useTerminalSession.test.tsx | 8 +++--- .../src/renderer/hooks/useTerminalSession.ts | 25 ++++++++----------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/frontend/src/renderer/hooks/useTerminalSession.test.tsx b/frontend/src/renderer/hooks/useTerminalSession.test.tsx index d10743de..d36e3e74 100644 --- a/frontend/src/renderer/hooks/useTerminalSession.test.tsx +++ b/frontend/src/renderer/hooks/useTerminalSession.test.tsx @@ -127,10 +127,10 @@ function setup({ attachedSession = session as WorkspaceSession | undefined, mux const wrapper = ({ children }: { children: ReactNode }) => ( {children} ); - const view = renderHook( - ({ currentSession }) => useTerminalSession(currentSession, { mux: mux.mux }), - { initialProps: { currentSession: attachedSession }, wrapper }, - ); + const view = renderHook(({ currentSession }) => useTerminalSession(currentSession, { mux: mux.mux }), { + initialProps: { currentSession: attachedSession }, + wrapper, + }); const terminal = createFakeTerminal(); let detach: () => void = () => undefined; act(() => { diff --git a/frontend/src/renderer/hooks/useTerminalSession.ts b/frontend/src/renderer/hooks/useTerminalSession.ts index c9ed2829..a956f486 100644 --- a/frontend/src/renderer/hooks/useTerminalSession.ts +++ b/frontend/src/renderer/hooks/useTerminalSession.ts @@ -107,20 +107,17 @@ export function useTerminalSession(session: WorkspaceSession | undefined, option r.mux = null; }, []); - const openVisible = useCallback( - (clearBeforeOpen = false) => { - const r = runtime.current; - const { terminal, handle, mux } = r; - if (!terminal || !handle || !mux || r.detached) return; - if (clearBeforeOpen || !r.firstAttach) { - terminal.clear(); - } - r.firstAttach = false; - mux.open(handle, terminal.cols, terminal.rows); - mux.resize(handle, terminal.cols, terminal.rows); - }, - [], - ); + const openVisible = useCallback((clearBeforeOpen = false) => { + const r = runtime.current; + const { terminal, handle, mux } = r; + if (!terminal || !handle || !mux || r.detached) return; + if (clearBeforeOpen || !r.firstAttach) { + terminal.clear(); + } + r.firstAttach = false; + mux.open(handle, terminal.cols, terminal.rows); + mux.resize(handle, terminal.cols, terminal.rows); + }, []); openVisibleRef.current = openVisible; const bindVisible = useCallback(() => { From 018a33911957ad811f9eb369b7a507c3ae6d5407 Mon Sep 17 00:00:00 2001 From: whoisasx Date: Fri, 19 Jun 2026 03:17:00 +0530 Subject: [PATCH 3/9] fix: require configured agent defaults --- README.md | 6 +- backend/internal/cli/dto_drift_e2e_test.go | 4 +- backend/internal/cli/spawn.go | 5 +- backend/internal/config/config.go | 13 -- backend/internal/config/config_test.go | 2 +- backend/internal/daemon/daemon.go | 7 +- backend/internal/daemon/lifecycle_wiring.go | 58 ++--- backend/internal/daemon/wiring_test.go | 11 +- backend/internal/domain/agentdefaults.go | 43 ++++ backend/internal/httpd/api.go | 4 + backend/internal/httpd/apispec/openapi.yaml | 203 ++++++++++++++++++ .../internal/httpd/apispec/specgen/build.go | 31 +++ backend/internal/httpd/controllers/dto.go | 16 +- .../internal/httpd/controllers/sessions.go | 4 +- .../httpd/controllers/sessions_test.go | 4 +- .../internal/httpd/controllers/settings.go | 67 ++++++ .../httpd/controllers/settings_test.go | 80 +++++++ .../integration/lifecycle_sqlite_test.go | 8 +- backend/internal/service/session/service.go | 6 +- .../internal/service/session/service_test.go | 23 +- backend/internal/service/settings/service.go | 46 ++++ .../internal/service/settings/service_test.go | 59 +++++ backend/internal/session_manager/manager.go | 85 +++++--- .../internal/session_manager/manager_test.go | 77 +++++-- backend/internal/storage/sqlite/gen/models.go | 6 + .../storage/sqlite/gen/settings.sql.go | 48 +++++ .../sqlite/migrations/0014_app_settings.sql | 17 ++ .../storage/sqlite/queries/settings.sql | 11 + .../storage/sqlite/store/settings_store.go | 40 ++++ .../sqlite/store/settings_store_test.go | 37 ++++ backend/sqlc.yaml | 8 + frontend/src/api/schema.ts | 122 +++++++++++ .../components/AgentDefaultsDialog.test.tsx | 94 ++++++++ .../components/AgentDefaultsDialog.tsx | 148 +++++++++++++ .../components/ProjectSettingsForm.tsx | 10 +- .../src/renderer/components/Sidebar.test.tsx | 14 +- frontend/src/renderer/components/Sidebar.tsx | 19 +- frontend/src/renderer/lib/agent-defaults.ts | 21 ++ frontend/src/renderer/lib/agent-options.ts | 27 +++ frontend/src/renderer/routes/_shell.tsx | 10 +- 40 files changed, 1359 insertions(+), 135 deletions(-) create mode 100644 backend/internal/domain/agentdefaults.go create mode 100644 backend/internal/httpd/controllers/settings.go create mode 100644 backend/internal/httpd/controllers/settings_test.go create mode 100644 backend/internal/service/settings/service.go create mode 100644 backend/internal/service/settings/service_test.go create mode 100644 backend/internal/storage/sqlite/gen/settings.sql.go create mode 100644 backend/internal/storage/sqlite/migrations/0014_app_settings.sql create mode 100644 backend/internal/storage/sqlite/queries/settings.sql create mode 100644 backend/internal/storage/sqlite/store/settings_store.go create mode 100644 backend/internal/storage/sqlite/store/settings_store_test.go create mode 100644 frontend/src/renderer/components/AgentDefaultsDialog.test.tsx create mode 100644 frontend/src/renderer/components/AgentDefaultsDialog.tsx create mode 100644 frontend/src/renderer/lib/agent-defaults.ts create mode 100644 frontend/src/renderer/lib/agent-options.ts diff --git a/README.md b/README.md index 281cb376..7adf5678 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,8 @@ progress (what's shipped vs. in flight) see [`docs/STATUS.md`](docs/STATUS.md). `opencode`, `aider`, `amp`, `goose`, `copilot`, `grok`, `qwen`, `kimi`, `crush`, `cline`, `droid`, `devin`, `auggie`, `continue`, `kiro`, `kilocode`, and more), registered through a shared registry with common - activity-dispatch / hook utilities. The default is set by `AO_AGENT`. + activity-dispatch / hook utilities. App-wide worker and orchestrator defaults + are stored in daemon-backed settings. - **Isolated workspaces.** Worker and orchestrator sessions spawn into their own `git worktree` (`backend/internal/adapters/workspace/gitworktree/`), launched inside a `zellij` runtime adapter (`backend/internal/adapters/runtime/`) so @@ -94,7 +95,7 @@ go build -o /tmp/ao ./cmd/ao # base of --path; pass --id explicitly when the directory name doesn't match. /tmp/ao project add --path /path/to/your/repo --id your-repo --name your-repo -# Spawn a worker session running the default agent. +# Spawn a worker session running the configured default agent. /tmp/ao spawn --project your-repo --prompt "Refactor the auth module" # Inspect what's running. @@ -167,7 +168,6 @@ exposing it beyond loopback would be a security regression. | `AO_SHUTDOWN_TIMEOUT` | `10s` | Graceful-shutdown hard cap. | | `AO_RUN_FILE` | `/agent-orchestrator/running.json` | PID + port handshake path. | | `AO_DATA_DIR` | `/agent-orchestrator/data` | SQLite DB, WAL files, managed state. | -| `AO_AGENT` | `claude-code` | Default agent adapter id used by `ao spawn`. | | `AO_SESSION_ID` | _(unset)_ | Set inside spawned sessions; read by `ao send` and `ao hooks`. | | `GITHUB_TOKEN` | _(unset)_ | Used by the GitHub SCM and tracker adapters. Falls back to `gh auth token`. | diff --git a/backend/internal/cli/dto_drift_e2e_test.go b/backend/internal/cli/dto_drift_e2e_test.go index 2488d8f4..2114d95e 100644 --- a/backend/internal/cli/dto_drift_e2e_test.go +++ b/backend/internal/cli/dto_drift_e2e_test.go @@ -62,8 +62,8 @@ func (f *fakeSessionService) Spawn(_ context.Context, cfg ports.SpawnConfig) (do }, nil } -func (f *fakeSessionService) SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, _ bool) (domain.Session, error) { - return f.Spawn(ctx, ports.SpawnConfig{ProjectID: projectID, Kind: domain.KindOrchestrator}) +func (f *fakeSessionService) SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, _ bool, harness domain.AgentHarness) (domain.Session, error) { + return f.Spawn(ctx, ports.SpawnConfig{ProjectID: projectID, Kind: domain.KindOrchestrator, Harness: harness}) } func (f *fakeSessionService) Get(context.Context, domain.SessionID) (domain.Session, error) { diff --git a/backend/internal/cli/spawn.go b/backend/internal/cli/spawn.go index efa4f3f1..b1c79e4f 100644 --- a/backend/internal/cli/spawn.go +++ b/backend/internal/cli/spawn.go @@ -44,7 +44,8 @@ func newSpawnCommand(ctx *commandContext) *cobra.Command { Use: "spawn", Short: "Spawn a worker agent session in a registered project", Long: "Spawn a worker agent session in a registered project.\n\n" + - "The session runs the chosen agent (default: the daemon's AO_AGENT) in a\n" + + "The session runs the chosen agent, the project's role override, or the\n" + + "app-wide default agent configured in Settings. It uses a\n" + "fresh git worktree. Register the project first with `ao project add`.", Args: noArgs, RunE: func(cmd *cobra.Command, args []string) error { @@ -120,7 +121,7 @@ func newSpawnCommand(ctx *commandContext) *cobra.Command { return pflag.NormalizedName(name) }) f.StringVar(&opts.project, "project", "", "Project id to spawn the session in (required)") - f.StringVar(&opts.harness, "harness", "", "Agent harness / --agent: claude-code, codex, aider, opencode, grok, droid, amp, agy, crush, cursor, qwen, copilot, goose, auggie, continue, devin, cline, kimi, kiro, kilocode, vibe, pi, autohand (default: the daemon's AO_AGENT)") + f.StringVar(&opts.harness, "harness", "", "Agent harness / --agent: claude-code, codex, aider, opencode, grok, droid, amp, agy, crush, cursor, qwen, copilot, goose, auggie, continue, devin, cline, kimi, kiro, kilocode, vibe, pi, autohand (default: app setting)") f.StringVar(&opts.branch, "branch", "", "Branch for the session worktree (default: ao//root)") f.StringVar(&opts.prompt, "prompt", "", "Initial prompt for the agent") f.StringVar(&opts.issue, "issue", "", "Issue id to associate with the session") diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index e2a9386c..1296ea42 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -29,9 +29,6 @@ const ( // DefaultShutdownTimeout is the hard cap on graceful shutdown. After this // the process exits even if connections are still draining. DefaultShutdownTimeout = 10 * time.Second - // DefaultAgent is the agent adapter id the daemon wires when AO_AGENT is - // unset. It matches the claude-code adapter's manifest id. - DefaultAgent = "claude-code" ) // DefaultAllowedOrigins are the browser origins the daemon's CORS boundary @@ -63,10 +60,6 @@ type Config struct { // DataDir is the directory holding durable SQLite state: DB and WAL files. // It is created on first use by the storage layer. DataDir string - // Agent is the id of the agent adapter the daemon wires into the Session - // Manager (see DefaultAgent). Selected by AO_AGENT; startSession fails fast - // if no adapter with this id is registered. - Agent string // AllowedOrigins are the browser origins granted CORS read access (see // DefaultAllowedOrigins). Overridden by AO_ALLOWED_ORIGINS. AllowedOrigins []string @@ -89,7 +82,6 @@ func (c Config) Addr() string { // AO_SHUTDOWN_TIMEOUT shutdown deadline (Go duration > 0, default 10s) // AO_RUN_FILE running.json path (default ~/.ao/running.json) // AO_DATA_DIR durable state dir (default ~/.ao/data) -// AO_AGENT agent adapter id (default claude-code) // AO_ALLOWED_ORIGINS CORS origins, comma-separated (default DefaultAllowedOrigins) // // The bind host is not configurable: the daemon is loopback-only by design. @@ -99,7 +91,6 @@ func Load() (Config, error) { Port: DefaultPort, RequestTimeout: DefaultRequestTimeout, ShutdownTimeout: DefaultShutdownTimeout, - Agent: DefaultAgent, AllowedOrigins: DefaultAllowedOrigins, } @@ -130,10 +121,6 @@ func Load() (Config, error) { cfg.ShutdownTimeout = d } - if raw := os.Getenv("AO_AGENT"); raw != "" { - cfg.Agent = raw - } - if raw, ok := os.LookupEnv("AO_ALLOWED_ORIGINS"); ok && raw != "" { // Explicit override replaces the defaults entirely so a deployment can // also narrow the list. The "null" origin is rejected, never silently diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 4ce22512..30be606f 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -10,7 +10,7 @@ import ( func TestLoadDefaults(t *testing.T) { // Clear every recognised var so we observe pure defaults regardless of the // surrounding environment. - for _, k := range []string{"AO_PORT", "AO_REQUEST_TIMEOUT", "AO_SHUTDOWN_TIMEOUT", "AO_RUN_FILE", "AO_DATA_DIR", "AO_AGENT", "AO_ALLOWED_ORIGINS"} { + for _, k := range []string{"AO_PORT", "AO_REQUEST_TIMEOUT", "AO_SHUTDOWN_TIMEOUT", "AO_RUN_FILE", "AO_DATA_DIR", "AO_ALLOWED_ORIGINS"} { t.Setenv(k, "") } diff --git a/backend/internal/daemon/daemon.go b/backend/internal/daemon/daemon.go index 747f5251..4c4f1e9f 100644 --- a/backend/internal/daemon/daemon.go +++ b/backend/internal/daemon/daemon.go @@ -18,6 +18,7 @@ import ( "github.com/aoagents/agent-orchestrator/backend/internal/runfile" notificationsvc "github.com/aoagents/agent-orchestrator/backend/internal/service/notification" projectsvc "github.com/aoagents/agent-orchestrator/backend/internal/service/project" + settingssvc "github.com/aoagents/agent-orchestrator/backend/internal/service/settings" "github.com/aoagents/agent-orchestrator/backend/internal/storage/sqlite" "github.com/aoagents/agent-orchestrator/backend/internal/terminal" ) @@ -94,9 +95,8 @@ func Run() error { lcStack.scmDone = startSCMObserver(ctx, store, lcStack.LCM, log) // Wire the controller-facing session service over the same store + LCM, the - // zellij runtime, a gitworktree workspace, the per-session agent resolver - // (AO_AGENT default, validated here), and the agent messenger, then mount it - // on the API. + // zellij runtime, a gitworktree workspace, the per-session agent resolver, + // and the agent messenger, then mount it on the API. sessionSvc, reviewSvc, err := startSession(cfg, runtimeAdapter, store, lcStack.LCM, messenger, log) if err != nil { stop() @@ -111,6 +111,7 @@ func Run() error { Projects: projectsvc.NewWithDeps(projectsvc.Deps{Store: store, Sessions: sessionSvc}), Sessions: sessionSvc, Reviews: reviewSvc, + Settings: settingssvc.New(store), Notifications: notifier, NotificationStream: notificationHub, CDC: store, diff --git a/backend/internal/daemon/lifecycle_wiring.go b/backend/internal/daemon/lifecycle_wiring.go index a66d3dea..45a9eb91 100644 --- a/backend/internal/daemon/lifecycle_wiring.go +++ b/backend/internal/daemon/lifecycle_wiring.go @@ -60,17 +60,10 @@ func (l *lifecycleStack) Stop() { // startSession builds the controller-facing session service: a session manager // over the real zellij runtime, a per-session gitworktree workspace, the shared -// store + LCM, the per-session agent resolver (AO_AGENT default), and the +// store + LCM, the per-session agent resolver, and the // agent messenger. The returned service is mounted at httpd APIDeps.Sessions. func startSession(cfg config.Config, runtime *zellij.Runtime, store *sqlite.Store, lcm *lifecycle.Manager, messenger ports.AgentMessenger, log *slog.Logger) (*sessionsvc.Service, reviewsvc.Manager, error) { - // Resolve the default agent once and share it with both the resolver (which - // launches it for an unspecified harness) and the session manager (which - // persists it onto the seed row), so the stored harness matches what runs. - defaultAgent := cfg.Agent - if defaultAgent == "" { - defaultAgent = config.DefaultAgent - } - agents, err := buildAgentResolver(defaultAgent, log) + agents, err := buildAgentResolver(log) if err != nil { return nil, nil, err } @@ -87,15 +80,15 @@ func startSession(cfg config.Config, runtime *zellij.Runtime, store *sqlite.Stor return nil, nil, fmt.Errorf("session workspace: %w", err) } mgr := sessionmanager.New(sessionmanager.Deps{ - Runtime: runtime, - Agents: agents, - Workspace: ws, - Store: store, - Messenger: messenger, - Lifecycle: lcm, - DataDir: cfg.DataDir, - DefaultHarness: domain.AgentHarness(defaultAgent), - Logger: log, + Runtime: runtime, + Agents: agents, + Workspace: ws, + Store: store, + Messenger: messenger, + Lifecycle: lcm, + DataDir: cfg.DataDir, + AgentDefaults: store, + Logger: log, }) scmProvider, err := newGitHubSCMProvider(log) if err != nil { @@ -178,20 +171,15 @@ func buildAgentRegistry() (*adapters.Registry, error) { // agentRegistry adapts the generic adapter Registry to ports.AgentResolver: it // maps a session's harness onto the registered adapter of the same id and -// asserts that adapter drives an agent. An empty harness falls back to the -// daemon's configured default (AO_AGENT), so a spawn that names no harness still -// gets a real agent. +// asserts that adapter drives an agent. Empty harnesses are misses: session +// spawns must resolve defaults before they reach the registry. type agentRegistry struct { - reg *adapters.Registry - defaultHarness domain.AgentHarness + reg *adapters.Registry } var _ ports.AgentResolver = agentRegistry{} func (a agentRegistry) Agent(harness domain.AgentHarness) (ports.Agent, bool) { - if harness == "" { - harness = a.defaultHarness - } adapter, ok := a.reg.Get(string(harness)) if !ok { return nil, false @@ -202,27 +190,19 @@ func (a agentRegistry) Agent(harness domain.AgentHarness) (ports.Agent, bool) { // buildAgentResolver constructs the per-session agent resolver the Session // Manager consumes (sessionmanager.Deps.Agents): a registry of the shipped -// adapters plus the configured default harness. It fails fast if the default -// does not resolve, so a typo'd AO_AGENT surfaces at startup. The session lane -// plugs this in when it mounts the controller-facing session service at the -// httpd APIDeps.Sessions slot. -func buildAgentResolver(defaultAgent string, log *slog.Logger) (ports.AgentResolver, error) { - if defaultAgent == "" { - defaultAgent = config.DefaultAgent - } +// adapters. The session lane plugs this in when it mounts the controller-facing +// session service at the httpd APIDeps.Sessions slot. +func buildAgentResolver(log *slog.Logger) (ports.AgentResolver, error) { reg, err := buildAgentRegistry() if err != nil { return nil, err } - resolver := agentRegistry{reg: reg, defaultHarness: domain.AgentHarness(defaultAgent)} - if _, ok := resolver.Agent(""); !ok { - return nil, fmt.Errorf("configured default agent %q is not a registered adapter", defaultAgent) - } + resolver := agentRegistry{reg: reg} ids := make([]string, 0) for _, mf := range reg.Manifests() { ids = append(ids, mf.ID) } - log.Info("built per-session agent resolver", "default", defaultAgent, "registered", ids) + log.Info("built per-session agent resolver", "registered", ids) return resolver, nil } diff --git a/backend/internal/daemon/wiring_test.go b/backend/internal/daemon/wiring_test.go index 36e67344..ed5a9dfc 100644 --- a/backend/internal/daemon/wiring_test.go +++ b/backend/internal/daemon/wiring_test.go @@ -77,12 +77,11 @@ func TestWiring_WriteFlowsToBroadcaster(t *testing.T) { } // TestWiring_AgentResolverResolvesRealAdapters asserts buildAgentResolver wires a -// real registry-backed per-session resolver: each harness resolves to the -// matching registered adapter, an empty harness falls back to the AO_AGENT -// default, and an unknown harness misses. +// real registry-backed per-session resolver: each concrete harness resolves to +// the matching registered adapter, while empty/unknown harnesses miss. func TestWiring_AgentResolverResolvesRealAdapters(t *testing.T) { log := slog.New(slog.NewTextHandler(io.Discard, nil)) - resolver, err := buildAgentResolver("", log) // empty default → claude-code + resolver, err := buildAgentResolver(log) if err != nil { t.Fatal(err) } @@ -113,7 +112,6 @@ func TestWiring_AgentResolverResolvesRealAdapters(t *testing.T) { {domain.HarnessVibe, "vibe"}, {domain.HarnessPi, "pi"}, {domain.HarnessAutohand, "autohand"}, - {"", config.DefaultAgent}, // empty harness falls back to the AO_AGENT default } { agent, ok := resolver.Agent(tc.harness) if !ok { @@ -130,6 +128,9 @@ func TestWiring_AgentResolverResolvesRealAdapters(t *testing.T) { if _, ok := resolver.Agent("definitely-not-an-agent"); ok { t.Fatal("unknown harness resolved to an agent; want a miss") } + if _, ok := resolver.Agent(""); ok { + t.Fatal("empty harness resolved to an agent; want a miss") + } } // TestWiring_StartSessionBuildsSessionService asserts the daemon's startSession diff --git a/backend/internal/domain/agentdefaults.go b/backend/internal/domain/agentdefaults.go new file mode 100644 index 00000000..16129b02 --- /dev/null +++ b/backend/internal/domain/agentdefaults.go @@ -0,0 +1,43 @@ +package domain + +import "fmt" + +// AgentDefaults are the app-wide fallback harnesses used when a spawn does not +// name an explicit harness and the project has no role override. +type AgentDefaults struct { + DefaultWorkerAgent AgentHarness `json:"defaultWorkerAgent,omitempty" enum:"claude-code,codex,aider,opencode,grok,droid,amp,agy,crush,cursor,qwen,copilot,goose,auggie,continue,devin,cline,kimi,kiro,kilocode,vibe,pi,autohand"` + DefaultOrchestratorAgent AgentHarness `json:"defaultOrchestratorAgent,omitempty" enum:"claude-code,codex,aider,opencode,grok,droid,amp,agy,crush,cursor,qwen,copilot,goose,auggie,continue,devin,cline,kimi,kiro,kilocode,vibe,pi,autohand"` +} + +// HarnessFor returns the default harness for a session kind. Any non-worker +// role is treated as an orchestrator role because the domain only has those two +// concrete spawn roles today. +func (d AgentDefaults) HarnessFor(kind SessionKind) AgentHarness { + if kind == KindWorker || kind == "" { + return d.DefaultWorkerAgent + } + return d.DefaultOrchestratorAgent +} + +// Complete reports whether both app-wide defaults have been configured. +func (d AgentDefaults) Complete() bool { + return d.DefaultWorkerAgent != "" && d.DefaultOrchestratorAgent != "" +} + +// ValidateComplete rejects missing or unknown defaults. Settings writes use +// this strict path so the app never persists a half-configured default state. +func (d AgentDefaults) ValidateComplete() error { + if d.DefaultWorkerAgent == "" { + return fmt.Errorf("defaultWorkerAgent is required") + } + if !d.DefaultWorkerAgent.IsKnown() { + return fmt.Errorf("defaultWorkerAgent: unknown harness %q", d.DefaultWorkerAgent) + } + if d.DefaultOrchestratorAgent == "" { + return fmt.Errorf("defaultOrchestratorAgent is required") + } + if !d.DefaultOrchestratorAgent.IsKnown() { + return fmt.Errorf("defaultOrchestratorAgent: unknown harness %q", d.DefaultOrchestratorAgent) + } + return nil +} diff --git a/backend/internal/httpd/api.go b/backend/internal/httpd/api.go index 40b65d8a..2924624d 100644 --- a/backend/internal/httpd/api.go +++ b/backend/internal/httpd/api.go @@ -23,6 +23,7 @@ type APIDeps struct { Activity controllers.ActivityRecorder PRs prsvc.ActionManager Reviews reviewsvc.Manager + Settings controllers.SettingsService Notifications controllers.NotificationService NotificationStream controllers.NotificationStream CDC cdc.Source @@ -37,6 +38,7 @@ type API struct { sessions *controllers.SessionsController prs *controllers.PRsController reviews *controllers.ReviewsController + settings *controllers.SettingsController notifications *controllers.NotificationsController events *EventsController } @@ -56,6 +58,7 @@ func NewAPI(cfg config.Config, deps APIDeps) *API { }, prs: &controllers.PRsController{Svc: deps.PRs}, reviews: &controllers.ReviewsController{Svc: deps.Reviews}, + settings: &controllers.SettingsController{Svc: deps.Settings}, notifications: &controllers.NotificationsController{Svc: deps.Notifications, Stream: deps.NotificationStream}, events: &EventsController{Source: deps.CDC, Live: deps.Events}, } @@ -79,6 +82,7 @@ func (a *API) Register(root chi.Router) { a.sessions.Register(r) a.prs.Register(r) a.reviews.Register(r) + a.settings.Register(r) a.notifications.Register(r) // Sibling REST controllers plug in here. }) diff --git a/backend/internal/httpd/apispec/openapi.yaml b/backend/internal/httpd/apispec/openapi.yaml index a279d460..766810bc 100644 --- a/backend/internal/httpd/apispec/openapi.yaml +++ b/backend/internal/httpd/apispec/openapi.yaml @@ -1149,6 +1149,67 @@ paths: summary: Clean up terminated session workspaces tags: - sessions + /api/v1/settings/agents: + get: + operationId: getAgentDefaults + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AgentDefaultsResponse' + description: OK + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/APIError' + description: Internal Server Error + "501": + content: + application/json: + schema: + $ref: '#/components/schemas/APIError' + description: Not Implemented + summary: Get app-wide default agents + tags: + - settings + put: + operationId: setAgentDefaults + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AgentDefaultsRequest' + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AgentDefaultsResponse' + description: OK + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/APIError' + description: Bad Request + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/APIError' + description: Internal Server Error + "501": + content: + application/json: + schema: + $ref: '#/components/schemas/APIError' + description: Not Implemented + summary: Set app-wide default agents + tags: + - settings components: schemas: APIError: @@ -1195,6 +1256,120 @@ components: permissions: type: string type: object + AgentDefaultsRequest: + properties: + defaultOrchestratorAgent: + enum: + - claude-code + - codex + - aider + - opencode + - grok + - droid + - amp + - agy + - crush + - cursor + - qwen + - copilot + - goose + - auggie + - continue + - devin + - cline + - kimi + - kiro + - kilocode + - vibe + - pi + - autohand + type: string + defaultWorkerAgent: + enum: + - claude-code + - codex + - aider + - opencode + - grok + - droid + - amp + - agy + - crush + - cursor + - qwen + - copilot + - goose + - auggie + - continue + - devin + - cline + - kimi + - kiro + - kilocode + - vibe + - pi + - autohand + type: string + type: object + AgentDefaultsResponse: + properties: + configured: + type: boolean + defaultOrchestratorAgent: + enum: + - claude-code + - codex + - aider + - opencode + - grok + - droid + - amp + - agy + - crush + - cursor + - qwen + - copilot + - goose + - auggie + - continue + - devin + - cline + - kimi + - kiro + - kilocode + - vibe + - pi + - autohand + type: string + defaultWorkerAgent: + enum: + - claude-code + - codex + - aider + - opencode + - grok + - droid + - amp + - agy + - crush + - cursor + - qwen + - copilot + - goose + - auggie + - continue + - devin + - cline + - kimi + - kiro + - kilocode + - vibe + - pi + - autohand + type: string + required: + - configured + type: object ClaimPRRequest: properties: allowTakeover: @@ -1808,6 +1983,32 @@ components: properties: clean: type: boolean + harness: + enum: + - claude-code + - codex + - aider + - opencode + - grok + - droid + - amp + - agy + - crush + - cursor + - qwen + - copilot + - goose + - auggie + - continue + - devin + - cline + - kimi + - kiro + - kilocode + - vibe + - pi + - autohand + type: string projectId: type: string required: @@ -1903,6 +2104,8 @@ tags: name: prs - description: Code-review runs and findings name: reviews +- description: App-wide user settings + name: settings - description: Durable dashboard notifications name: notifications - description: Server-sent CDC event stream with durable replay diff --git a/backend/internal/httpd/apispec/specgen/build.go b/backend/internal/httpd/apispec/specgen/build.go index 2aeca734..a906710f 100644 --- a/backend/internal/httpd/apispec/specgen/build.go +++ b/backend/internal/httpd/apispec/specgen/build.go @@ -63,6 +63,8 @@ func Build() ([]byte, error) { "Pull-request actions (SCM lane)"), *(&openapi31.Tag{Name: "reviews"}).WithDescription( "Code-review runs and findings"), + *(&openapi31.Tag{Name: "settings"}).WithDescription( + "App-wide user settings"), *(&openapi31.Tag{Name: "notifications"}).WithDescription( "Durable dashboard notifications"), *(&openapi31.Tag{Name: "events"}).WithDescription( @@ -129,6 +131,7 @@ var schemaNames = map[string]string{ "DomainProjectConfig": "ProjectConfig", "DomainAgentConfig": "AgentConfig", "DomainRoleOverride": "RoleOverride", + "DomainAgentDefaults": "AgentDefaults", // httpd/controllers (wire envelopes) "ControllersListProjectsResponse": "ListProjectsResponse", "ControllersProjectResponse": "ProjectResponse", @@ -157,6 +160,8 @@ var schemaNames = map[string]string{ "ControllersSpawnOrchestratorRequest": "SpawnOrchestratorRequest", "ControllersSpawnOrchestratorResponse": "SpawnOrchestratorResponse", "ControllersOrchestratorResponse": "OrchestratorResponse", + "ControllersAgentDefaultsRequest": "AgentDefaultsRequest", + "ControllersAgentDefaultsResponse": "AgentDefaultsResponse", "ControllersListNotificationsQuery": "ListNotificationsQuery", "ControllersNotificationStreamQuery": "NotificationStreamQuery", "ControllersNotificationTarget": "NotificationTarget", @@ -258,10 +263,36 @@ func operations() []operation { ops = append(ops, sessionOperations()...) ops = append(ops, prOperations()...) ops = append(ops, reviewOperations()...) + ops = append(ops, settingsOperations()...) ops = append(ops, notificationOperations()...) return ops } +func settingsOperations() []operation { + return []operation{ + { + method: http.MethodGet, path: "/api/v1/settings/agents", id: "getAgentDefaults", tag: "settings", + summary: "Get app-wide default agents", + resps: []respUnit{ + {http.StatusOK, controllers.AgentDefaultsResponse{}}, + {http.StatusInternalServerError, envelope.APIError{}}, + {http.StatusNotImplemented, envelope.APIError{}}, + }, + }, + { + method: http.MethodPut, path: "/api/v1/settings/agents", id: "setAgentDefaults", tag: "settings", + summary: "Set app-wide default agents", + reqBody: controllers.AgentDefaultsRequest{}, + resps: []respUnit{ + {http.StatusOK, controllers.AgentDefaultsResponse{}}, + {http.StatusBadRequest, envelope.APIError{}}, + {http.StatusInternalServerError, envelope.APIError{}}, + {http.StatusNotImplemented, envelope.APIError{}}, + }, + }, + } +} + func notificationOperations() []operation { return []operation{ { diff --git a/backend/internal/httpd/controllers/dto.go b/backend/internal/httpd/controllers/dto.go index 2da2fe91..277a10b6 100644 --- a/backend/internal/httpd/controllers/dto.go +++ b/backend/internal/httpd/controllers/dto.go @@ -255,8 +255,9 @@ type OrchestratorIDParam struct { // SpawnOrchestratorRequest is the body of POST /api/v1/orchestrators. type SpawnOrchestratorRequest struct { - ProjectID domain.ProjectID `json:"projectId"` - Clean bool `json:"clean,omitempty"` + ProjectID domain.ProjectID `json:"projectId"` + Clean bool `json:"clean,omitempty"` + Harness domain.AgentHarness `json:"harness,omitempty" enum:"claude-code,codex,aider,opencode,grok,droid,amp,agy,crush,cursor,qwen,copilot,goose,auggie,continue,devin,cline,kimi,kiro,kilocode,vibe,pi,autohand"` } // SpawnOrchestratorResponse is the body of POST /api/v1/orchestrators. @@ -271,6 +272,17 @@ type OrchestratorResponse struct { ProjectName string `json:"projectName,omitempty"` } +// AgentDefaultsRequest is the body of PUT /api/v1/settings/agents. +type AgentDefaultsRequest struct { + domain.AgentDefaults +} + +// AgentDefaultsResponse is the body of GET/PUT /api/v1/settings/agents. +type AgentDefaultsResponse struct { + domain.AgentDefaults + Configured bool `json:"configured"` +} + // ListNotificationsQuery is the query string accepted by GET /api/v1/notifications. type ListNotificationsQuery struct { Status string `query:"status,omitempty" enum:"unread" description:"Notification status filter. V1 supports only unread."` diff --git a/backend/internal/httpd/controllers/sessions.go b/backend/internal/httpd/controllers/sessions.go index bc0a4504..7fbcb651 100644 --- a/backend/internal/httpd/controllers/sessions.go +++ b/backend/internal/httpd/controllers/sessions.go @@ -26,7 +26,7 @@ const ( type SessionService interface { List(ctx context.Context, filter sessionsvc.ListFilter) ([]domain.Session, error) Spawn(ctx context.Context, cfg ports.SpawnConfig) (domain.Session, error) - SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, clean bool) (domain.Session, error) + SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, clean bool, harness domain.AgentHarness) (domain.Session, error) Get(ctx context.Context, id domain.SessionID) (domain.Session, error) Restore(ctx context.Context, id domain.SessionID) (domain.Session, error) Kill(ctx context.Context, id domain.SessionID) (bool, error) @@ -328,7 +328,7 @@ func (c *SessionsController) spawnOrchestrator(w http.ResponseWriter, r *http.Re envelope.WriteAPIError(w, r, http.StatusBadRequest, "bad_request", "PROJECT_ID_REQUIRED", "projectId is required", nil) return } - sess, err := c.Svc.SpawnOrchestrator(r.Context(), in.ProjectID, in.Clean) + sess, err := c.Svc.SpawnOrchestrator(r.Context(), in.ProjectID, in.Clean, in.Harness) if err != nil { envelope.WriteError(w, r, err) return diff --git a/backend/internal/httpd/controllers/sessions_test.go b/backend/internal/httpd/controllers/sessions_test.go index 3bc53b60..622a570d 100644 --- a/backend/internal/httpd/controllers/sessions_test.go +++ b/backend/internal/httpd/controllers/sessions_test.go @@ -57,7 +57,7 @@ func (f *fakeSessionService) Spawn(_ context.Context, cfg ports.SpawnConfig) (do return s, nil } -func (f *fakeSessionService) SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, clean bool) (domain.Session, error) { +func (f *fakeSessionService) SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, clean bool, harness domain.AgentHarness) (domain.Session, error) { if clean { active := true existing, err := f.List(ctx, sessionsvc.ListFilter{ProjectID: projectID, Active: &active, OrchestratorOnly: true}) @@ -70,7 +70,7 @@ func (f *fakeSessionService) SpawnOrchestrator(ctx context.Context, projectID do } } } - return f.Spawn(ctx, ports.SpawnConfig{ProjectID: projectID, Kind: domain.KindOrchestrator}) + return f.Spawn(ctx, ports.SpawnConfig{ProjectID: projectID, Kind: domain.KindOrchestrator, Harness: harness}) } func (f *fakeSessionService) Get(_ context.Context, id domain.SessionID) (domain.Session, error) { diff --git a/backend/internal/httpd/controllers/settings.go b/backend/internal/httpd/controllers/settings.go new file mode 100644 index 00000000..a79c17f0 --- /dev/null +++ b/backend/internal/httpd/controllers/settings.go @@ -0,0 +1,67 @@ +package controllers + +import ( + "context" + "net/http" + + "github.com/go-chi/chi/v5" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/httpd/apispec" + "github.com/aoagents/agent-orchestrator/backend/internal/httpd/envelope" +) + +// SettingsService is the controller-facing app settings contract. +type SettingsService interface { + GetAgentDefaults(ctx context.Context) (domain.AgentDefaults, error) + SetAgentDefaults(ctx context.Context, defaults domain.AgentDefaults) (domain.AgentDefaults, error) +} + +// SettingsController owns app-wide user settings routes. Nil keeps routes +// registered but returns OpenAPI-backed 501s. +type SettingsController struct { + Svc SettingsService +} + +// Register mounts the settings routes on the supplied router. +func (c *SettingsController) Register(r chi.Router) { + r.Get("/settings/agents", c.getAgentDefaults) + r.Put("/settings/agents", c.setAgentDefaults) +} + +func (c *SettingsController) getAgentDefaults(w http.ResponseWriter, r *http.Request) { + if c.Svc == nil { + apispec.NotImplemented(w, r, "GET", "/api/v1/settings/agents") + return + } + defaults, err := c.Svc.GetAgentDefaults(r.Context()) + if err != nil { + envelope.WriteError(w, r, err) + return + } + envelope.WriteJSON(w, http.StatusOK, AgentDefaultsResponse{ + AgentDefaults: defaults, + Configured: defaults.Complete(), + }) +} + +func (c *SettingsController) setAgentDefaults(w http.ResponseWriter, r *http.Request) { + if c.Svc == nil { + apispec.NotImplemented(w, r, "PUT", "/api/v1/settings/agents") + return + } + var in AgentDefaultsRequest + if err := decodeJSONStrict(r, &in); err != nil { + envelope.WriteAPIError(w, r, http.StatusBadRequest, "bad_request", "INVALID_JSON", "Invalid JSON body", nil) + return + } + defaults, err := c.Svc.SetAgentDefaults(r.Context(), in.AgentDefaults) + if err != nil { + envelope.WriteError(w, r, err) + return + } + envelope.WriteJSON(w, http.StatusOK, AgentDefaultsResponse{ + AgentDefaults: defaults, + Configured: defaults.Complete(), + }) +} diff --git a/backend/internal/httpd/controllers/settings_test.go b/backend/internal/httpd/controllers/settings_test.go new file mode 100644 index 00000000..efa4dd5a --- /dev/null +++ b/backend/internal/httpd/controllers/settings_test.go @@ -0,0 +1,80 @@ +package controllers_test + +import ( + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/aoagents/agent-orchestrator/backend/internal/config" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/httpd" + "github.com/aoagents/agent-orchestrator/backend/internal/httpd/apierr" +) + +type fakeSettingsService struct { + defaults domain.AgentDefaults + err error + saved domain.AgentDefaults +} + +func (f *fakeSettingsService) GetAgentDefaults(context.Context) (domain.AgentDefaults, error) { + return f.defaults, f.err +} + +func (f *fakeSettingsService) SetAgentDefaults(_ context.Context, defaults domain.AgentDefaults) (domain.AgentDefaults, error) { + if f.err != nil { + return domain.AgentDefaults{}, f.err + } + f.saved = defaults + f.defaults = defaults + return defaults, nil +} + +func newSettingsTestServer(t *testing.T, svc *fakeSettingsService) *httptest.Server { + t.Helper() + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := httptest.NewServer(httpd.NewRouterWithControl(config.Config{}, log, nil, httpd.APIDeps{Settings: svc}, httpd.ControlDeps{})) + t.Cleanup(srv.Close) + return srv +} + +func TestSettingsAPI_AgentDefaultsRoundTrip(t *testing.T) { + svc := &fakeSettingsService{} + srv := newSettingsTestServer(t, svc) + + body, status, _ := doRequest(t, srv, "GET", "/api/v1/settings/agents", "") + if status != http.StatusOK { + t.Fatalf("GET settings = %d, want 200; body=%s", status, body) + } + var got struct { + DefaultWorkerAgent string `json:"defaultWorkerAgent"` + DefaultOrchestratorAgent string `json:"defaultOrchestratorAgent"` + Configured bool `json:"configured"` + } + mustJSON(t, body, &got) + if got.Configured || got.DefaultWorkerAgent != "" || got.DefaultOrchestratorAgent != "" { + t.Fatalf("first-run settings = %#v, want empty incomplete defaults", got) + } + + body, status, _ = doRequest(t, srv, "PUT", "/api/v1/settings/agents", `{"defaultWorkerAgent":"codex","defaultOrchestratorAgent":"claude-code"}`) + if status != http.StatusOK { + t.Fatalf("PUT settings = %d, want 200; body=%s", status, body) + } + mustJSON(t, body, &got) + if !got.Configured || got.DefaultWorkerAgent != "codex" || got.DefaultOrchestratorAgent != "claude-code" { + t.Fatalf("saved settings = %#v", got) + } + if svc.saved.DefaultWorkerAgent != domain.HarnessCodex || svc.saved.DefaultOrchestratorAgent != domain.HarnessClaudeCode { + t.Fatalf("service saved = %+v", svc.saved) + } +} + +func TestSettingsAPI_ValidationError(t *testing.T) { + srv := newSettingsTestServer(t, &fakeSettingsService{err: apierr.Invalid("INVALID_AGENT_DEFAULTS", "defaultOrchestratorAgent is required", nil)}) + + body, status, _ := doRequest(t, srv, "PUT", "/api/v1/settings/agents", `{"defaultWorkerAgent":"codex"}`) + assertErrorCode(t, body, status, http.StatusBadRequest, "INVALID_AGENT_DEFAULTS") +} diff --git a/backend/internal/integration/lifecycle_sqlite_test.go b/backend/internal/integration/lifecycle_sqlite_test.go index 9a0a9960..7ff9b884 100644 --- a/backend/internal/integration/lifecycle_sqlite_test.go +++ b/backend/internal/integration/lifecycle_sqlite_test.go @@ -92,12 +92,18 @@ func newStack(t *testing.T) *stack { if err := store.UpsertProject(ctx, domain.ProjectRecord{ID: "mer", Path: "/repo/mer", RegisteredAt: time.Now()}); err != nil { t.Fatal(err) } + if err := store.SetAgentDefaults(ctx, domain.AgentDefaults{ + DefaultWorkerAgent: domain.HarnessClaudeCode, + DefaultOrchestratorAgent: domain.HarnessClaudeCode, + }); err != nil { + t.Fatal(err) + } msg := &captureMessenger{} lcm := lifecycle.New(store, msg) prm := prsvc.New(prsvc.Deps{Writer: store, Lifecycle: lcm}) rt := &stubRuntime{} ws := &stubWorkspace{} - mgr := sessionmanager.New(sessionmanager.Deps{Runtime: rt, Agents: stubAgents{}, Workspace: ws, Store: store, Messenger: msg, Lifecycle: lcm, LookPath: func(string) (string, error) { return "/usr/bin/true", nil }}) + mgr := sessionmanager.New(sessionmanager.Deps{Runtime: rt, Agents: stubAgents{}, Workspace: ws, Store: store, Messenger: msg, Lifecycle: lcm, LookPath: func(string) (string, error) { return "/usr/bin/true", nil }, AgentDefaults: store}) sm := sessionsvc.New(mgr, store) return &stack{store: store, sm: sm, lcm: lcm, prm: prm, rt: rt, ws: ws, msg: msg} } diff --git a/backend/internal/service/session/service.go b/backend/internal/service/session/service.go index 26a39fa0..26aaa1a8 100644 --- a/backend/internal/service/session/service.go +++ b/backend/internal/service/session/service.go @@ -158,7 +158,7 @@ func (s *Service) requireProject(ctx context.Context, id domain.ProjectID) error // true it first tears down any active orchestrator(s) for that project so the new // one is the only live coordinator — a business rule that belongs here, not in the // HTTP controller. -func (s *Service) SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, clean bool) (domain.Session, error) { +func (s *Service) SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, clean bool, harness domain.AgentHarness) (domain.Session, error) { if err := s.requireProject(ctx, projectID); err != nil { return domain.Session{}, err } @@ -174,7 +174,7 @@ func (s *Service) SpawnOrchestrator(ctx context.Context, projectID domain.Projec } } } - return s.Spawn(ctx, ports.SpawnConfig{ProjectID: projectID, Kind: domain.KindOrchestrator}) + return s.Spawn(ctx, ports.SpawnConfig{ProjectID: projectID, Kind: domain.KindOrchestrator, Harness: harness}) } // Restore relaunches a terminated session and returns the API-facing read model. @@ -341,6 +341,8 @@ func toAPIError(err error) error { return apierr.Invalid("PROJECT_NOT_RESOLVABLE", "Project is not registered or has no repo — register it with `ao project add`", nil) case errors.Is(err, sessionmanager.ErrUnknownHarness): return apierr.Invalid("UNKNOWN_HARNESS", err.Error(), nil) + case errors.Is(err, sessionmanager.ErrDefaultAgentRequired): + return apierr.Invalid("DEFAULT_AGENT_REQUIRED", "Choose default agents in Settings before spawning sessions", nil) case errors.Is(err, ports.ErrWorkspaceBranchCheckedOutElsewhere): return apierr.Conflict("BRANCH_CHECKED_OUT_ELSEWHERE", err.Error(), nil) case errors.Is(err, ports.ErrWorkspaceBranchNotFetched): diff --git a/backend/internal/service/session/service_test.go b/backend/internal/service/session/service_test.go index d9c2ec0b..c6567a26 100644 --- a/backend/internal/service/session/service_test.go +++ b/backend/internal/service/session/service_test.go @@ -140,11 +140,13 @@ type fakeCommander struct { killErr error cleanupErr error spawned bool + spawnedConfig ports.SpawnConfig killsAtSpawn int } func (f *fakeCommander) Spawn(_ context.Context, cfg ports.SpawnConfig) (domain.SessionRecord, error) { f.spawned = true + f.spawnedConfig = cfg f.killsAtSpawn = len(f.killed) return domain.SessionRecord{ID: "mer-9", ProjectID: cfg.ProjectID, Kind: cfg.Kind}, nil } @@ -237,7 +239,7 @@ func TestSpawnOrchestratorCleanKillsActiveOrchestratorsBeforeSpawn(t *testing.T) fc := &fakeCommander{} svc := &Service{manager: fc, store: st} - if _, err := svc.SpawnOrchestrator(context.Background(), "mer", true); err != nil { + if _, err := svc.SpawnOrchestrator(context.Background(), "mer", true, ""); err != nil { t.Fatalf("SpawnOrchestrator: %v", err) } @@ -249,6 +251,20 @@ func TestSpawnOrchestratorCleanKillsActiveOrchestratorsBeforeSpawn(t *testing.T) } } +func TestSpawnOrchestratorForwardsExplicitHarness(t *testing.T) { + st := newFakeStore() + st.projects["mer"] = domain.ProjectRecord{ID: "mer"} + fc := &fakeCommander{} + svc := &Service{manager: fc, store: st} + + if _, err := svc.SpawnOrchestrator(context.Background(), "mer", false, domain.HarnessCodex); err != nil { + t.Fatalf("SpawnOrchestrator: %v", err) + } + if fc.spawnedConfig.Kind != domain.KindOrchestrator || fc.spawnedConfig.Harness != domain.HarnessCodex { + t.Fatalf("spawn config = %+v, want orchestrator codex", fc.spawnedConfig) + } +} + // TestSpawnUnknownProjectReturns404 covers Bug 1: an HTTP spawn for an // unregistered projectId must surface PROJECT_NOT_FOUND (apierr.NotFound) // BEFORE any session row is created, so no orphan terminated row is left @@ -275,7 +291,7 @@ func TestSpawnOrchestratorUnknownProjectReturns404(t *testing.T) { fc := &fakeCommander{} svc := &Service{manager: fc, store: st} - _, err := svc.SpawnOrchestrator(context.Background(), "ghost", false) + _, err := svc.SpawnOrchestrator(context.Background(), "ghost", false, "") var e *apierr.Error if !errors.As(err, &e) || e.Kind != apierr.KindNotFound || e.Code != "PROJECT_NOT_FOUND" { t.Fatalf("err = %v, want apierr.NotFound PROJECT_NOT_FOUND", err) @@ -300,6 +316,7 @@ func TestToAPIErrorMapsWorkspaceBranchSentinels(t *testing.T) { {"invalid branch", fmt.Errorf("spawn mer-1: workspace: %w: \"bad!!\" (exit 1)", ports.ErrWorkspaceBranchInvalid), apierr.KindInvalid, "INVALID_BRANCH"}, {"agent binary not found", fmt.Errorf("spawn mer-1: %w", ports.ErrAgentBinaryNotFound), apierr.KindInvalid, "AGENT_BINARY_NOT_FOUND"}, {"unknown harness", fmt.Errorf("spawn: %w: %q", sessionmanager.ErrUnknownHarness, "bogus"), apierr.KindInvalid, "UNKNOWN_HARNESS"}, + {"default agent required", fmt.Errorf("spawn: %w", sessionmanager.ErrDefaultAgentRequired), apierr.KindInvalid, "DEFAULT_AGENT_REQUIRED"}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -320,7 +337,7 @@ func TestSpawnOrchestratorNoCleanSkipsKills(t *testing.T) { fc := &fakeCommander{} svc := &Service{manager: fc, store: st} - if _, err := svc.SpawnOrchestrator(context.Background(), "mer", false); err != nil { + if _, err := svc.SpawnOrchestrator(context.Background(), "mer", false, ""); err != nil { t.Fatalf("SpawnOrchestrator: %v", err) } if len(fc.killed) != 0 || !fc.spawned { diff --git a/backend/internal/service/settings/service.go b/backend/internal/service/settings/service.go new file mode 100644 index 00000000..925ff521 --- /dev/null +++ b/backend/internal/service/settings/service.go @@ -0,0 +1,46 @@ +package settings + +import ( + "context" + "fmt" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/httpd/apierr" +) + +// Store is the persistence surface for app-wide user settings. +type Store interface { + GetAgentDefaults(ctx context.Context) (domain.AgentDefaults, bool, error) + SetAgentDefaults(ctx context.Context, defaults domain.AgentDefaults) error +} + +// Service owns validation and persistence for app-wide user settings. +type Service struct { + store Store +} + +// New wires a settings service over a store. +func New(store Store) *Service { + return &Service{store: store} +} + +// GetAgentDefaults returns configured defaults. Missing settings return the +// zero value so callers can distinguish first-run setup by Complete(). +func (s *Service) GetAgentDefaults(ctx context.Context) (domain.AgentDefaults, error) { + defaults, _, err := s.store.GetAgentDefaults(ctx) + if err != nil { + return domain.AgentDefaults{}, fmt.Errorf("get agent defaults: %w", err) + } + return defaults, nil +} + +// SetAgentDefaults validates and persists app-wide agent defaults. +func (s *Service) SetAgentDefaults(ctx context.Context, defaults domain.AgentDefaults) (domain.AgentDefaults, error) { + if err := defaults.ValidateComplete(); err != nil { + return domain.AgentDefaults{}, apierr.Invalid("INVALID_AGENT_DEFAULTS", err.Error(), nil) + } + if err := s.store.SetAgentDefaults(ctx, defaults); err != nil { + return domain.AgentDefaults{}, fmt.Errorf("set agent defaults: %w", err) + } + return defaults, nil +} diff --git a/backend/internal/service/settings/service_test.go b/backend/internal/service/settings/service_test.go new file mode 100644 index 00000000..eb83d3e7 --- /dev/null +++ b/backend/internal/service/settings/service_test.go @@ -0,0 +1,59 @@ +package settings + +import ( + "context" + "errors" + "testing" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/httpd/apierr" +) + +type fakeStore struct { + defaults domain.AgentDefaults + ok bool + err error + saved domain.AgentDefaults +} + +func (f *fakeStore) GetAgentDefaults(context.Context) (domain.AgentDefaults, bool, error) { + return f.defaults, f.ok, f.err +} + +func (f *fakeStore) SetAgentDefaults(_ context.Context, defaults domain.AgentDefaults) error { + f.saved = defaults + return f.err +} + +func TestGetAgentDefaultsReturnsZeroWhenUnset(t *testing.T) { + got, err := New(&fakeStore{}).GetAgentDefaults(context.Background()) + if err != nil { + t.Fatalf("GetAgentDefaults: %v", err) + } + if got.Complete() { + t.Fatalf("defaults = %+v, want first-run incomplete defaults", got) + } +} + +func TestSetAgentDefaultsValidatesAndPersists(t *testing.T) { + store := &fakeStore{} + defaults := domain.AgentDefaults{ + DefaultWorkerAgent: domain.HarnessCodex, + DefaultOrchestratorAgent: domain.HarnessClaudeCode, + } + got, err := New(store).SetAgentDefaults(context.Background(), defaults) + if err != nil { + t.Fatalf("SetAgentDefaults: %v", err) + } + if got != defaults || store.saved != defaults { + t.Fatalf("got=%+v saved=%+v, want %+v", got, store.saved, defaults) + } +} + +func TestSetAgentDefaultsRejectsMissingValues(t *testing.T) { + _, err := New(&fakeStore{}).SetAgentDefaults(context.Background(), domain.AgentDefaults{DefaultWorkerAgent: domain.HarnessCodex}) + var apiErr *apierr.Error + if !errors.As(err, &apiErr) || apiErr.Kind != apierr.KindInvalid || apiErr.Code != "INVALID_AGENT_DEFAULTS" { + t.Fatalf("err = %v, want INVALID_AGENT_DEFAULTS", err) + } +} diff --git a/backend/internal/session_manager/manager.go b/backend/internal/session_manager/manager.go index 0e33d9b5..7ab1bff5 100644 --- a/backend/internal/session_manager/manager.go +++ b/backend/internal/session_manager/manager.go @@ -32,6 +32,9 @@ var ( // adapter. The API maps it to a 400 so a typo'd `--harness` is a validation // error, not an opaque 500. ErrUnknownHarness = errors.New("session: unknown agent harness") + // ErrDefaultAgentRequired means a spawn named no harness, the project had no + // role override, and the app-wide default for that role has not been set. + ErrDefaultAgentRequired = errors.New("session: default agent required") ) // Env vars a spawned process reads to learn who it is. @@ -76,6 +79,12 @@ type Store interface { DeleteSession(ctx context.Context, id domain.SessionID) (bool, error) } +// AgentDefaults is the app-wide settings surface needed to resolve a spawn that +// does not carry an explicit or project-level harness. +type AgentDefaults interface { + GetAgentDefaults(ctx context.Context) (domain.AgentDefaults, bool, error) +} + // Manager coordinates internal session spawn, restore, kill, and cleanup over // the outbound ports. User-facing read-model assembly lives in the service package. type Manager struct { @@ -86,11 +95,10 @@ type Manager struct { messenger ports.AgentMessenger lcm lifecycleRecorder dataDir string - // defaultHarness is the daemon's configured default agent (AO_AGENT). A spawn - // that names no harness resolves to it before the seed row is written, so the - // stored/returned harness matches the agent the resolver actually launches. - defaultHarness domain.AgentHarness - clock func() time.Time + // agentDefaults supplies app-wide defaults when a spawn names no harness and + // the project has no role override. + agentDefaults AgentDefaults + clock func() time.Time // lookPath is exec.LookPath in production; tests substitute a stub so // they don't need real binaries on PATH. Returns ports.ErrAgentBinaryNotFound // when the binary is missing so the sentinel propagates through toAPIError. @@ -113,12 +121,10 @@ type Deps struct { // DataDir is exported to spawned agents as AO_DATA_DIR so their hook // commands can open the same store. DataDir string - // DefaultHarness is the daemon's configured default agent (AO_AGENT), used to - // resolve a spawn that names no harness. Wiring passes config.DefaultAgent; - // left empty, an unspecified harness stays empty (the resolver still defaults - // it at launch, but the record won't reflect the real agent). - DefaultHarness domain.AgentHarness - Clock func() time.Time + // AgentDefaults stores app-wide defaults used when a spawn names no harness + // and the project has no role override. + AgentDefaults AgentDefaults + Clock func() time.Time // LookPath overrides exec.LookPath for the pre-launch agent-binary check. // Production wiring leaves this nil and the manager defaults to // exec.LookPath; tests inject a stub so they need not seed real binaries. @@ -136,18 +142,18 @@ type Deps struct { // time.Now when Deps.Clock is nil. func New(d Deps) *Manager { m := &Manager{ - runtime: d.Runtime, - agents: d.Agents, - workspace: d.Workspace, - store: d.Store, - messenger: d.Messenger, - lcm: d.Lifecycle, - dataDir: d.DataDir, - defaultHarness: d.DefaultHarness, - clock: d.Clock, - lookPath: d.LookPath, - executable: d.Executable, - logger: d.Logger, + runtime: d.Runtime, + agents: d.Agents, + workspace: d.Workspace, + store: d.Store, + messenger: d.Messenger, + lcm: d.Lifecycle, + dataDir: d.DataDir, + agentDefaults: d.AgentDefaults, + clock: d.Clock, + lookPath: d.LookPath, + executable: d.Executable, + logger: d.Logger, } if m.clock == nil { // UTC so spawn-stamped CreatedAt/UpdatedAt match every other session @@ -167,6 +173,27 @@ func New(d Deps) *Manager { return m } +func (m *Manager) defaultHarnessForKind(ctx context.Context, kind domain.SessionKind) (domain.AgentHarness, error) { + if m.agentDefaults == nil { + return "", ErrDefaultAgentRequired + } + defaults, ok, err := m.agentDefaults.GetAgentDefaults(ctx) + if err != nil { + return "", fmt.Errorf("agent defaults: %w", err) + } + if !ok { + return "", ErrDefaultAgentRequired + } + harness := defaults.HarnessFor(kind) + if harness == "" { + return "", ErrDefaultAgentRequired + } + if !harness.IsKnown() { + return "", fmt.Errorf("%w: %q", ErrUnknownHarness, harness) + } + return harness, nil +} + // Spawn creates the session row (which assigns the "{project}-{n}" id), then the // workspace and runtime, then reports completion to the LCM. If workspace // materialization fails the still-seed row is deleted outright; a later failure @@ -179,12 +206,12 @@ func (m *Manager) Spawn(ctx context.Context, cfg ports.SpawnConfig) (domain.Sess // A per-project role override picks the harness when the spawn names none, // so a project can default workers to one agent and orchestrators to another. cfg.Harness = effectiveHarness(cfg.Harness, cfg.Kind, project.Config) - // Resolve an unspecified harness to the daemon default BEFORE the seed row is - // written, so the stored/returned harness matches the agent the resolver - // launches (otherwise a default-agent session persists an empty harness and - // the UI can't tell which agent is running). if cfg.Harness == "" { - cfg.Harness = m.defaultHarness + harness, err := m.defaultHarnessForKind(ctx, cfg.Kind) + if err != nil { + return domain.SessionRecord{}, fmt.Errorf("spawn: %w", err) + } + cfg.Harness = harness } // Reject an unknown harness before any durable state is created. Doing this @@ -308,7 +335,7 @@ func (m *Manager) loadProject(ctx context.Context, projectID domain.ProjectID) ( // effectiveHarness resolves the harness for a spawn: an explicit harness wins; // otherwise the project's role override for the session kind applies; otherwise -// it stays empty so the daemon's global default (AO_AGENT) is used downstream. +// it stays empty so the app-wide default can be read from settings. func effectiveHarness(explicit domain.AgentHarness, kind domain.SessionKind, cfg domain.ProjectConfig) domain.AgentHarness { if explicit != "" { return explicit diff --git a/backend/internal/session_manager/manager_test.go b/backend/internal/session_manager/manager_test.go index d2474a2b..87ff33ed 100644 --- a/backend/internal/session_manager/manager_test.go +++ b/backend/internal/session_manager/manager_test.go @@ -85,6 +85,23 @@ func (f *fakeStore) GetDisplayPRFactsForSession(_ context.Context, id domain.Ses return domain.PRFacts{}, false, nil } +type fakeAgentDefaults struct { + defaults domain.AgentDefaults + ok bool + err error +} + +func (f fakeAgentDefaults) GetAgentDefaults(context.Context) (domain.AgentDefaults, bool, error) { + return f.defaults, f.ok, f.err +} + +func testAgentDefaults() fakeAgentDefaults { + return fakeAgentDefaults{ok: true, defaults: domain.AgentDefaults{ + DefaultWorkerAgent: domain.HarnessClaudeCode, + DefaultOrchestratorAgent: domain.HarnessClaudeCode, + }} +} + type fakeLCM struct { store *fakeStore completed int @@ -228,7 +245,11 @@ func newManager() (*Manager, *fakeStore, *fakeRuntime, *fakeWorkspace) { // Stub lookPath so the pre-launch agent-binary check passes; the fakeAgent // returns argv ["launch"] which is not a real binary on PATH. lookPath := func(string) (string, error) { return "/bin/true", nil } - m := New(Deps{Runtime: rt, Agents: fakeAgents{}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath}) + m := New(Deps{ + Runtime: rt, Agents: fakeAgents{}, Workspace: ws, Store: st, + Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath, + AgentDefaults: testAgentDefaults(), + }) return m, st, rt, ws } func seedTerminal(st *fakeStore, id domain.SessionID, meta domain.SessionMetadata) { @@ -251,7 +272,7 @@ func TestSpawn_ResolvesProjectConfig(t *testing.T) { rt := &fakeRuntime{} ws := &fakeWorkspace{} lookPath := func(string) (string, error) { return "/bin/true", nil } - m := New(Deps{Runtime: rt, Agents: singleAgent{agent: agent}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath}) + m := New(Deps{Runtime: rt, Agents: singleAgent{agent: agent}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath, AgentDefaults: testAgentDefaults()}) rec, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker}) if err != nil { @@ -284,24 +305,27 @@ func TestSpawn_ResolvesProjectConfig(t *testing.T) { } } -// TestSpawn_PersistsResolvedDefaultHarness locks the fix for the mislabelled -// agent: a spawn that names no harness must persist the daemon's default agent -// (so the API/UI report what actually runs), while an explicit harness wins. -func TestSpawn_PersistsResolvedDefaultHarness(t *testing.T) { +// TestSpawn_PersistsResolvedAppDefaultHarness locks the default resolution +// order: a spawn that names no harness persists the stored app default, while +// an explicit harness wins. +func TestSpawn_PersistsResolvedAppDefaultHarness(t *testing.T) { st := newFakeStore() st.projects["mer"] = domain.ProjectRecord{ID: "mer"} m := New(Deps{ Runtime: &fakeRuntime{}, Agents: fakeAgents{}, Workspace: &fakeWorkspace{}, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, - LookPath: func(string) (string, error) { return "/bin/true", nil }, - DefaultHarness: domain.HarnessClaudeCode, + LookPath: func(string) (string, error) { return "/bin/true", nil }, + AgentDefaults: fakeAgentDefaults{ok: true, defaults: domain.AgentDefaults{ + DefaultWorkerAgent: domain.HarnessCodex, + DefaultOrchestratorAgent: domain.HarnessClaudeCode, + }}, }) if _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker}); err != nil { t.Fatal(err) } - if got := st.sessions["mer-1"].Harness; got != domain.HarnessClaudeCode { - t.Fatalf("unspecified harness = %q, want resolved default %q", got, domain.HarnessClaudeCode) + if got := st.sessions["mer-1"].Harness; got != domain.HarnessCodex { + t.Fatalf("unspecified harness = %q, want resolved default %q", got, domain.HarnessCodex) } if _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker, Harness: domain.HarnessCodex}); err != nil { @@ -312,6 +336,24 @@ func TestSpawn_PersistsResolvedDefaultHarness(t *testing.T) { } } +func TestSpawn_UnspecifiedHarnessRequiresConfiguredDefault(t *testing.T) { + st := newFakeStore() + st.projects["mer"] = domain.ProjectRecord{ID: "mer"} + m := New(Deps{ + Runtime: &fakeRuntime{}, Agents: fakeAgents{}, Workspace: &fakeWorkspace{}, Store: st, + Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, + LookPath: func(string) (string, error) { return "/bin/true", nil }, + }) + + _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker}) + if !errors.Is(err, ErrDefaultAgentRequired) { + t.Fatalf("err = %v, want ErrDefaultAgentRequired", err) + } + if len(st.sessions) != 0 { + t.Fatalf("spawn without defaults must not create a session row: %+v", st.sessions) + } +} + func TestSpawn_AssignsIDAndGoesIdle(t *testing.T) { m, st, rt, _ := newManager() s, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker, Prompt: "do it"}) @@ -471,7 +513,7 @@ func TestRestore_AppliesProjectAgentConfig(t *testing.T) { seedTerminal(st, "mer-1", domain.SessionMetadata{WorkspacePath: "/ws/mer-1", Branch: "b", AgentSessionID: "agent-x"}) agent := &recordingAgent{} lookPath := func(string) (string, error) { return "/bin/true", nil } - m := New(Deps{Runtime: &fakeRuntime{}, Agents: singleAgent{agent: agent}, Workspace: &fakeWorkspace{}, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath}) + m := New(Deps{Runtime: &fakeRuntime{}, Agents: singleAgent{agent: agent}, Workspace: &fakeWorkspace{}, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath, AgentDefaults: testAgentDefaults()}) if _, err := m.Restore(ctx, "mer-1"); err != nil { t.Fatal(err) @@ -565,7 +607,7 @@ func TestSpawn_ForwardsResolvedAgentConfigPermissions(t *testing.T) { }} agent := &recordingAgent{} lookPath := func(string) (string, error) { return "/bin/true", nil } - m := New(Deps{Runtime: &fakeRuntime{}, Agents: singleAgent{agent: agent}, Workspace: &fakeWorkspace{}, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath}) + m := New(Deps{Runtime: &fakeRuntime{}, Agents: singleAgent{agent: agent}, Workspace: &fakeWorkspace{}, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath, AgentDefaults: testAgentDefaults()}) _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker}) if err != nil { @@ -616,7 +658,7 @@ func TestSpawnWorker_AppendsActiveOrchestratorContact(t *testing.T) { rt := &fakeRuntime{} ws := &fakeWorkspace{} lookPath := func(string) (string, error) { return "/bin/true", nil } - m := New(Deps{Runtime: rt, Agents: singleAgent{agent: agent}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath}) + m := New(Deps{Runtime: rt, Agents: singleAgent{agent: agent}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath, AgentDefaults: testAgentDefaults()}) s, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker, Prompt: "do it"}) if err != nil { @@ -652,7 +694,7 @@ func TestSpawnWorker_SkipsTerminatedOrchestratorContact(t *testing.T) { rt := &fakeRuntime{} ws := &fakeWorkspace{} lookPath := func(string) (string, error) { return "/bin/true", nil } - m := New(Deps{Runtime: rt, Agents: singleAgent{agent: agent}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath}) + m := New(Deps{Runtime: rt, Agents: singleAgent{agent: agent}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath, AgentDefaults: testAgentDefaults()}) _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker, Prompt: "do it"}) if err != nil { @@ -670,7 +712,7 @@ func TestSpawnOrchestrator_UsesCoordinatorPrompt(t *testing.T) { rt := &fakeRuntime{} ws := &fakeWorkspace{} lookPath := func(string) (string, error) { return "/bin/true", nil } - m := New(Deps{Runtime: rt, Agents: singleAgent{agent: agent}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath}) + m := New(Deps{Runtime: rt, Agents: singleAgent{agent: agent}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath, AgentDefaults: testAgentDefaults()}) _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindOrchestrator}) if err != nil { @@ -843,7 +885,7 @@ func TestSpawn_RejectsMissingAgentBinary(t *testing.T) { notFound := func(name string) (string, error) { return "", fmt.Errorf("exec: %q: not found", name) } - m := New(Deps{Runtime: rt, Agents: fakeAgents{}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: notFound}) + m := New(Deps{Runtime: rt, Agents: fakeAgents{}, Workspace: ws, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: notFound, AgentDefaults: testAgentDefaults()}) _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker}) if !errors.Is(err, ports.ErrAgentBinaryNotFound) { @@ -894,7 +936,8 @@ func pathPinManager(executable func() (string, error)) (*Manager, *fakeStore, *f Runtime: rt, Agents: fakeAgents{}, Workspace: &fakeWorkspace{}, Store: st, Messenger: &fakeMessenger{}, Lifecycle: &fakeLCM{store: st}, LookPath: lookPath, Executable: executable, - Logger: slog.New(slog.NewTextHandler(logBuf, nil)), + AgentDefaults: testAgentDefaults(), + Logger: slog.New(slog.NewTextHandler(logBuf, nil)), }) return m, st, rt, logBuf } diff --git a/backend/internal/storage/sqlite/gen/models.go b/backend/internal/storage/sqlite/gen/models.go index 589bfed0..21fd48ed 100644 --- a/backend/internal/storage/sqlite/gen/models.go +++ b/backend/internal/storage/sqlite/gen/models.go @@ -12,6 +12,12 @@ import ( "github.com/aoagents/agent-orchestrator/backend/internal/domain" ) +type AppSetting struct { + ID int64 + DefaultWorkerAgent domain.AgentHarness + DefaultOrchestratorAgent domain.AgentHarness +} + type ChangeLog struct { Seq int64 ProjectID domain.ProjectID diff --git a/backend/internal/storage/sqlite/gen/settings.sql.go b/backend/internal/storage/sqlite/gen/settings.sql.go new file mode 100644 index 00000000..c64130d1 --- /dev/null +++ b/backend/internal/storage/sqlite/gen/settings.sql.go @@ -0,0 +1,48 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: settings.sql + +package gen + +import ( + "context" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" +) + +const getAgentDefaults = `-- name: GetAgentDefaults :one +SELECT default_worker_agent, default_orchestrator_agent +FROM app_settings +WHERE id = 1 +` + +type GetAgentDefaultsRow struct { + DefaultWorkerAgent domain.AgentHarness + DefaultOrchestratorAgent domain.AgentHarness +} + +func (q *Queries) GetAgentDefaults(ctx context.Context) (GetAgentDefaultsRow, error) { + row := q.db.QueryRowContext(ctx, getAgentDefaults) + var i GetAgentDefaultsRow + err := row.Scan(&i.DefaultWorkerAgent, &i.DefaultOrchestratorAgent) + return i, err +} + +const upsertAgentDefaults = `-- name: UpsertAgentDefaults :exec +INSERT INTO app_settings (id, default_worker_agent, default_orchestrator_agent) +VALUES (1, ?, ?) +ON CONFLICT(id) DO UPDATE SET + default_worker_agent = excluded.default_worker_agent, + default_orchestrator_agent = excluded.default_orchestrator_agent +` + +type UpsertAgentDefaultsParams struct { + DefaultWorkerAgent domain.AgentHarness + DefaultOrchestratorAgent domain.AgentHarness +} + +func (q *Queries) UpsertAgentDefaults(ctx context.Context, arg UpsertAgentDefaultsParams) error { + _, err := q.db.ExecContext(ctx, upsertAgentDefaults, arg.DefaultWorkerAgent, arg.DefaultOrchestratorAgent) + return err +} diff --git a/backend/internal/storage/sqlite/migrations/0014_app_settings.sql b/backend/internal/storage/sqlite/migrations/0014_app_settings.sql new file mode 100644 index 00000000..d60d3c55 --- /dev/null +++ b/backend/internal/storage/sqlite/migrations/0014_app_settings.sql @@ -0,0 +1,17 @@ +-- +goose Up +-- +goose StatementBegin + +CREATE TABLE app_settings ( + id INTEGER PRIMARY KEY CHECK (id = 1), + default_worker_agent TEXT NOT NULL DEFAULT '' + CHECK (default_worker_agent IN ('', 'claude-code', 'codex', 'aider', 'opencode', 'grok', 'droid', 'amp', 'agy', 'crush', 'cursor', 'qwen', 'copilot', 'goose', 'auggie', 'continue', 'devin', 'cline', 'kimi', 'kiro', 'kilocode', 'vibe', 'pi', 'autohand')), + default_orchestrator_agent TEXT NOT NULL DEFAULT '' + CHECK (default_orchestrator_agent IN ('', 'claude-code', 'codex', 'aider', 'opencode', 'grok', 'droid', 'amp', 'agy', 'crush', 'cursor', 'qwen', 'copilot', 'goose', 'auggie', 'continue', 'devin', 'cline', 'kimi', 'kiro', 'kilocode', 'vibe', 'pi', 'autohand')) +); + +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TABLE app_settings; +-- +goose StatementEnd diff --git a/backend/internal/storage/sqlite/queries/settings.sql b/backend/internal/storage/sqlite/queries/settings.sql new file mode 100644 index 00000000..2b0de0c7 --- /dev/null +++ b/backend/internal/storage/sqlite/queries/settings.sql @@ -0,0 +1,11 @@ +-- name: GetAgentDefaults :one +SELECT default_worker_agent, default_orchestrator_agent +FROM app_settings +WHERE id = 1; + +-- name: UpsertAgentDefaults :exec +INSERT INTO app_settings (id, default_worker_agent, default_orchestrator_agent) +VALUES (1, ?, ?) +ON CONFLICT(id) DO UPDATE SET + default_worker_agent = excluded.default_worker_agent, + default_orchestrator_agent = excluded.default_orchestrator_agent; diff --git a/backend/internal/storage/sqlite/store/settings_store.go b/backend/internal/storage/sqlite/store/settings_store.go new file mode 100644 index 00000000..d77752f9 --- /dev/null +++ b/backend/internal/storage/sqlite/store/settings_store.go @@ -0,0 +1,40 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/storage/sqlite/gen" +) + +// GetAgentDefaults returns the app-wide spawn defaults. ok=false means the +// user has not configured them yet. +func (s *Store) GetAgentDefaults(ctx context.Context) (domain.AgentDefaults, bool, error) { + row, err := s.qr.GetAgentDefaults(ctx) + if errors.Is(err, sql.ErrNoRows) { + return domain.AgentDefaults{}, false, nil + } + if err != nil { + return domain.AgentDefaults{}, false, fmt.Errorf("get agent defaults: %w", err) + } + return domain.AgentDefaults{ + DefaultWorkerAgent: row.DefaultWorkerAgent, + DefaultOrchestratorAgent: row.DefaultOrchestratorAgent, + }, true, nil +} + +// SetAgentDefaults replaces the app-wide spawn defaults. +func (s *Store) SetAgentDefaults(ctx context.Context, defaults domain.AgentDefaults) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if err := s.qw.UpsertAgentDefaults(ctx, gen.UpsertAgentDefaultsParams{ + DefaultWorkerAgent: defaults.DefaultWorkerAgent, + DefaultOrchestratorAgent: defaults.DefaultOrchestratorAgent, + }); err != nil { + return fmt.Errorf("set agent defaults: %w", err) + } + return nil +} diff --git a/backend/internal/storage/sqlite/store/settings_store_test.go b/backend/internal/storage/sqlite/store/settings_store_test.go new file mode 100644 index 00000000..e93a0b59 --- /dev/null +++ b/backend/internal/storage/sqlite/store/settings_store_test.go @@ -0,0 +1,37 @@ +package store_test + +import ( + "context" + "testing" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/storage/sqlite" +) + +func TestAgentDefaultsRoundTrip(t *testing.T) { + st, err := sqlite.Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = st.Close() }) + + ctx := context.Background() + if got, ok, err := st.GetAgentDefaults(ctx); err != nil || ok || got.Complete() { + t.Fatalf("initial defaults got=%+v ok=%v err=%v, want unset", got, ok, err) + } + + want := domain.AgentDefaults{ + DefaultWorkerAgent: domain.HarnessCodex, + DefaultOrchestratorAgent: domain.HarnessClaudeCode, + } + if err := st.SetAgentDefaults(ctx, want); err != nil { + t.Fatalf("SetAgentDefaults: %v", err) + } + got, ok, err := st.GetAgentDefaults(ctx) + if err != nil { + t.Fatalf("GetAgentDefaults: %v", err) + } + if !ok || got != want { + t.Fatalf("defaults got=%+v ok=%v, want %+v", got, ok, want) + } +} diff --git a/backend/sqlc.yaml b/backend/sqlc.yaml index 070b6916..bca92b99 100644 --- a/backend/sqlc.yaml +++ b/backend/sqlc.yaml @@ -80,6 +80,14 @@ sql: go_type: import: "github.com/aoagents/agent-orchestrator/backend/internal/domain" type: "ProjectID" + - column: "app_settings.default_worker_agent" + go_type: + import: "github.com/aoagents/agent-orchestrator/backend/internal/domain" + type: "AgentHarness" + - column: "app_settings.default_orchestrator_agent" + go_type: + import: "github.com/aoagents/agent-orchestrator/backend/internal/domain" + type: "AgentHarness" - column: "session_worktrees.session_id" go_type: import: "github.com/aoagents/agent-orchestrator/backend/internal/domain" diff --git a/frontend/src/api/schema.ts b/frontend/src/api/schema.ts index 0bb8bd4a..7383a752 100644 --- a/frontend/src/api/schema.ts +++ b/frontend/src/api/schema.ts @@ -400,6 +400,24 @@ export interface paths { patch?: never; trace?: never; }; + "/api/v1/settings/agents": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** Get app-wide default agents */ + get: operations["getAgentDefaults"]; + /** Set app-wide default agents */ + put: operations["setAgentDefaults"]; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; } export type webhooks = Record; export interface components { @@ -424,6 +442,19 @@ export interface components { model?: string; permissions?: string; }; + AgentDefaultsRequest: { + /** @enum {string} */ + defaultOrchestratorAgent?: "claude-code" | "codex" | "aider" | "opencode" | "grok" | "droid" | "amp" | "agy" | "crush" | "cursor" | "qwen" | "copilot" | "goose" | "auggie" | "continue" | "devin" | "cline" | "kimi" | "kiro" | "kilocode" | "vibe" | "pi" | "autohand"; + /** @enum {string} */ + defaultWorkerAgent?: "claude-code" | "codex" | "aider" | "opencode" | "grok" | "droid" | "amp" | "agy" | "crush" | "cursor" | "qwen" | "copilot" | "goose" | "auggie" | "continue" | "devin" | "cline" | "kimi" | "kiro" | "kilocode" | "vibe" | "pi" | "autohand"; + }; + AgentDefaultsResponse: { + configured: boolean; + /** @enum {string} */ + defaultOrchestratorAgent?: "claude-code" | "codex" | "aider" | "opencode" | "grok" | "droid" | "amp" | "agy" | "crush" | "cursor" | "qwen" | "copilot" | "goose" | "auggie" | "continue" | "devin" | "cline" | "kimi" | "kiro" | "kilocode" | "vibe" | "pi" | "autohand"; + /** @enum {string} */ + defaultWorkerAgent?: "claude-code" | "codex" | "aider" | "opencode" | "grok" | "droid" | "amp" | "agy" | "crush" | "cursor" | "qwen" | "copilot" | "goose" | "auggie" | "continue" | "devin" | "cline" | "kimi" | "kiro" | "kilocode" | "vibe" | "pi" | "autohand"; + }; ClaimPRRequest: { allowTakeover?: null | boolean; pr: string; @@ -658,6 +689,8 @@ export interface components { }; SpawnOrchestratorRequest: { clean?: boolean; + /** @enum {string} */ + harness?: "claude-code" | "codex" | "aider" | "opencode" | "grok" | "droid" | "amp" | "agy" | "crush" | "cursor" | "qwen" | "copilot" | "goose" | "auggie" | "continue" | "devin" | "cline" | "kimi" | "kiro" | "kilocode" | "vibe" | "pi" | "autohand"; projectId: string; }; SpawnOrchestratorResponse: { @@ -2124,4 +2157,93 @@ export interface operations { }; }; }; + getAgentDefaults: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description OK */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["AgentDefaultsResponse"]; + }; + }; + /** @description Internal Server Error */ + 500: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["APIError"]; + }; + }; + /** @description Not Implemented */ + 501: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["APIError"]; + }; + }; + }; + }; + setAgentDefaults: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["AgentDefaultsRequest"]; + }; + }; + responses: { + /** @description OK */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["AgentDefaultsResponse"]; + }; + }; + /** @description Bad Request */ + 400: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["APIError"]; + }; + }; + /** @description Internal Server Error */ + 500: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["APIError"]; + }; + }; + /** @description Not Implemented */ + 501: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["APIError"]; + }; + }; + }; + }; } diff --git a/frontend/src/renderer/components/AgentDefaultsDialog.test.tsx b/frontend/src/renderer/components/AgentDefaultsDialog.test.tsx new file mode 100644 index 00000000..7726669a --- /dev/null +++ b/frontend/src/renderer/components/AgentDefaultsDialog.test.tsx @@ -0,0 +1,94 @@ +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { render, screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const { getMock, putMock } = vi.hoisted(() => ({ + getMock: vi.fn(), + putMock: vi.fn(), +})); + +vi.mock("../lib/api-client", () => ({ + apiClient: { + GET: getMock, + PUT: putMock, + }, + apiErrorMessage: (error: unknown) => { + if (error instanceof Error) return error.message; + if (typeof error === "object" && error !== null && "message" in error) { + return String((error as { message: unknown }).message); + } + return "Request failed"; + }, +})); + +import { AgentDefaultsDialog } from "./AgentDefaultsDialog"; + +function renderDialog(open = false) { + const queryClient = new QueryClient({ + defaultOptions: { queries: { retry: false }, mutations: { retry: false } }, + }); + const onOpenChange = vi.fn(); + render( + + + , + ); + return onOpenChange; +} + +async function chooseOption(trigger: HTMLElement, optionName: string) { + await userEvent.click(trigger); + await userEvent.click(await screen.findByRole("option", { name: optionName })); +} + +beforeEach(() => { + getMock.mockReset(); + putMock.mockReset(); +}); + +describe("AgentDefaultsDialog", () => { + it("opens on first run and saves selected defaults", async () => { + getMock.mockResolvedValue({ + data: { configured: false }, + error: undefined, + }); + putMock.mockResolvedValue({ + data: { + configured: true, + defaultWorkerAgent: "codex", + defaultOrchestratorAgent: "goose", + }, + error: undefined, + }); + const onOpenChange = renderDialog(false); + + expect(await screen.findByRole("dialog", { name: "Choose Default Agents" })).toBeInTheDocument(); + const save = screen.getByRole("button", { name: "Save defaults" }); + expect(save).toBeDisabled(); + + await chooseOption(screen.getByRole("combobox", { name: "Worker agent" }), "codex"); + await chooseOption(screen.getByRole("combobox", { name: "Orchestrator agent" }), "goose"); + await userEvent.click(save); + + await waitFor(() => expect(putMock).toHaveBeenCalledTimes(1)); + expect(putMock).toHaveBeenCalledWith("/api/v1/settings/agents", { + body: { defaultWorkerAgent: "codex", defaultOrchestratorAgent: "goose" }, + }); + expect(onOpenChange).toHaveBeenCalledWith(false); + }); + + it("stays hidden when configured and not explicitly opened", async () => { + getMock.mockResolvedValue({ + data: { + configured: true, + defaultWorkerAgent: "codex", + defaultOrchestratorAgent: "claude-code", + }, + error: undefined, + }); + renderDialog(false); + + await waitFor(() => expect(screen.queryByRole("dialog")).not.toBeInTheDocument()); + }); +}); diff --git a/frontend/src/renderer/components/AgentDefaultsDialog.tsx b/frontend/src/renderer/components/AgentDefaultsDialog.tsx new file mode 100644 index 00000000..931ec7f7 --- /dev/null +++ b/frontend/src/renderer/components/AgentDefaultsDialog.tsx @@ -0,0 +1,148 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { Bot, X } from "lucide-react"; +import type { ReactNode } from "react"; +import { useEffect, useMemo, useState } from "react"; +import { agentDefaultsQueryKey, fetchAgentDefaults, saveAgentDefaults } from "../lib/agent-defaults"; +import { AGENT_OPTIONS, type AgentOption } from "../lib/agent-options"; +import { Button } from "./ui/button"; +import { Label } from "./ui/label"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "./ui/select"; + +type AgentDefaultsDialogProps = { + daemonReady: boolean; + open: boolean; + onOpenChange: (open: boolean) => void; +}; + +export function AgentDefaultsDialog({ daemonReady, open, onOpenChange }: AgentDefaultsDialogProps) { + const queryClient = useQueryClient(); + const query = useQuery({ + queryKey: agentDefaultsQueryKey, + queryFn: fetchAgentDefaults, + enabled: daemonReady, + }); + const firstRunRequired = daemonReady && (query.isLoading || query.isError || (query.isSuccess && !query.data.configured)); + const visible = open || firstRunRequired; + const locked = firstRunRequired; + const [workerAgent, setWorkerAgent] = useState(""); + const [orchestratorAgent, setOrchestratorAgent] = useState(""); + + useEffect(() => { + if (!visible || !query.data) return; + setWorkerAgent(query.data.defaultWorkerAgent ?? ""); + setOrchestratorAgent(query.data.defaultOrchestratorAgent ?? ""); + }, [query.data, visible]); + + const canSave = daemonReady && workerAgent !== "" && orchestratorAgent !== ""; + const title = firstRunRequired ? "Choose Default Agents" : "Default Agents"; + const mutation = useMutation({ + mutationFn: () => + saveAgentDefaults({ + defaultWorkerAgent: workerAgent as AgentOption, + defaultOrchestratorAgent: orchestratorAgent as AgentOption, + }), + onSuccess: (defaults) => { + queryClient.setQueryData(agentDefaultsQueryKey, defaults); + onOpenChange(false); + }, + }); + + const statusText = useMemo(() => { + if (query.isLoading) return "Loading agent settings..."; + if (!daemonReady) return "Daemon is not ready."; + if (query.isError) return query.error instanceof Error ? query.error.message : "Could not load agent settings"; + if (mutation.isError) return mutation.error instanceof Error ? mutation.error.message : "Could not save agent settings"; + return null; + }, [daemonReady, mutation.error, mutation.isError, query.error, query.isError, query.isLoading]); + + if (!visible) return null; + + const close = () => { + if (!locked) onOpenChange(false); + }; + + return ( +
+
{ + event.preventDefault(); + if (canSave) mutation.mutate(); + }} + role="dialog" + aria-modal="true" + > +
+
+
+
+

+ {title} +

+

+ {firstRunRequired ? "Required before spawning sessions." : "Used when a project has no role override."} +

+
+ {!locked && ( + + )} +
+ +
+ + + + + + +
+ +
+ + {statusText} + + +
+
+
+ ); +} + +function AgentSelect({ id, value, onChange }: { id: string; value: string; onChange: (value: string) => void }) { + return ( + + ); +} + +function Field({ label, htmlFor, children }: { label: string; htmlFor: string; children: ReactNode }) { + return ( +
+ + {children} +
+ ); +} diff --git a/frontend/src/renderer/components/ProjectSettingsForm.tsx b/frontend/src/renderer/components/ProjectSettingsForm.tsx index 673204e1..9e2dbef8 100644 --- a/frontend/src/renderer/components/ProjectSettingsForm.tsx +++ b/frontend/src/renderer/components/ProjectSettingsForm.tsx @@ -1,8 +1,9 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useState } from "react"; import type { components } from "../../api/schema"; -import { apiClient, apiErrorMessage } from "../lib/api-client"; import { workspaceQueryKey } from "../hooks/useWorkspaceQuery"; +import { apiClient, apiErrorMessage } from "../lib/api-client"; +import { AGENT_OPTIONS } from "../lib/agent-options"; import { DashboardSubhead } from "./DashboardSubhead"; import { Button } from "./ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "./ui/card"; @@ -12,9 +13,6 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from ". type Project = components["schemas"]["Project"]; type ProjectConfig = components["schemas"]["ProjectConfig"]; -// Agents the daemon registers. Empty = "use the daemon default". -const AGENT_OPTIONS = ["claude-code", "codex", "opencode", "amp", "goose", "kiro"] as const; - const PERMISSION_MODE_OPTIONS = [ { value: "default", label: "Default" }, { value: "accept-edits", label: "Accept edits" }, @@ -251,14 +249,14 @@ function PermissionModeSelect({ } function AgentSelect({ id, value, onChange }: { id: string; value: string; onChange: (value: string) => void }) { - // "" sentinel → daemon default; Select can't hold an empty value, so map it. + // "" sentinel → app default; Select can't hold an empty value, so map it. return ( + { - if (next !== AGENT_SELECT_PLACEHOLDER) onChange(next); - }} - > - - - - - - Select agent - - {AGENT_OPTIONS.map((agent) => ( - - {agent} - - ))} - - - ); -} - -function Field({ label, htmlFor, children }: { label: string; htmlFor: string; children: ReactNode }) { - return ( -
- - {children} -
- ); -} diff --git a/frontend/src/renderer/components/ProjectSettingsForm.tsx b/frontend/src/renderer/components/ProjectSettingsForm.tsx index 9e2dbef8..673204e1 100644 --- a/frontend/src/renderer/components/ProjectSettingsForm.tsx +++ b/frontend/src/renderer/components/ProjectSettingsForm.tsx @@ -1,9 +1,8 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useState } from "react"; import type { components } from "../../api/schema"; -import { workspaceQueryKey } from "../hooks/useWorkspaceQuery"; import { apiClient, apiErrorMessage } from "../lib/api-client"; -import { AGENT_OPTIONS } from "../lib/agent-options"; +import { workspaceQueryKey } from "../hooks/useWorkspaceQuery"; import { DashboardSubhead } from "./DashboardSubhead"; import { Button } from "./ui/button"; import { Card, CardContent, CardHeader, CardTitle } from "./ui/card"; @@ -13,6 +12,9 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from ". type Project = components["schemas"]["Project"]; type ProjectConfig = components["schemas"]["ProjectConfig"]; +// Agents the daemon registers. Empty = "use the daemon default". +const AGENT_OPTIONS = ["claude-code", "codex", "opencode", "amp", "goose", "kiro"] as const; + const PERMISSION_MODE_OPTIONS = [ { value: "default", label: "Default" }, { value: "accept-edits", label: "Accept edits" }, @@ -249,14 +251,14 @@ function PermissionModeSelect({ } function AgentSelect({ id, value, onChange }: { id: string; value: string; onChange: (value: string) => void }) { - // "" sentinel → app default; Select can't hold an empty value, so map it. + // "" sentinel → daemon default; Select can't hold an empty value, so map it. return (