diff --git a/binary/tests/conftest.py b/binary/tests/conftest.py index 10aec6a5..bc65736b 100644 --- a/binary/tests/conftest.py +++ b/binary/tests/conftest.py @@ -59,6 +59,14 @@ def run_binary(args, payload, home, extra_env=None, stdin_close=False): return _run([str(BUILT_BINARY)] + args, payload, home, extra_env, stdin_close) +def run_go_binary(args, payload, home, extra_env=None, stdin_close=False): + """The Go rewrite (WEB-4809); opt-in via UNBOUND_GO_BINARY, else skipped.""" + go_binary = os.environ.get("UNBOUND_GO_BINARY") + if not go_binary: + pytest.skip("UNBOUND_GO_BINARY not set; Go parity is opt-in") + return _run([go_binary] + args, payload, home, extra_env, stdin_close) + + @pytest.fixture def sandbox_home(tmp_path): home = tmp_path / "home" diff --git a/binary/tests/test_hook_cli.py b/binary/tests/test_hook_cli.py index ac5086b1..a1b8cd9b 100644 --- a/binary/tests/test_hook_cli.py +++ b/binary/tests/test_hook_cli.py @@ -10,7 +10,7 @@ import pytest -from conftest import run_binary, run_cli_dev, run_python_path +from conftest import run_binary, run_cli_dev, run_go_binary, run_python_path S = {"session_id": "test-session", "transcript_path": "/nonexistent/transcript.jsonl"} @@ -96,6 +96,15 @@ def test_frozen_binary_matches_python_path(tool, event, sandbox_home): assert got.returncode == ref.returncode +@pytest.mark.parametrize("tool,event", CASES) +def test_go_binary_matches_python_path(tool, event, sandbox_home): + payload = json.dumps(EVENT_PAYLOADS[tool][event]) + ref = run_python_path(tool, payload, sandbox_home) + got = run_go_binary(["hook", tool, event], payload, sandbox_home) + assert got.stdout == ref.stdout + assert got.returncode == ref.returncode + + @pytest.mark.parametrize("tool", list(EVENT_PAYLOADS)) @pytest.mark.parametrize("junk", ["", "not json at all"]) def test_malformed_stdin_parity(tool, junk, sandbox_home): diff --git a/go/.gitignore b/go/.gitignore new file mode 100644 index 00000000..c0f64a83 --- /dev/null +++ b/go/.gitignore @@ -0,0 +1,2 @@ +dist/ +build/ diff --git a/go/README.md b/go/README.md new file mode 100644 index 00000000..95c4d5a6 --- /dev/null +++ b/go/README.md @@ -0,0 +1,62 @@ +# unbound-hook — Go rewrite (WEB-4809) + +Phase 1 scaffold of a Go port of the PyInstaller `unbound-hook` binary in +`binary/`. Rationale: a single static Go binary avoids the PyInstaller +onedir bundle that EDR/AV agents flag and slow-scan on managed fleets. + +## Contract + +The CLI surface mirrors `binary/` exactly — see `binary/README.md` and +`binary/src/unbound_hook/main.py` / `hook_cmd.py`: + +- `unbound-hook hook []` — tools: claude-code | cursor | + copilot | codex. stdin event JSON → stdout response JSON. **Fail-open:** + unknown tool, bad input, or any dispatcher failure prints `{}` and exits 0; + this process sits between the user and their editor. +- `unbound-hook setup|backfill|clear` — admin commands, NOT fail-open; + currently exit 1 with "not implemented". +- `unbound-hook --version` / `version` — `unbound-hook `, never + reads stdin (pkg postinstall pre-warm contract, packaging/README.md + "Version contract"). Version is baked via `-ldflags "-X main.Version=..."`. + +Phase 1 status: each tool handler is a fail-open stub (reads stdin, prints +`{}`, exits 0). The real per-tool ports come next; sources are named in the +TODO header of each `internal/hooks/*.go` file. + +Phase 2 status: the shared core the four python hook modules duplicate is +ported as stdlib-only packages (not yet wired into the stubs). Each package +doc comment names the python lines it mirrors; `claude-code/hooks/unbound.py` +is the canonical reference: + +- `internal/pyjson` — python-`json.dumps`-byte-identical encode/decode + (ordered objects, ensure_ascii, repr(float)); required for stdout and + audit-line parity, since Go's encoding/json formats differently +- `internal/config` — ~/.unbound/config.json + UNBOUND_GATEWAY_URL / + UNBOUND__API_KEY precedence (codex is env-only, quirk kept) +- `internal/httpc` — HTTP via curl subprocess (house rule: corporate-CA / + Zscaler compat; never net/http), exact python argv, fail-open +- `internal/report` — error.log (25-line cap) + rate-limited best-effort + POST to /v1/hooks/errors +- `internal/audit` — agent-audit.log JSONL load/append/save + session-keyed + cleanup (grouping key is per tool, supplied by callers) +- `internal/locks` — mtime-TTL lock files (self-update lock, dispatch + claim-and-steal, staleness probe, touch) +- `internal/transcript` — claude-code Stop-path transcript JSONL parsing + (parse_transcript_file), including its abort-on-exception quirks + +## Build & test + +``` +./build.sh # Go 1.22+ + lipo: universal2 dist/unbound-hook/unbound-hook + # UNBOUND_HOOK_VERSION=1.2.3 bakes the release version + +UNBOUND_GO_BINARY=$PWD/dist/unbound-hook/unbound-hook \ + python3 -m pytest ../binary/tests/ -q # opt-in tool×event parity +``` + +Stdlib only — zero Go dependencies. + +The python path in `binary/` remains the golden reference; the parity +harness in `binary/tests/test_hook_cli.py` compares this binary's stdout + +exit code against `python3 /unbound.py` for every tool × event when +`UNBOUND_GO_BINARY` is set (skipped otherwise, so CI is unchanged). diff --git a/go/build.sh b/go/build.sh new file mode 100755 index 00000000..12b034fc --- /dev/null +++ b/go/build.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# Build the unbound-hook Go universal2 binary (WEB-4809). +# Requires Go 1.22+ and macOS lipo. UNBOUND_BUILD_GO overrides the toolchain. +set -euo pipefail +cd "$(dirname "$0")" + +GO="${UNBOUND_BUILD_GO:-go}" +command -v "$GO" >/dev/null 2>&1 || { echo "ERROR: go toolchain not found (set UNBOUND_BUILD_GO)"; exit 1; } +command -v lipo >/dev/null 2>&1 || { echo "ERROR: lipo not found (macOS required)"; exit 1; } + +# packaging/README.md "Version contract": --version must self-identify with +# the release version as a whitespace-delimited token. +VERSION="${UNBOUND_HOOK_VERSION:-0.0.0-dev}" +LDFLAGS="-s -w -X main.Version=${VERSION}" + +OUT=dist/unbound-hook +mkdir -p "$OUT" + +CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 "$GO" build -trimpath -ldflags "$LDFLAGS" \ + -o "$OUT/unbound-hook.arm64" ./cmd/unbound-hook +CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 "$GO" build -trimpath -ldflags "$LDFLAGS" \ + -o "$OUT/unbound-hook.amd64" ./cmd/unbound-hook +lipo -create -output "$OUT/unbound-hook" "$OUT/unbound-hook.arm64" "$OUT/unbound-hook.amd64" +rm "$OUT/unbound-hook.arm64" "$OUT/unbound-hook.amd64" + +echo "--- verifying universal2 ---" +archs=$(lipo -archs "$OUT/unbound-hook") +case "$archs" in + *x86_64*arm64*|*arm64*x86_64*) echo "OK: universal2 ($archs)" ;; + *) echo "NOT-UNIVERSAL: $archs"; exit 1 ;; +esac + +echo "--- smoke ---" +"./$OUT/unbound-hook" --version +echo '{}' | "./$OUT/unbound-hook" hook claude-code PreToolUse diff --git a/go/cmd/unbound-hook/main.go b/go/cmd/unbound-hook/main.go new file mode 100644 index 00000000..4dc7b7bf --- /dev/null +++ b/go/cmd/unbound-hook/main.go @@ -0,0 +1,77 @@ +// Command unbound-hook is the Go rewrite of the PyInstaller hook binary. +// The python implementation under binary/ is the golden reference; this +// dispatcher mirrors binary/src/unbound_hook/main.py exactly. +// +// Subcommands: +// +// hook [] stdin/stdout hook dispatch (fail-open, exit 0) +// setup [...] MDM onboarding (not implemented yet) +// backfill [...] historical transcript seeding (not implemented yet) +// clear full deregistration (not implemented yet) +// --version / version print version (pkg postinstall pre-warm contract: +// must exit fast without reading stdin) +package main + +import ( + "fmt" + "os" + + "github.com/websentry-ai/setup/go/internal/hooks" +) + +// Version is baked at build time via -ldflags "-X main.Version=...". +var Version = "0.0.0-dev" + +func usage() string { + return fmt.Sprintf(`unbound-hook %s + +Usage: + unbound-hook hook [] tools: claude-code|cursor|copilot|codex + unbound-hook setup --api-key [--discovery-key ] [options] + unbound-hook backfill (--all | --user ) [--dry-run] [options] + unbound-hook clear + unbound-hook --version +`, Version) +} + +func run(args []string) int { + if len(args) > 0 && (args[0] == "--version" || args[0] == "-V" || args[0] == "version") { + // Pre-warm contract: print and exit, never touch stdin. + fmt.Printf("unbound-hook %s\n", Version) + return 0 + } + if len(args) == 0 { + fmt.Println(usage()) + return 2 + } + if args[0] == "-h" || args[0] == "--help" || args[0] == "help" { + fmt.Println(usage()) + return 0 + } + + cmd, rest := args[0], args[1:] + switch cmd { + case "hook": + tool, event := "", "" + if len(rest) > 0 { + tool = rest[0] + } + if len(rest) > 1 { + event = rest[1] + } + return hooks.Dispatch(tool, event, os.Stdin, os.Stdout) + case "setup", "backfill", "clear": + // Admin commands are NOT fail-open: a silent no-op here would look + // like a successful install/backfill/deregistration. + fmt.Fprintf(os.Stderr, "unbound-hook %s: not implemented\n", cmd) + return 1 + } + + fmt.Fprintf(os.Stderr, "Unknown command: %s\n", cmd) + fmt.Fprintln(os.Stderr, usage()) + return 2 +} + +func main() { + os.Exit(run(os.Args[1:])) +} diff --git a/go/go.mod b/go/go.mod new file mode 100644 index 00000000..a10173bc --- /dev/null +++ b/go/go.mod @@ -0,0 +1,3 @@ +module github.com/websentry-ai/setup/go + +go 1.22 diff --git a/go/internal/audit/audit.go b/go/internal/audit/audit.go new file mode 100644 index 00000000..45c899e9 --- /dev/null +++ b/go/internal/audit/audit.go @@ -0,0 +1,127 @@ +// Package audit ports the per-tool agent-audit.log JSONL handling: +// load_existing_logs / save_logs / append_to_audit_log +// (claude-code/hooks/unbound.py lines 205-238) and cleanup_old_logs +// (lines 1103-1126; cursor/unbound.py keys on event.conversation_id +// instead of the top-level session_id, so the grouping key is a caller +// parameter here). +// +// Entries are decoded pyjson values so that Save re-renders each line +// byte-identically to python's json.dumps(json.loads(line)). Paths are +// per tool and owned by callers (~/.claude/hooks/agent-audit.log, +// ~/.cursor/hooks/..., ~/.codex/hooks/..., ~/.copilot/hooks/...). There is +// no size-based rotation — only the session-scoped cleanup; files are +// created with the process umask like python's open(). +// +// Quirk copied as-is: a non-object JSONL line loads fine in python and +// only blows up later (AttributeError in cleanup, caught by main's blanket +// handler). Here Cleanup's key func decides what such entries map to. +package audit + +import ( + "bufio" + "bytes" + "os" + "path/filepath" + + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +// Load reads every parseable JSONL entry. Blank and undecodable lines are +// skipped; an unreadable file yields whatever was collected (python +// swallows the exception and returns the partial list). +func Load(path string) []any { + logs := []any{} + f, err := os.Open(path) + if err != nil { + return logs + } + defer f.Close() + r := bufio.NewReader(f) + for { + line, err := r.ReadBytes('\n') + trimmed := bytes.TrimSpace(line) + if len(trimmed) > 0 { + if entry, perr := pyjson.Loads(trimmed); perr == nil { + logs = append(logs, entry) + } + } + if err != nil { + return logs + } + } +} + +// Save rewrites the file with one python-format JSON line per entry. +// Errors are swallowed; an entry that fails to encode aborts the rest, +// leaving a partial file exactly like a mid-write python exception would. +func Save(path string, logs []any) { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return + } + f, err := os.Create(path) + if err != nil { + return + } + defer f.Close() + for _, entry := range logs { + line, err := pyjson.Dumps(entry) + if err != nil { + return + } + if _, err := f.WriteString(line + "\n"); err != nil { + return + } + } +} + +// Append adds one entry to the log. Errors are swallowed. +func Append(path string, entry any) { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return + } + line, err := pyjson.Dumps(entry) + if err != nil { + return + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return + } + defer f.Close() + _, _ = f.WriteString(line + "\n") +} + +// Cleanup trims the log once it exceeds limit entries. key extracts each +// entry's grouping id ("" for none): with more than one distinct id, only +// the most recently first-seen id's entries survive (entries with other or +// missing ids are dropped, including the size headroom — python keeps the +// whole last session however large); with at most one, the newest `limit` +// entries survive. +func Cleanup(path string, limit int, key func(entry any) string) { + logs := Load(path) + if len(logs) <= limit { + return + } + + var order []string + seen := map[string]bool{} + for _, entry := range logs { + if id := key(entry); id != "" && !seen[id] { + order = append(order, id) + seen[id] = true + } + } + + if len(order) > 1 { + latest := order[len(order)-1] + kept := []any{} + for _, entry := range logs { + if key(entry) == latest { + kept = append(kept, entry) + } + } + Save(path, kept) + } else if len(logs) > limit { + Save(path, logs[len(logs)-limit:]) + } +} diff --git a/go/internal/audit/audit_test.go b/go/internal/audit/audit_test.go new file mode 100644 index 00000000..1fd4ef34 --- /dev/null +++ b/go/internal/audit/audit_test.go @@ -0,0 +1,172 @@ +package audit + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +// sessionKey mirrors claude-code cleanup_old_logs: top-level session_id. +func sessionKey(entry any) string { + obj, ok := entry.(*pyjson.Object) + if !ok { + return "" + } + if s, ok := obj.GetDefault("session_id", nil).(string); ok { + return s + } + return "" +} + +func entryLine(session string, n int) string { + return fmt.Sprintf(`{"timestamp": "2026-06-12T00:00:%02dZ", "session_id": "%s", "event": {"hook_event_name": "PostToolUse", "n": %d}}`, n%60, session, n) +} + +func writeLog(t *testing.T, path string, lines []string) { + t.Helper() + if err := os.WriteFile(path, []byte(strings.Join(lines, "\n")+"\n"), 0o644); err != nil { + t.Fatal(err) + } +} + +func TestLoadSkipsBlankAndCorruptLines(t *testing.T) { + path := filepath.Join(t.TempDir(), "agent-audit.log") + writeLog(t, path, []string{ + entryLine("s1", 0), + "", + " ", + "{not json", + entryLine("s1", 1), + }) + logs := Load(path) + if len(logs) != 2 { + t.Fatalf("got %d entries, want 2", len(logs)) + } +} + +func TestLoadMissingFileIsEmpty(t *testing.T) { + if logs := Load(filepath.Join(t.TempDir(), "nope")); len(logs) != 0 { + t.Errorf("got %d entries", len(logs)) + } +} + +func TestAppendThenLoadRoundTripsBytes(t *testing.T) { + path := filepath.Join(t.TempDir(), "agent-audit.log") + line := entryLine("s1", 0) + entry, err := pyjson.Loads([]byte(line)) + if err != nil { + t.Fatal(err) + } + Append(path, entry) + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if got := string(data); got != line+"\n" { + t.Errorf("file = %q, want %q", got, line+"\n") + } +} + +func TestAppendCreatesParentDir(t *testing.T) { + path := filepath.Join(t.TempDir(), ".claude", "hooks", "agent-audit.log") + Append(path, pyjson.NewObject().Set("k", "v")) + if _, err := os.Stat(path); err != nil { + t.Fatalf("log not created: %v", err) + } +} + +func TestSaveRewritesPythonFormat(t *testing.T) { + path := filepath.Join(t.TempDir(), "agent-audit.log") + // compact input must come back python-formatted, like json.dumps(json.loads(line)) + entry, err := pyjson.Loads([]byte(`{"session_id":"s1","event":{"a":1,"b":[1,2]}}`)) + if err != nil { + t.Fatal(err) + } + Save(path, []any{entry}) + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + want := `{"session_id": "s1", "event": {"a": 1, "b": [1, 2]}}` + "\n" + if string(data) != want { + t.Errorf("file = %q, want %q", string(data), want) + } +} + +func TestCleanupUnderLimitUntouched(t *testing.T) { + path := filepath.Join(t.TempDir(), "agent-audit.log") + lines := []string{entryLine("s1", 0), entryLine("s2", 1)} + writeLog(t, path, lines) + before, _ := os.ReadFile(path) + Cleanup(path, 100, sessionKey) + after, _ := os.ReadFile(path) + if string(before) != string(after) { + t.Error("under-limit log must not be rewritten") + } +} + +func TestCleanupMultiSessionKeepsOnlyMostRecent(t *testing.T) { + path := filepath.Join(t.TempDir(), "agent-audit.log") + var lines []string + for i := 0; i < 60; i++ { + lines = append(lines, entryLine("old-session", i)) + } + for i := 0; i < 50; i++ { + lines = append(lines, entryLine("new-session", i)) + } + writeLog(t, path, lines) + Cleanup(path, 100, sessionKey) + logs := Load(path) + if len(logs) != 50 { + t.Fatalf("got %d entries, want 50", len(logs)) + } + for _, e := range logs { + if sessionKey(e) != "new-session" { + t.Fatalf("kept entry from %q", sessionKey(e)) + } + } +} + +func TestCleanupSingleSessionTrimsToLimit(t *testing.T) { + path := filepath.Join(t.TempDir(), "agent-audit.log") + var lines []string + for i := 0; i < 130; i++ { + lines = append(lines, entryLine("only-session", i)) + } + writeLog(t, path, lines) + Cleanup(path, 100, sessionKey) + logs := Load(path) + if len(logs) != 100 { + t.Fatalf("got %d entries, want 100", len(logs)) + } + // must keep the NEWEST 100 + first, _ := logs[0].(*pyjson.Object) + ev, _ := first.GetDefault("event", nil).(*pyjson.Object) + if n, _ := ev.Get("n"); n != pyjson.Number("30") { + t.Errorf("first kept entry n = %v, want 30", n) + } +} + +func TestCleanupDropsKeylessEntriesWhenMultiSession(t *testing.T) { + // python: logs without session_id are dropped by the kept_logs filter + path := filepath.Join(t.TempDir(), "agent-audit.log") + var lines []string + for i := 0; i < 60; i++ { + lines = append(lines, entryLine("a", i)) + } + lines = append(lines, `{"event": {"hook_event_name": "Stop"}}`) + for i := 0; i < 50; i++ { + lines = append(lines, entryLine("b", i)) + } + writeLog(t, path, lines) + Cleanup(path, 100, sessionKey) + for _, e := range Load(path) { + if sessionKey(e) != "b" { + t.Fatalf("kept non-b entry: %v", e) + } + } +} diff --git a/go/internal/config/config.go b/go/internal/config/config.go new file mode 100644 index 00000000..108d73d9 --- /dev/null +++ b/go/internal/config/config.go @@ -0,0 +1,129 @@ +// Package config reads ~/.unbound/config.json and the hook environment +// variables, mirroring the python hook modules: +// +// - GatewayURL mirrors the UNBOUND_GATEWAY_URL module constant +// (claude-code/hooks/unbound.py lines 17-19): env value if the variable +// is set (even empty, like os.environ.get), else the baked default, +// with trailing slashes stripped. +// - APIKey mirrors get_api_key (claude-code/hooks/unbound.py lines +// 1181-1203, cursor/unbound.py lines 1182-1195, copilot/hooks/unbound.py +// lines 938-951) and codex's env-only lookup (codex/hooks/unbound.py +// line 1523): per-tool env var first, then config.json's api_key — +// except codex, which reads ONLY the env var (quirk copied as-is). +// - Read mirrors the discovery/mcp-scan config reads +// (claude-code/hooks/unbound.py lines 1429-1439, 1533-1546): api_key +// and base_url come from config.json; gateway_url and frontend_url are +// also written there by the setup scripts (claude-code/hooks/setup.py +// write_unbound_config, line 1245) and exposed for future callers. +package config + +import ( + "encoding/json" + "errors" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// DefaultGatewayURL is the python modules' baked env-var default. The +// python install path rewrites the literal in the script for non-default +// tenants; the Go equivalent is -ldflags "-X .../config.DefaultGatewayURL=...". +var DefaultGatewayURL = "https://api.getunbound.ai" + +// APIKeyEnvVars maps each tool to its API-key environment variable. +var APIKeyEnvVars = map[string]string{ + "claude-code": "UNBOUND_CLAUDE_API_KEY", + "cursor": "UNBOUND_CURSOR_API_KEY", + "codex": "UNBOUND_CODEX_API_KEY", + "copilot": "UNBOUND_COPILOT_API_KEY", +} + +// Config is the subset of ~/.unbound/config.json the hooks consume. +// Non-string values for any field are treated as missing (python's +// dict.get would pass them through, but no caller survives one). +type Config struct { + APIKey string // "api_key" + BaseURL string // "base_url" — the backend, used as discovery --domain + GatewayURL string // "gateway_url" + FrontendURL string // "frontend_url" +} + +// Path returns ~/.unbound/config.json (UNBOUND_CONFIG_PATH in python). +func Path() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".unbound", "config.json"), nil +} + +// GatewayURL resolves the gateway base URL: UNBOUND_GATEWAY_URL if set +// (python os.environ.get treats a set-but-empty variable as set), else +// DefaultGatewayURL; trailing '/' stripped like python's rstrip("/"). +func GatewayURL() string { + url, ok := os.LookupEnv("UNBOUND_GATEWAY_URL") + if !ok { + url = DefaultGatewayURL + } + return strings.TrimRight(url, "/") +} + +// Read parses ~/.unbound/config.json. A missing file returns fs.ErrNotExist +// (python's FileNotFoundError branch); other read/parse failures return +// their error so callers can log them the way the python modules do. +func Read() (Config, error) { + path, err := Path() + if err != nil { + return Config{}, err + } + return readFile(path) +} + +func readFile(path string) (Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return Config{}, err + } + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return Config{}, err + } + str := func(key string) string { + if s, ok := raw[key].(string); ok { + return s + } + return "" + } + return Config{ + APIKey: str("api_key"), + BaseURL: str("base_url"), + GatewayURL: str("gateway_url"), + FrontendURL: str("frontend_url"), + }, nil +} + +// APIKey resolves the API key for a tool: the tool's env var when non-empty +// (python: `if key:` — a set-but-empty var falls through), then config.json. +// codex never falls back to config.json (its main() reads only the env var). +// A missing config file yields ("", nil), mirroring the silent +// FileNotFoundError branch; any other failure is returned for the caller to +// log ('config' category in python). +func APIKey(tool string) (string, error) { + if env := APIKeyEnvVars[tool]; env != "" { + if key := os.Getenv(env); key != "" { + return key, nil + } + } + if tool == "codex" { + return "", nil + } + cfg, err := Read() + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return "", nil + } + return "", err + } + return cfg.APIKey, nil +} diff --git a/go/internal/config/config_test.go b/go/internal/config/config_test.go new file mode 100644 index 00000000..05792ca7 --- /dev/null +++ b/go/internal/config/config_test.go @@ -0,0 +1,135 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func withHome(t *testing.T) string { + t.Helper() + home := t.TempDir() + t.Setenv("HOME", home) + return home +} + +func writeConfig(t *testing.T, home, content string) { + t.Helper() + dir := filepath.Join(home, ".unbound") + if err := os.MkdirAll(dir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(content), 0o600); err != nil { + t.Fatal(err) + } +} + +func TestGatewayURLDefault(t *testing.T) { + os.Unsetenv("UNBOUND_GATEWAY_URL") + if got := GatewayURL(); got != "https://api.getunbound.ai" { + t.Errorf("GatewayURL() = %q", got) + } +} + +func TestGatewayURLEnvAndRstrip(t *testing.T) { + t.Setenv("UNBOUND_GATEWAY_URL", "https://tenant.example.com///") + if got := GatewayURL(); got != "https://tenant.example.com" { + t.Errorf("GatewayURL() = %q", got) + } +} + +func TestGatewayURLSetButEmptyCountsAsSet(t *testing.T) { + // python os.environ.get returns "" for a set-but-empty var. + t.Setenv("UNBOUND_GATEWAY_URL", "") + if got := GatewayURL(); got != "" { + t.Errorf("GatewayURL() = %q, want empty", got) + } +} + +func TestAPIKeyEnvWins(t *testing.T) { + home := withHome(t) + writeConfig(t, home, `{"api_key": "from-config"}`) + t.Setenv("UNBOUND_CLAUDE_API_KEY", "from-env") + key, err := APIKey("claude-code") + if err != nil || key != "from-env" { + t.Errorf("APIKey = %q, %v", key, err) + } +} + +func TestAPIKeyEmptyEnvFallsThroughToConfig(t *testing.T) { + home := withHome(t) + writeConfig(t, home, `{"api_key": "from-config"}`) + t.Setenv("UNBOUND_CLAUDE_API_KEY", "") // falsy in python: falls through + key, err := APIKey("claude-code") + if err != nil || key != "from-config" { + t.Errorf("APIKey = %q, %v", key, err) + } +} + +func TestAPIKeyCodexNeverReadsConfig(t *testing.T) { + home := withHome(t) + writeConfig(t, home, `{"api_key": "from-config"}`) + os.Unsetenv("UNBOUND_CODEX_API_KEY") + key, err := APIKey("codex") + if err != nil || key != "" { + t.Errorf("APIKey(codex) = %q, %v; codex is env-only", key, err) + } + t.Setenv("UNBOUND_CODEX_API_KEY", "codex-env") + if key, _ := APIKey("codex"); key != "codex-env" { + t.Errorf("APIKey(codex) = %q", key) + } +} + +func TestAPIKeyMissingFileIsSilent(t *testing.T) { + withHome(t) + os.Unsetenv("UNBOUND_CURSOR_API_KEY") + key, err := APIKey("cursor") + if err != nil || key != "" { + t.Errorf("APIKey = %q, %v; want silent empty (FileNotFoundError branch)", key, err) + } +} + +func TestAPIKeyCorruptConfigReturnsError(t *testing.T) { + home := withHome(t) + writeConfig(t, home, `{not json`) + os.Unsetenv("UNBOUND_COPILOT_API_KEY") + if _, err := APIKey("copilot"); err == nil { + t.Error("expected error for corrupt config.json (python logs it)") + } +} + +func TestReadAllFields(t *testing.T) { + home := withHome(t) + writeConfig(t, home, `{ + "api_key": "k", + "base_url": "https://backend.example.com", + "gateway_url": "https://gw.example.com", + "frontend_url": "https://fe.example.com", + "extra": 42 + }`) + cfg, err := Read() + if err != nil { + t.Fatal(err) + } + want := Config{ + APIKey: "k", + BaseURL: "https://backend.example.com", + GatewayURL: "https://gw.example.com", + FrontendURL: "https://fe.example.com", + } + if cfg != want { + t.Errorf("Read() = %+v, want %+v", cfg, want) + } +} + +func TestReadNonStringFieldTreatedAsMissing(t *testing.T) { + home := withHome(t) + writeConfig(t, home, `{"api_key": 12345, "base_url": "ok"}`) + cfg, err := Read() + if err != nil { + t.Fatal(err) + } + if cfg.APIKey != "" || cfg.BaseURL != "ok" { + t.Errorf("Read() = %+v", cfg) + } +} diff --git a/go/internal/hooks/claude_code.go b/go/internal/hooks/claude_code.go new file mode 100644 index 00000000..debd6d38 --- /dev/null +++ b/go/internal/hooks/claude_code.go @@ -0,0 +1,279 @@ +package hooks + +// claude-code hook handler: a behavioral port of +// claude-code/hooks/unbound.py — the golden reference; doc comments cite +// its line numbers and quirks are copied verbatim. This binary is the +// packaged ("frozen") variant: self-update is the internal/selfupdate +// no-op and discovery always runs the locally installed binary, never the +// install.sh download path (unbound.py lines 60-66). + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" + + "github.com/websentry-ai/setup/go/internal/audit" + "github.com/websentry-ai/setup/go/internal/config" + "github.com/websentry-ai/setup/go/internal/pyjson" + "github.com/websentry-ai/setup/go/internal/report" + "github.com/websentry-ai/setup/go/internal/selfupdate" +) + +// Module constants (unbound.py lines 17-73). +const ( + mcpToolPrefix = "mcp__" + ccCacheTTLSeconds = 300 + ccPolicyCheckFailureBlockReason = "policy engine unavailable — please retry" + ccPretoolUserMessagesLimit = 5 + ccAuditLogTotalLimit = 100 + + approvalTimeout = 4 * time.Hour + + discoveryDebounce = 24 * time.Hour + discoveryHookFlagTTL = 24 * time.Hour + discoveryHookFlagPath = "/v1/hooks/discovery-enabled" + discoveryStaleLock = 15 * time.Minute + discoveryDispatchTTL = 10 * time.Second +) + +// ALLOWED_NON_MCP_HOOK_NAMES / NATIVE_FILE_TOOLS (lines 23-24): MCP tools +// (mcp__*) are always checked separately. +var ( + ccAllowedNonMCPHookNames = []string{"Bash", "Read", "Write", "Edit"} + ccNativeFileTools = map[string]bool{"Read": true, "Write": true, "Edit": true} +) + +// approvalPollPhases mirrors APPROVAL_POLL_PHASES (lines 68-73): +// (elapsed-below, interval) pairs in seconds. +var approvalPollPhases = [4][2]int{ + {5 * 60, 3}, + {30 * 60, 15}, + {2 * 60 * 60, 60}, + {4 * 60 * 60, 120}, +} + +// claudeCodeHook carries per-process state held in module globals on the +// python side, plus the home-derived paths (lines 20-58). +type claudeCodeHook struct { + gatewayURL string + apiKey string + rep *report.Reporter + + auditLog string // ~/.claude/hooks/agent-audit.log + policyCache string // ~/.claude/hooks/.policy_cache.json + approvalMarker string // ~/.claude/hooks/.approval_pending + claudeConfig string // ~/.claude.json + unboundConfig string // ~/.unbound/config.json + identityCache string // ~/.unbound/identity.json + discoveryCache string // ~/.unbound/discovery-cache.json + discoveryLock string // ~/.unbound/discovery.lock + dispatchLock string // ~/.unbound/discovery.dispatch.lock +} + +func newClaudeCodeHook() (*claudeCodeHook, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + hooksDir := filepath.Join(home, ".claude", "hooks") + unboundDir := filepath.Join(home, ".unbound") + gw := config.GatewayURL() + return &claudeCodeHook{ + gatewayURL: gw, + rep: &report.Reporter{ + GatewayURL: gw, + HookSource: "claude-code", + ErrorLog: filepath.Join(hooksDir, "error.log"), + LastReportFile: filepath.Join(hooksDir, ".last_error_report"), + }, + auditLog: filepath.Join(hooksDir, "agent-audit.log"), + policyCache: filepath.Join(hooksDir, ".policy_cache.json"), + approvalMarker: filepath.Join(hooksDir, ".approval_pending"), + claudeConfig: filepath.Join(home, ".claude.json"), + unboundConfig: filepath.Join(unboundDir, "config.json"), + identityCache: filepath.Join(unboundDir, "identity.json"), + discoveryCache: filepath.Join(unboundDir, "discovery-cache.json"), + discoveryLock: filepath.Join(unboundDir, "discovery.lock"), + dispatchLock: filepath.Join(unboundDir, "discovery.dispatch.lock"), + }, nil +} + +func runClaudeCode(_ string, stdin io.Reader, stdout io.Writer) int { + c, err := newClaudeCodeHook() + if err != nil { + // No resolvable home; the dispatcher contract is fail-open. + fmt.Fprintln(stdout, `{"suppressOutput": true}`) + return 0 + } + c.main(stdin, stdout) + return 0 +} + +// main mirrors main() (lines 1607-1681). All paths print one JSON line and +// exit 0; the deferred recover is python's blanket except (1677-1680). +func (c *claudeCodeHook) main(stdin io.Reader, stdout io.Writer) { + // get_api_key (1181-1203) + the _cached_api_key global (1608-1610). + key, err := config.APIKey("claude-code") + if err != nil { + var syn *json.SyntaxError + if errors.As(err, &syn) { + c.rep.LogError(fmt.Sprintf("~/.unbound/config.json is not valid JSON: %v", err), "config") + } else { + c.rep.LogError(fmt.Sprintf("Failed to read config file: %v", err), "config") + } + key = "" + } + c.apiKey = key + c.rep.APIKey = key + + defer func() { + if r := recover(); r != nil { + c.rep.LogError(fmt.Sprintf("Exception in main: %v", r), "general") + fmt.Fprintln(stdout, `{"suppressOutput": true}`) + } + }() + + raw, err := io.ReadAll(stdin) + if err != nil { + raise("stdin read failed: %v", err) + } + input := strings.TrimSpace(string(raw)) + if input == "" { + fmt.Fprintln(stdout, `{"suppressOutput": true}`) + return + } + parsed, err := pyjson.Loads([]byte(input)) + if err != nil { + fmt.Fprintln(stdout, `{"suppressOutput": true}`) + return + } + event := mustObj(parsed) + hookEventName, _ := event.GetDefault("hook_event_name", nil).(string) + + // SessionStart fires once per session — natural TTL gate for the + // debounced discovery scan dispatch (1629-1634). + if hookEventName == "SessionStart" { + c.deviceSerial(true) // warm the (slow) serial probe + cache once per session + selfupdate.Check() + c.dispatchDiscovery() + fmt.Fprintln(stdout, "{}") + return + } + sessionID := event.GetDefault("session_id", nil) + + if hookEventName == "PreToolUse" { + response := c.processPreToolUse(event) + response.Set("suppressOutput", true) + c.printJSON(stdout, response) + return + } + + if hookEventName == "UserPromptSubmit" { + response := c.processUserPromptSubmit(event) + if d, _ := response.GetDefault("decision", nil).(string); d == "block" { + audit.Append(c.auditLog, pyjson.NewObject(). + Set("timestamp", report.UTCTimestamp(time.Now())). + Set("session_id", event.GetDefault("session_id", nil)). + Set("event", event)) + response.Set("suppressOutput", true) + c.printJSON(stdout, response) + return + } + // Allowed: continue to log the event (output printed at end). + } + + audit.Append(c.auditLog, pyjson.NewObject(). + Set("timestamp", report.UTCTimestamp(time.Now())). + Set("session_id", event.GetDefault("session_id", nil)). + Set("event", event)) + + if hookEventName == "Stop" && pyjson.Truthy(sessionID) { + c.processStopEvent(event) + } + + c.cleanupOldLogs() + + fmt.Fprintln(stdout, `{"suppressOutput": true}`) +} + +func (c *claudeCodeHook) printJSON(w io.Writer, v any) { + s, err := pyjson.Dumps(v) + if err != nil { + raise("json dumps failed: %v", err) + } + fmt.Fprintln(w, s) +} + +// cleanupOldLogs mirrors cleanup_old_logs (1103-1126): grouping is by the +// TOP-LEVEL session_id only. +func (c *claudeCodeHook) cleanupOldLogs() { + audit.Cleanup(c.auditLog, ccAuditLogTotalLimit, func(entry any) string { + obj, ok := entry.(*pyjson.Object) + if !ok { + return "" + } + v, _ := obj.Get("session_id") + if !pyjson.Truthy(v) { + return "" + } + if s, ok := v.(string); ok { + return "s:" + s + } + // Non-string ids are keyed by their JSON form, prefixed so they + // never collide with a string id of the same spelling. + s, err := pyjson.Dumps(v) + if err != nil { + return "" + } + return "j:" + s + }) +} + +// getSessionModel mirrors _get_session_model (355-364). +func (c *claudeCodeHook) getSessionModel(sessionID any) any { + if !pyjson.Truthy(sessionID) { + return nil + } + return extractSessionModel(audit.Load(c.auditLog), sessionID) +} + +// extractSessionModel mirrors _extract_session_model (331-352): forward +// scan, latest SessionStart wins; the first malformed entry aborts the scan +// keeping what was found (python's broad except around the loop). +func extractSessionModel(logs []any, sessionID any) (found any) { + if !pyjson.Truthy(sessionID) || len(logs) == 0 { + return nil + } + defer func() { + if r := recover(); r != nil { + if _, ok := r.(pyRaise); !ok { + panic(r) + } + } + }() + for _, entry := range logs { + log := mustObj(entry) + logSession := log.GetDefault("session_id", nil) + if !pyjson.Truthy(logSession) { + logSession = objGet(log.GetDefault("event", pyjson.NewObject()), "session_id", nil) + } + if !pyEq(logSession, sessionID) { + continue + } + event := log + if v, has := log.Get("event"); has { + event = mustObj(v) + } + if hen, _ := event.GetDefault("hook_event_name", nil).(string); hen == "SessionStart" { + if model := event.GetDefault("model", nil); pyjson.Truthy(model) { + found = model // keep scanning — latest SessionStart wins + } + } + } + return found +} diff --git a/go/internal/hooks/claude_code_discovery.go b/go/internal/hooks/claude_code_discovery.go new file mode 100644 index 00000000..ae566f5f --- /dev/null +++ b/go/internal/hooks/claude_code_discovery.go @@ -0,0 +1,277 @@ +package hooks + +// Discovery + MCP-scan dispatch for the claude-code port +// (claude-code/hooks/unbound.py lines 1347-1409, 1419-1476, 1479-1604). +// The Go binary is always the frozen variant: it never downloads +// install.sh — discovery runs from the locally installed binary or is +// skipped with a logged error (lines 1441-1448, 1548-1554). + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/websentry-ai/setup/go/internal/config" + "github.com/websentry-ai/setup/go/internal/httpc" + "github.com/websentry-ai/setup/go/internal/locks" + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +// frozenDiscoveryBin mirrors FROZEN_DISCOVERY_BIN (line 66). Var, not +// const, so tests can point it at a sandbox binary. +var frozenDiscoveryBin = "/opt/unbound/current/unbound-discovery/unbound-discovery" + +// utcStamp mirrors datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"). +func utcStamp(t time.Time) string { + return t.UTC().Format("2006-01-02T15:04:05") + "Z" +} + +// parseUTCStamp is the strict strptime("%Y-%m-%dT%H:%M:%SZ") counterpart. +func parseUTCStamp(s string) (time.Time, error) { + if !strings.HasSuffix(s, "Z") { + return time.Time{}, errors.New("timestamp does not match format") + } + return time.Parse("2006-01-02T15:04:05", strings.TrimSuffix(s, "Z")) +} + +// readDiscoveryCache mirrors the lenient cache reads (1353-1361, +// 1483-1491): missing, unreadable, corrupt, falsy, or non-dict all yield {}. +func (c *claudeCodeHook) readDiscoveryCache() *pyjson.Object { + if data, err := os.ReadFile(c.discoveryCache); err == nil { + if parsed, perr := pyjson.Loads(data); perr == nil && pyjson.Truthy(parsed) { + if obj, ok := parsed.(*pyjson.Object); ok { + return obj + } + } + } + return pyjson.NewObject() +} + +// writeDiscoveryCache mirrors the atomic cache writes (1401-1406, +// 1592-1597): json.dump(indent=2, sort_keys=True) to a sibling .tmp, then +// os.replace. Filesystem errors are returned for the caller to handle per +// call site; an unencodable value raises (python json.dump TypeError is +// outside the except OSError). +func (c *claudeCodeHook) writeDiscoveryCache(cache *pyjson.Object) error { + s, err := pyjson.DumpsIndentSorted(cache) + if err != nil { + raise("json dump failed: %v", err) + } + // Path.with_suffix(".tmp"): discovery-cache.json -> discovery-cache.tmp. + tmp := strings.TrimSuffix(c.discoveryCache, ".json") + ".tmp" + if err := os.MkdirAll(filepath.Dir(c.discoveryCache), 0o755); err != nil { + return err + } + if err := os.WriteFile(tmp, []byte(s), 0o644); err != nil { + return err + } + return os.Rename(tmp, c.discoveryCache) +} + +// hookDiscoveryEnabledForOrg mirrors _hook_discovery_enabled_for_org +// (1347-1409): cached flag within TTL, else refetch from the gateway. +// Fail-closed: any error with no usable cached value means false. +func (c *claudeCodeHook) hookDiscoveryEnabledForOrg() bool { + cache := c.readDiscoveryCache() + flag := pyjson.NewObject() + if hd, ok := cache.GetDefault("hook_discovery", nil).(*pyjson.Object); ok { + flag = hd + } + cachedEnabled := func() bool { return pyjson.Truthy(flag.GetDefault("enabled", false)) } + + if lastFetched, ok := flag.GetDefault("fetched_at", nil).(string); ok { + if ts, err := parseUTCStamp(lastFetched); err == nil && time.Since(ts) < discoveryHookFlagTTL { + return cachedEnabled() + } + } + + cfg, err := config.Read() + if err != nil { + // python catches OSError/JSONDecodeError only; a non-dict config + // raises at cfg.get into main's blanket except. + var typeErr *json.UnmarshalTypeError + if errors.As(err, &typeErr) { + raise("config.json is not a dict: %v", err) + } + return cachedEnabled() + } + if cfg.APIKey == "" { + return cachedEnabled() + } + + res, err := httpc.Get(c.gatewayURL+discoveryHookFlagPath, cfg.APIKey, 5, 8*time.Second) + if err != nil || res.ExitCode != 0 { + return cachedEnabled() + } + parsed, perr := pyjson.Loads(res.Stdout) + if perr != nil { + return cachedEnabled() + } + obj, ok := parsed.(*pyjson.Object) + if !ok { + return cachedEnabled() // .get raised into the except Exception + } + enabled := pyjson.Truthy(obj.GetDefault("enabled", false)) + + cache.Set("hook_discovery", pyjson.NewObject(). + Set("enabled", enabled). + Set("fetched_at", utcStamp(time.Now()))) + _ = c.writeDiscoveryCache(cache) // except OSError: pass + return enabled +} + +// dispatchDiscovery mirrors _dispatch_discovery (1479-1604): org flag, +// 24h debounce, stale-lock check, atomic dispatch claim, then a +// fire-and-forget detached run of the local discovery binary. +func (c *claudeCodeHook) dispatchDiscovery() { + if !c.hookDiscoveryEnabledForOrg() { + return + } + defer func() { + if r := recover(); r != nil { + pe, ok := r.(pyRaise) + if !ok { + panic(r) + } + c.rep.LogError(fmt.Sprintf("discovery gate failed: %s", pe.msg), "discovery_gate") + } + }() + + cache := c.readDiscoveryCache() + + if last, ok := cache.GetDefault("last_run_at", nil).(string); ok { + if ts, err := parseUTCStamp(last); err == nil && time.Since(ts) < discoveryDebounce { + return + } + } + + // Another run in flight (1502-1508). + if locks.IsFresh(c.discoveryLock, discoveryStaleLock) { + return + } + + // Atomic dispatch claim — first hook to create the marker wins; + // concurrent peers bail to avoid duplicate detached spawns (1510-1530). + claimed, err := locks.Claim(c.dispatchLock, discoveryDispatchTTL) + if err != nil { + raise("dispatch claim failed: %v", err) + } + if !claimed { + return + } + defer locks.Release(c.dispatchLock) + + cfg, err := config.Read() + if err != nil { + var typeErr *json.UnmarshalTypeError + if errors.As(err, &typeErr) { + raise("config.json is not a dict: %v", err) + } + c.rep.LogError(fmt.Sprintf("discovery gate: could not read %s: %v", c.unboundConfig, err), "discovery_gate") + return + } + if cfg.APIKey == "" { + c.rep.LogError("discovery gate: api_key missing in ~/.unbound/config.json", "discovery_gate") + return + } + if cfg.BaseURL == "" { + c.rep.LogError("discovery gate: base_url missing in ~/.unbound/config.json", "discovery_gate") + return + } + + // Frozen binary: never fetch install.sh — run the locally installed + // discovery binary, or skip if it isn't there. + if fi, err := os.Stat(frozenDiscoveryBin); err != nil || fi.IsDir() { + c.rep.LogError(fmt.Sprintf("discovery gate: discovery binary missing at %s", frozenDiscoveryBin), "discovery_gate") + return + } + + // api_key goes via env so it never appears in argv / /proc//cmdline. + if err := spawnDetached( + []string{frozenDiscoveryBin, "--domain", cfg.BaseURL}, + []string{"UNBOUND_API_KEY=" + cfg.APIKey}, + ); err != nil { + c.rep.LogError(fmt.Sprintf("discovery gate: Popen failed: %v", err), "discovery_gate") + return + } + + // Stamp last_run_at only after the spawn succeeds so a launch failure + // doesn't burn the 24h window (1590-1597). + cache.Set("last_run_at", utcStamp(time.Now())) + if err := c.writeDiscoveryCache(cache); err != nil { + raise("%v", err) + } +} + +// dispatchMCPServerScan mirrors _dispatch_mcp_server_scan (1419-1476): +// report ONE unknown MCP server out-of-band. Detached so the blocking +// PreToolUse hook returns immediately; secrets travel via env, never argv. +func (c *claudeCodeHook) dispatchMCPServerScan(serverName string, serverConfig *pyjson.Object) { + if serverName == "" { + c.rep.LogError("mcp scan dispatch: empty server name, skipping", "mcp_server") + return + } + defer func() { + if r := recover(); r != nil { + pe, ok := r.(pyRaise) + if !ok { + panic(r) + } + c.rep.LogError(fmt.Sprintf("mcp scan dispatch failed for %s: %s", serverName, pe.msg), "mcp_server") + } + }() + + cfg, err := config.Read() + if err != nil { + var typeErr *json.UnmarshalTypeError + if errors.As(err, &typeErr) { + raise("config.json is not a dict: %v", err) + } + c.rep.LogError(fmt.Sprintf("mcp scan dispatch: cannot read config: %v", err), "mcp_server") + return + } + if cfg.APIKey == "" || cfg.BaseURL == "" { + c.rep.LogError("mcp scan dispatch: api_key/base_url missing in config", "mcp_server") + return + } + + if fi, err := os.Stat(frozenDiscoveryBin); err != nil || fi.IsDir() { + c.rep.LogError(fmt.Sprintf("mcp scan dispatch: discovery binary missing at %s", frozenDiscoveryBin), "mcp_server") + return + } + + serverJSON, err := pyjson.Dumps(serverConfig) + if err != nil { + raise("json dumps failed: %v", err) + } + if err := spawnDetached( + []string{frozenDiscoveryBin, "mcp-scan", "--name", serverName, "--domain", cfg.BaseURL}, + []string{ + "UNBOUND_API_KEY=" + cfg.APIKey, + "UNBOUND_MCP_SERVER_JSON=" + serverJSON, + "UNBOUND_MCP_SERVER_NAME=" + serverName, + "UNBOUND_MCP_DOMAIN=" + cfg.BaseURL, + }, + ); err != nil { + raise("Popen failed: %v", err) + } +} + +// spawnDetached mirrors the fire-and-forget Popen kwargs (1460-1473, +// 1577-1585): all stdio on the null device (os/exec's default for nil +// std fields), a new session (start_new_session=True), and no wait. +func spawnDetached(argv, extraEnv []string) error { + cmd := exec.Command(argv[0], argv[1:]...) + cmd.Env = append(os.Environ(), extraEnv...) + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + if err := cmd.Start(); err != nil { + return err + } + return cmd.Process.Release() +} diff --git a/go/internal/hooks/claude_code_identity.go b/go/internal/hooks/claude_code_identity.go new file mode 100644 index 00000000..111a94cd --- /dev/null +++ b/go/internal/hooks/claude_code_identity.go @@ -0,0 +1,205 @@ +package hooks + +// Account identity + device serial for the claude-code port +// (claude-code/hooks/unbound.py lines 674-831). + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +// placeholderSerials mirrors _PLACEHOLDER_SERIALS (713-719): DMI/BIOS +// serials that are shared sentinel strings on VMs and OEM boards — treat as +// "no serial" so machines never collide on the same fake value. +var placeholderSerials = map[string]bool{ + "": true, "0": true, "00000000": true, "000000000": true, "0000000000": true, + "none": true, "na": true, "n/a": true, + "unknown": true, "default": true, "default string": true, + "to be filled by o.e.m.": true, "to be filled by oem": true, + "system serial number": true, "serial number": true, + "not applicable": true, "not specified": true, "not available": true, + "oem": true, "o.e.m.": true, "invalid": true, "123456789": true, "xxxxxxxx": true, +} + +// validSerial mirrors _valid_serial (722-723). +func validSerial(value string) bool { + return value != "" && !placeholderSerials[strings.ToLower(strings.TrimSpace(value))] +} + +// runProbe runs a probe command with a hard timeout, mirroring +// subprocess.run(capture_output=True, text=True, timeout=N). ok is the +// returncode == 0 check; timeouts and spawn failures report !ok. +func runProbe(timeout time.Duration, name string, args ...string) (string, bool) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + cmd := exec.CommandContext(ctx, name, args...) + cmd.WaitDelay = time.Second + out, err := cmd.Output() + if err != nil || ctx.Err() != nil { + return "", false + } + return string(out), true +} + +// getDeviceSerial mirrors _get_device_serial (726-775): best-effort +// hardware serial, filtering placeholders, falling through to a stable +// per-install id. "" stands in for python's None. +func getDeviceSerial() string { + switch runtime.GOOS { + case "darwin": + if out, ok := runProbe(10*time.Second, "system_profiler", "SPHardwareDataType"); ok { + for _, line := range strings.Split(out, "\n") { + if strings.Contains(line, "Serial Number") { + parts := strings.SplitN(line, ": ", 2) + if len(parts) >= 2 && validSerial(parts[1]) { + return strings.TrimSpace(parts[1]) + } + } + } + } + case "linux": + if out, ok := runProbe(10*time.Second, "dmidecode", "-s", "system-serial-number"); ok && validSerial(out) { + return strings.TrimSpace(out) + } + for _, path := range []string{"/etc/machine-id", "/var/lib/dbus/machine-id"} { + if data, err := os.ReadFile(path); err == nil { + value := strings.TrimSpace(string(data)) + if validSerial(value) { + return value + } + } + } + case "windows": + if out, ok := runProbe(10*time.Second, "powershell", "-NoProfile", "-Command", + "(Get-CimInstance -ClassName Win32_BIOS).SerialNumber"); ok && validSerial(out) { + return strings.TrimSpace(out) + } + if out, ok := runProbe(10*time.Second, "powershell", "-NoProfile", "-Command", + `(Get-ItemProperty 'HKLM:\SOFTWARE\Microsoft\Cryptography').MachineGuid`); ok && validSerial(out) { + return strings.TrimSpace(out) + } + } + return "" +} + +// deviceSerial mirrors _device_serial (778-811): cache-first, probe only +// when probe=true (SessionStart and the end-of-turn exchange; the +// latency-critical pre-tool path passes false). The cache is shared with +// the cursor hook, so it is merged and written atomically. +func (c *claudeCodeHook) deviceSerial(probe bool) string { + data := pyjson.NewObject() + if raw, err := os.ReadFile(c.identityCache); err == nil { + if parsed, perr := pyjson.Loads(raw); perr == nil { + if obj, ok := parsed.(*pyjson.Object); ok { + data = obj + if cached, ok := obj.GetDefault("device_serial", nil).(string); ok && strings.TrimSpace(cached) != "" { + return strings.TrimSpace(cached) + } + } + } + } + if !probe { + return "" + } + serial := getDeviceSerial() + if serial != "" { + data.Set("device_serial", serial) + if err := os.MkdirAll(filepath.Dir(c.identityCache), 0o755); err == nil { + if s, err := pyjson.Dumps(data); err == nil { + tmp := filepath.Join(filepath.Dir(c.identityCache), fmt.Sprintf(".identity.%d.tmp", os.Getpid())) + if os.WriteFile(tmp, []byte(s), 0o644) == nil { + _ = os.Rename(tmp, c.identityCache) + } + } + } + } + return serial +} + +// orNone mirrors `value or None`: falsy collapses to nil. +func orNone(v any) any { + if pyjson.Truthy(v) { + return v + } + return nil +} + +// emailDomain mirrors _email_domain (674-681). Any non-string input ends +// as None in python (TypeError into the bare except, or a failed rsplit). +func emailDomain(email any) any { + s, ok := email.(string) + if !ok || s == "" || !strings.Contains(s, "@") { + return nil + } + domain := strings.ToLower(strings.TrimSpace(s[strings.LastIndex(s, "@")+1:])) + if domain == "" { + return nil + } + return domain +} + +// readAccountIdentity mirrors read_account_identity (684-707): pulled from +// ~/.claude.json; every failure leaves the fields None. +func (c *claudeCodeHook) readAccountIdentity() *pyjson.Object { + var orgID, plan, authMode, email any + func() { + data, err := os.ReadFile(c.claudeConfig) + if err != nil { + return + } + parsed, perr := pyjson.Loads(data) + if perr != nil { + return + } + cfg, ok := parsed.(*pyjson.Object) + if !ok { + return + } + if oauth, ok := cfg.GetDefault("oauthAccount", nil).(*pyjson.Object); ok { + orgID = orNone(oauth.GetDefault("organizationUuid", nil)) + plan = orNone(oauth.GetDefault("organizationType", nil)) + email = orNone(oauth.GetDefault("emailAddress", nil)) + authMode = "subscription" + return + } + if os.Getenv("ANTHROPIC_API_KEY") != "" { + authMode = "api_key" + return + } + cak := cfg.GetDefault("customApiKeyResponses", nil) + if !pyjson.Truthy(cak) { + return // (None or {}).get('approved') -> None + } + cakObj, ok := cak.(*pyjson.Object) + if !ok { + return // .get on a non-dict raised into the bare except + } + if pyjson.Truthy(cakObj.GetDefault("approved", nil)) { + authMode = "api_key" + } + }() + return pyjson.NewObject(). + Set("org_id", orgID). + Set("plan", plan). + Set("auth_mode", authMode). + Set("user_email", email). + Set("email_domain", emailDomain(email)) +} + +// buildAccountIdentity mirrors build_account_identity (814-831): the +// account identity plus the (possibly cached-only) device serial. +func (c *claudeCodeHook) buildAccountIdentity(probe bool) *pyjson.Object { + identity := c.readAccountIdentity() + if serial := c.deviceSerial(probe); serial != "" { + identity.Set("device_serial", serial) + } + return identity +} diff --git a/go/internal/hooks/claude_code_pretool.go b/go/internal/hooks/claude_code_pretool.go new file mode 100644 index 00000000..33313426 --- /dev/null +++ b/go/internal/hooks/claude_code_pretool.go @@ -0,0 +1,704 @@ +package hooks + +// PreToolUse / UserPromptSubmit path of the claude-code port: policy cache, +// approval marker + polling, command extraction, MCP server config lookup, +// gateway call, and the Claude Code response transforms. + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + "time" + + "github.com/websentry-ai/setup/go/internal/audit" + "github.com/websentry-ai/setup/go/internal/httpc" + "github.com/websentry-ai/setup/go/internal/pyjson" + "github.com/websentry-ai/setup/go/internal/report" + "github.com/websentry-ai/setup/go/internal/transcript" +) + +// processPreToolUse mirrors process_pre_tool_use (834-967). DO NOT LOG. +func (c *claudeCodeHook) processPreToolUse(event *pyjson.Object) *pyjson.Object { + sessionID := event.GetDefault("session_id", nil) + model := event.GetDefault("model", nil) + if !pyjson.Truthy(model) { + model = c.getSessionModel(sessionID) + } + if !pyjson.Truthy(model) { + model = "auto" + } + transcriptPath := event.GetDefault("transcript_path", nil) + tn := event.GetDefault("tool_name", "") + toolName, ok := tn.(string) + if !ok { + raise("tool_name %v has no attribute 'startswith'", tn) + } + + isMCP := strings.HasPrefix(toolName, mcpToolPrefix) + if !isMCP && !slices.Contains(ccAllowedNonMCPHookNames, toolName) { + return pyjson.NewObject() + } + + cache := c.loadPolicyCache() + toolsToCheck := []any{} + if cache != nil { + toolsToCheck, _ = cache.GetDefault("tools_to_check", []any{}).([]any) + } + needPullPolicies := cache == nil || c.isCacheStale(cache) + + if ccNativeFileTools[toolName] && !pyIn(toolName, toolsToCheck) && !needPullPolicies { + return pyjson.NewObject() + } + + recentUserPrompts := c.getRecentUserPromptsForSession(sessionID, ccPretoolUserMessagesLimit, transcriptPath) + command := extractCommandForPretool(event) + + // Build metadata with the raw event (861-865). + metadata := copyObject(event) + toolInput := event.GetDefault("tool_input", nil) + if !pyjson.Truthy(toolInput) { + toolInput = pyjson.NewObject() + } + if pyIn("file_path", toolInput) { + metadata.Set("file_path", pyIndex(toolInput, "file_path")) + } + + if isMCP { + // Parse mcp____ for gateway matching (867-880). + parts := strings.SplitN(toolName[len(mcpToolPrefix):], "__", 2) + mcpServerName := parts[0] + metadata.Set("mcp_server", mcpServerName) + mcpTool := "" + if len(parts) >= 2 { + mcpTool = parts[1] + } + metadata.Set("mcp_tool", mcpTool) + + if mcpServerName != "" { + cwd := event.GetDefault("cwd", nil) + if serverCfg := c.readMCPServerConfig(mcpServerName, cwd); serverCfg != nil { + metadata.Set("mcp_server_config", serverCfg) + } + } + } + + approvalKey := toolName + ":" + pyStr(command) + isRetry := c.isApprovalRetry(approvalKey) + + requestBody := pyjson.NewObject(). + Set("conversation_id", sessionID). + Set("unbound_app_label", "claude-code"). + Set("model", model). + Set("event_name", "tool_use"). + Set("pre_tool_use_data", pyjson.NewObject(). + Set("command", command). + Set("tool_name", toolName). + Set("metadata", metadata)). + Set("account_identity", c.buildAccountIdentity(false)) + // **_build_user_prompt_payload (473-478, 896). + messages := []any{} + if len(recentUserPrompts) > 0 { + if last := recentUserPrompts[len(recentUserPrompts)-1]; pyjson.Truthy(last) { + messages = []any{pyjson.NewObject().Set("role", "user").Set("content", last)} + } + } + requestBody.Set("messages", messages) + requestBody.Set("user_prompts", recentUserPrompts) + + if !isRetry { + requestBody.Set("first_approval_check", true) + } else if markerData := c.getApprovalMarkerData(); markerData != nil && markerData.Len() > 0 { + policyIDs := markerData.GetDefault("policyIds", []any{}) + applicationID := markerData.GetDefault("applicationId", "") + requestID := markerData.GetDefault("requestId", "") + c.clearApprovalMarker() + result := c.pollApprovalStatus(policyIDs, applicationID, requestID, approvalTimeout) + + switch result { + case "approved": + return transformResponseForClaude(pyjson.NewObject().Set("decision", "allow")) + case "deny": + return transformResponseForClaude(pyjson.NewObject(). + Set("decision", "deny"). + Set("reason", "Blocked by organization policy. This command was denied via Slack."). + Set("additionalContext", "This command was denied by an organization security policy. Do not attempt to achieve the same result using alternative tools, file operations, or workarounds. Inform the user and stop.")) + default: + adminContact := markerData.GetDefault("escalatedAdminContact", "") + var timeoutReason string + if pyjson.Truthy(adminContact) { + timeoutReason = "Blocked by organization policy. Approval request timed out — " + + "ask " + pyStr(adminContact) + " to check Slack and retry the command." + } else { + timeoutReason = "Blocked by organization policy. Approval request timed out — check your Slack DMs and retry the command." + } + return transformResponseForClaude(pyjson.NewObject(). + Set("decision", "deny"). + Set("reason", timeoutReason). + Set("additionalContext", "This command was blocked by an organization security policy that requires approval. Do not attempt to achieve the same result using alternative tools, file operations, or workarounds. The user must approve via Slack and retry.")) + } + } + + if needPullPolicies { + requestBody.Set("pull_policies", true) + } + + apiResponse := c.sendToHookAPI(requestBody) + + if apiResponse == nil { + if c.getPolicyCheckFailureAction() == "block" { + return transformResponseForClaude(pyjson.NewObject(). + Set("decision", "deny"). + Set("reason", ccPolicyCheckFailureBlockReason). + Set("additionalContext", "The organization policy engine could not be reached. This is a transient infrastructure failure. Tell the user the policy engine is unavailable and ask them to retry.")) + } + c.rep.ReportToGateway( + fmt.Sprintf("Hook bypassed_due_to_failure: gateway unreachable for tool=%s", toolName), + "bypassed_due_to_failure", c.apiKey) + return pyjson.NewObject() + } + + _, hasTools := apiResponse.Get("tools_to_check") + _, hasAction := apiResponse.Get("policy_check_failure_action") + if hasTools || hasAction { + c.savePolicyCache( + apiResponse.GetDefault("tools_to_check", nil), + apiResponse.GetDefault("policy_check_failure_action", nil)) + } + + if d, _ := apiResponse.GetDefault("decision", nil).(string); d == "approval_required" { + return c.handleApprovalRequiredResponse(apiResponse, approvalKey) + } + + if isMCP && pyjson.Truthy(apiResponse.GetDefault("unknown_mcp_server", nil)) { + if serverCfg, ok := metadata.GetDefault("mcp_server_config", nil).(*pyjson.Object); ok && pyjson.Truthy(serverCfg) { + serverName, _ := metadata.GetDefault("mcp_server", "").(string) + c.dispatchMCPServerScan(serverName, serverCfg) + } + } + + return transformResponseForClaude(apiResponse) +} + +// processUserPromptSubmit mirrors process_user_prompt_submit (970-986). +func (c *claudeCodeHook) processUserPromptSubmit(event *pyjson.Object) *pyjson.Object { + sessionID := event.GetDefault("session_id", nil) + model := event.GetDefault("model", nil) + if !pyjson.Truthy(model) { + model = c.getSessionModel(sessionID) + } + if !pyjson.Truthy(model) { + model = "auto" + } + prompt := event.GetDefault("prompt", "") + + messages := []any{} + if pyjson.Truthy(prompt) { + messages = []any{pyjson.NewObject().Set("role", "user").Set("content", prompt)} + } + requestBody := pyjson.NewObject(). + Set("conversation_id", sessionID). + Set("unbound_app_label", "claude-code"). + Set("model", model). + Set("event_name", "user_prompt"). + Set("account_identity", c.buildAccountIdentity(false)). + Set("messages", messages) + + return transformResponseForClaudePrompt(c.sendToHookAPI(requestBody)) +} + +// sendToHookAPI mirrors send_to_hook_api (514-538). nil stands in for +// python's falsy {} on every no-result path; a truthy non-dict response +// raises like the attribute access python would hit next (a str body that +// happens to contain a checked substring diverges — accepted). +func (c *claudeCodeHook) sendToHookAPI(requestBody *pyjson.Object) *pyjson.Object { + if c.apiKey == "" { + return nil + } + data, err := pyjson.Dumps(requestBody) + if err != nil { + c.rep.LogError(fmt.Sprintf("Hook API error: %v", err), "api_call") + return nil + } + res, err := httpc.PostJSON(c.gatewayURL+"/v1/hooks/pretool", c.apiKey, []byte(data), 20*time.Second) + if err != nil { + c.rep.LogError(fmt.Sprintf("Hook API error: %v", err), "api_call") + return nil + } + if res.ExitCode == 0 && len(res.Stdout) > 0 { + parsed, perr := pyjson.Loads(res.Stdout) + if perr != nil { + c.rep.LogError(fmt.Sprintf("Hook API error: %v", perr), "api_call") + return nil + } + if obj, ok := parsed.(*pyjson.Object); ok { + if obj.Len() == 0 { + return nil + } + return obj + } + if pyjson.Truthy(parsed) { + raise("hook api response is not a dict") + } + } + return nil +} + +// transformResponseForClaude mirrors transform_response_for_claude (586-602). +func transformResponseForClaude(apiResponse *pyjson.Object) *pyjson.Object { + if apiResponse == nil || apiResponse.Len() == 0 { + return pyjson.NewObject() + } + return pyjson.NewObject().Set("hookSpecificOutput", pyjson.NewObject(). + Set("hookEventName", "PreToolUse"). + Set("permissionDecision", apiResponse.GetDefault("decision", "allow")). + Set("permissionDecisionReason", apiResponse.GetDefault("reason", "")). + Set("additionalContext", apiResponse.GetDefault("additionalContext", ""))) +} + +// transformResponseForClaudePrompt mirrors transform_response_for_claude_prompt +// (605-620): for UserPromptSubmit, 'deny' maps to 'block'. +func transformResponseForClaudePrompt(apiResponse *pyjson.Object) *pyjson.Object { + if apiResponse == nil || apiResponse.Len() == 0 { + return pyjson.NewObject() + } + if d, _ := apiResponse.GetDefault("decision", "allow").(string); d == "deny" { + return pyjson.NewObject(). + Set("decision", "block"). + Set("reason", apiResponse.GetDefault("reason", "")) + } + return pyjson.NewObject() +} + +// extractCommandForPretool mirrors extract_command_for_pretool (481-511). +func extractCommandForPretool(event *pyjson.Object) any { + toolInput := event.GetDefault("tool_input", nil) + if !pyjson.Truthy(toolInput) { + toolInput = pyjson.NewObject() + } + tn := event.GetDefault("tool_name", "") + toolName, ok := tn.(string) + if !ok { + raise("tool_name %v has no attribute 'startswith'", tn) + } + + switch { + case toolName == "Bash" && pyIn("command", toolInput): + return pyIndex(toolInput, "command") + case strings.HasPrefix(toolName, mcpToolPrefix): + s, err := pyjson.Dumps(toolInput) + if err != nil { + raise("json dumps failed: %v", err) + } + return s + case (toolName == "Write" || toolName == "Edit" || toolName == "Read") && pyIn("file_path", toolInput): + return pyIndex(toolInput, "file_path") + case toolName == "Grep" && pyIn("pattern", toolInput): + return pyIndex(toolInput, "pattern") + case toolName == "Glob" && pyIn("pattern", toolInput): + return pyIndex(toolInput, "pattern") + case toolName == "WebFetch" && pyIn("url", toolInput): + return pyIndex(toolInput, "url") + case toolName == "WebSearch" && pyIn("query", toolInput): + return pyIndex(toolInput, "query") + case toolName == "Task" && pyIn("prompt", toolInput): + return pyIndex(toolInput, "prompt") + } + return toolName +} + +// getRecentUserPromptsForSession mirrors get_recent_user_prompts_for_session +// (441-470): audit-log prompts first, transcript user messages as fallback. +func (c *claudeCodeHook) getRecentUserPromptsForSession(sessionID any, n int, transcriptPath any) []any { + if n <= 0 { + return []any{} + } + + prompts := []any{} + for _, entry := range audit.Load(c.auditLog) { + log := mustObj(entry) + logSession := log.GetDefault("session_id", nil) + if !pyjson.Truthy(logSession) { + logSession = objGet(log.GetDefault("event", pyjson.NewObject()), "session_id", nil) + } + if !pyEq(logSession, sessionID) { + continue + } + event := mustObj(log.GetDefault("event", pyjson.NewObject())) + if hen, _ := event.GetDefault("hook_event_name", nil).(string); hen != "UserPromptSubmit" { + continue + } + if prompt := event.GetDefault("prompt", nil); pyjson.Truthy(prompt) { + prompts = append(prompts, prompt) + } + } + + if len(prompts) > 0 { + if len(prompts) > n { + prompts = prompts[len(prompts)-n:] + } + return prompts + } + + if pyjson.Truthy(transcriptPath) { + tp, ok := transcriptPath.(string) + if !ok { + raise("os.path.exists on a non-str transcript_path") + } + if tp != "undefined" { + if _, err := os.Stat(tp); err == nil { + userMessages := transcript.ParseFile(tp, "").UserMessages + if len(userMessages) > n { + userMessages = userMessages[len(userMessages)-n:] + } + out := []any{} + for _, m := range userMessages { + if pyjson.Truthy(m.Content) { + out = append(out, m.Content) + } + } + return out + } + } + } + + return []any{} +} + +// --- policy cache (144-202) --- + +// readPolicyCacheRaw mirrors _read_policy_cache_raw (144-153): nil on +// missing, unreadable, corrupt, or non-dict. +func (c *claudeCodeHook) readPolicyCacheRaw() *pyjson.Object { + data, err := os.ReadFile(c.policyCache) + if err != nil { + return nil + } + parsed, perr := pyjson.Loads(data) + if perr != nil { + return nil + } + obj, _ := parsed.(*pyjson.Object) + return obj +} + +// loadPolicyCache mirrors load_policy_cache (156-163). +func (c *claudeCodeHook) loadPolicyCache() *pyjson.Object { + cache := c.readPolicyCacheRaw() + if cache == nil { + return nil + } + if _, ok := cache.Get("last_synced"); !ok { + return nil + } + ttc, ok := cache.Get("tools_to_check") + if !ok { + return nil + } + if _, ok := ttc.([]any); !ok { + return nil + } + return cache +} + +// getPolicyCheckFailureAction mirrors get_policy_check_failure_action +// (166-172): defaults to 'allow', ignores TTL. +func (c *claudeCodeHook) getPolicyCheckFailureAction() string { + cache := c.readPolicyCacheRaw() + if cache == nil { + return "allow" + } + if v, _ := cache.GetDefault("policy_check_failure_action", nil).(string); v == "allow" || v == "block" { + return v + } + return "allow" +} + +// savePolicyCache mirrors save_policy_cache (175-192): nil for either field +// preserves the prior value; all errors are swallowed. +func (c *claudeCodeHook) savePolicyCache(toolsToCheck any, policyCheckFailureAction any) { + if err := os.MkdirAll(filepath.Dir(c.policyCache), 0o755); err != nil { + return + } + if toolsToCheck == nil { + toolsToCheck = []any{} + if prior := c.readPolicyCacheRaw(); prior != nil { + toolsToCheck = prior.GetDefault("tools_to_check", []any{}) + } + } + action, ok := policyCheckFailureAction.(string) + if !ok || (action != "allow" && action != "block") { + action = c.getPolicyCheckFailureAction() + } + cache := pyjson.NewObject(). + Set("last_synced", report.UTCTimestamp(time.Now())). + Set("tools_to_check", toolsToCheck). + Set("policy_check_failure_action", action) + s, err := pyjson.Dumps(cache) + if err != nil { + return + } + _ = os.WriteFile(c.policyCache, []byte(s), 0o644) +} + +// isCacheStale mirrors is_cache_stale (195-202): parse errors mean stale; a +// non-str last_synced raises (python AttributeError is not in the except). +func (c *claudeCodeHook) isCacheStale(cache *pyjson.Object) bool { + v, ok := cache.Get("last_synced") + if !ok { + return true + } + s, isStr := v.(string) + if !isStr { + raise("last_synced %v has no attribute 'rstrip'", v) + } + // fromisoformat parse of our own isoformat()+'Z' output. (Python 3.11+ + // accepts more ISO shapes; this layout is the only one the hook writes.) + synced, err := time.Parse("2006-01-02T15:04:05.999999", strings.TrimRight(s, "Z")) + if err != nil { + return true + } + return time.Now().UTC().Sub(synced).Seconds() > ccCacheTTLSeconds +} + +// --- approval marker + polling (241-328, 541-583) --- + +func approvalCmdHash(command string) string { + sum := sha256.Sum256([]byte(command)) + return hex.EncodeToString(sum[:])[:16] +} + +// isApprovalRetry mirrors _is_approval_retry (244-253): true iff a marker +// exists for this exact command and is fresh. A non-dict / non-numeric +// marker raises (python only catches OSError and JSONDecodeError here). +func (c *claudeCodeHook) isApprovalRetry(command string) bool { + data, err := os.ReadFile(c.approvalMarker) + if err != nil { + return false + } + parsed, perr := pyjson.Loads(data) + if perr != nil { + return false + } + obj := mustObj(parsed) + if !pyEq(obj.GetDefault("cmd", nil), approvalCmdHash(command)) { + return false + } + ts, ok := toFloat(obj.GetDefault("ts", pyjson.Number("0"))) + if !ok { + raise("unsupported operand type for -: approval marker ts") + } + return float64(time.Now().UnixNano())/1e9-ts < approvalTimeout.Seconds() +} + +// setApprovalMarker mirrors _set_approval_marker (256-272). Errors raise: +// python has no try here, exceptions surface in main's blanket except. +func (c *claudeCodeHook) setApprovalMarker(command string, policyIDs, applicationID, requestID, escalatedAdminContact any) { + if err := os.MkdirAll(filepath.Dir(c.approvalMarker), 0o755); err != nil { + raise("approval marker mkdir failed: %v", err) + } + data := pyjson.NewObject(). + Set("cmd", approvalCmdHash(command)). + Set("ts", float64(time.Now().UnixNano())/1e9). + Set("policyIds", policyIDs). + Set("applicationId", applicationID). + Set("requestId", requestID). + Set("escalatedAdminContact", escalatedAdminContact) + s, err := pyjson.Dumps(data) + if err != nil { + raise("json dumps failed: %v", err) + } + if err := os.WriteFile(c.approvalMarker, []byte(s), 0o644); err != nil { + raise("approval marker write failed: %v", err) + } +} + +// getApprovalMarkerData mirrors _get_approval_marker_data (275-281). +func (c *claudeCodeHook) getApprovalMarkerData() *pyjson.Object { + data, err := os.ReadFile(c.approvalMarker) + if err != nil { + return nil + } + parsed, perr := pyjson.Loads(data) + if perr != nil { + return nil + } + // python returns whatever json gives; a non-dict raises at the caller's + // first .get — raised here instead, same outcome. + return mustObj(parsed) +} + +func (c *claudeCodeHook) clearApprovalMarker() { + _ = os.Remove(c.approvalMarker) +} + +// handleApprovalRequiredResponse mirrors _handle_approval_required_response +// (291-328): set the marker, deny with explicit retry instructions. +func (c *claudeCodeHook) handleApprovalRequiredResponse(apiResponse *pyjson.Object, approvalKey string) *pyjson.Object { + approvalCheck := mustObj(apiResponse.GetDefault("approvalCheck", pyjson.NewObject())) + policyIDs := approvalCheck.GetDefault("policyIds", []any{}) + applicationID := approvalCheck.GetDefault("applicationId", "") + requestID := approvalCheck.GetDefault("requestId", "") + adminContact := approvalCheck.GetDefault("escalatedAdminContact", "") + if !pyjson.Truthy(adminContact) { + adminContact = "" // `or ''` + } + + var reason, tellUser string + if pyjson.Truthy(adminContact) { + reason = "We could not find your Slack account, so an approval request was sent to " + pyStr(adminContact) + ". " + + "Please ask them to approve it in Slack." + tellUser = `Tell the user: "` + reason + `" ` + } else { + reason = "An approval request has been sent to your Slack DMs. Please approve it there." + tellUser = `Tell the user: "An approval request has been sent to your Slack DMs. ` + + `Please approve it and I will retry automatically." ` + } + + c.setApprovalMarker(approvalKey, policyIDs, applicationID, requestID, adminContact) + return transformResponseForClaude(pyjson.NewObject(). + Set("decision", "deny"). + Set("reason", reason). + Set("additionalContext", + "This is NOT a permanent block — it is a temporary hold pending Slack approval. "+ + tellUser+ + "Then immediately retry the exact same tool call with the exact same command — "+ + "do not modify the command in any way, do not add sleep or any prefix. "+ + "Retry exactly once — the second attempt will wait for the approval.")) +} + +// nextPollInterval mirrors _next_poll_interval (541-546). +func nextPollInterval(elapsed float64) int { + for _, phase := range approvalPollPhases { + if elapsed < float64(phase[0]) { + return phase[1] + } + } + return approvalPollPhases[len(approvalPollPhases)-1][1] +} + +// pollApprovalStatus mirrors poll_approval_status (548-583): poll until +// approved, denied, or timeout, sleeping per the backoff phases. +func (c *claudeCodeHook) pollApprovalStatus(policyIDs, applicationID, requestID any, timeout time.Duration) string { + url := c.gatewayURL + "/v1/hooks/pretool/approval-status" + payload := pyjson.NewObject(). + Set("policyIds", policyIDs). + Set("applicationId", applicationID) + if pyjson.Truthy(requestID) { + payload.Set("requestId", requestID) + } + body, err := pyjson.Dumps(payload) + if err != nil { + raise("json dumps failed: %v", err) + } + + start := time.Now() + deadline := start.Add(timeout) + + for time.Now().Before(deadline) { + time.Sleep(time.Duration(nextPollInterval(time.Since(start).Seconds())) * time.Second) + res, err := httpc.PostJSON(url, c.apiKey, []byte(body), 10*time.Second) + if err != nil { + c.rep.LogError(fmt.Sprintf("Approval poll error: %v", err), "general") + continue + } + if res.ExitCode == 0 && len(res.Stdout) > 0 { + parsed, perr := pyjson.Loads(res.Stdout) + if perr != nil { + c.rep.LogError(fmt.Sprintf("Approval poll error: %v", perr), "general") + continue + } + obj, ok := parsed.(*pyjson.Object) + if !ok { + c.rep.LogError("Approval poll error: response is not a dict", "general") + continue + } + decision, _ := obj.GetDefault("decision", "pending").(string) + if decision == "allow" { + return "approved" + } + if decision == "deny" { + return "deny" + } + } + } + + return "timeout" +} + +// --- MCP server config (623-671) --- + +// extractMCPServerFields mirrors _extract_mcp_server_fields (623-635). +func extractMCPServerFields(server any) *pyjson.Object { + obj, ok := server.(*pyjson.Object) + if !ok { + return nil + } + result := pyjson.NewObject() + for _, key := range []string{"url", "command", "args", "type"} { + if v := obj.GetDefault(key, nil); pyjson.Truthy(v) { + result.Set(key, v) + } + } + if result.Len() == 0 { + return nil + } + return result +} + +// readMCPServerConfig mirrors _read_mcp_server_config (638-671): project +// servers for cwd (walking up parents) first, then top-level mcpServers. +// Any failure returns nil (python's broad except). +func (c *claudeCodeHook) readMCPServerConfig(serverName string, cwd any) *pyjson.Object { + data, err := os.ReadFile(c.claudeConfig) + if err != nil { + return nil + } + parsed, perr := pyjson.Loads(data) + if perr != nil { + return nil + } + cfg, ok := parsed.(*pyjson.Object) + if !ok { + return nil + } + + if pyjson.Truthy(cwd) { + cwdStr, isStr := cwd.(string) + if !isStr { + return nil // cwd.replace would have raised into the broad except + } + if projects, ok := cfg.GetDefault("projects", pyjson.NewObject()).(*pyjson.Object); ok { + cwdPath := strings.TrimRight(strings.ReplaceAll(cwdStr, "\\", "/"), "/") + for cwdPath != "" { + if projData, ok := projects.GetDefault(cwdPath, nil).(*pyjson.Object); ok { + if projServers, ok := projData.GetDefault("mcpServers", pyjson.NewObject()).(*pyjson.Object); ok { + if v, has := projServers.Get(serverName); has { + if result := extractMCPServerFields(v); result != nil { + return result + } + } + } + } + parent := posixDirname(cwdPath) + if parent == cwdPath { + break + } + cwdPath = parent + } + } + } + + if topServers, ok := cfg.GetDefault("mcpServers", pyjson.NewObject()).(*pyjson.Object); ok { + if v, has := topServers.Get(serverName); has { + if result := extractMCPServerFields(v); result != nil { + return result + } + } + } + + return nil +} diff --git a/go/internal/hooks/claude_code_stop.go b/go/internal/hooks/claude_code_stop.go new file mode 100644 index 00000000..176f781b --- /dev/null +++ b/go/internal/hooks/claude_code_stop.go @@ -0,0 +1,260 @@ +package hooks + +// Stop path of the claude-code port: session-event extraction from the +// audit log, transcript merge, exchange build, and the /v1/hooks/claude +// POST. + +import ( + "fmt" + "strings" + "time" + + "github.com/websentry-ai/setup/go/internal/audit" + "github.com/websentry-ai/setup/go/internal/httpc" + "github.com/websentry-ai/setup/go/internal/pyjson" + "github.com/websentry-ai/setup/go/internal/transcript" +) + +// processStopEvent mirrors process_stop_event (1129-1178). +func (c *claudeCodeHook) processStopEvent(event *pyjson.Object) { + sessionID := event.GetDefault("session_id", nil) + transcriptPath := event.GetDefault("transcript_path", nil) + lastAssistantMessage := event.GetDefault("last_assistant_message", nil) + + logs := audit.Load(c.auditLog) + + // A new UserPromptSubmit RESETS the list: the exchange covers the + // latest turn only (1140-1151). + sessionEvents := []any{} + currentConversationStarted := false + var userPromptTimestamp any + + for _, entry := range logs { + log := mustObj(entry) + logSessionID := log.GetDefault("session_id", nil) + if !pyjson.Truthy(logSessionID) { + logSessionID = objGet(log.GetDefault("event", pyjson.NewObject()), "session_id", nil) + } + if !pyEq(logSessionID, sessionID) { + continue + } + var eventName any + if v, has := log.Get("event"); has { + eventName = objGet(v, "hook_event_name", nil) + } else { + eventName = log.GetDefault("hook_event_name", nil) + } + + if en, _ := eventName.(string); en == "UserPromptSubmit" { + sessionEvents = []any{entry} + currentConversationStarted = true + userPromptTimestamp = log.GetDefault("timestamp", nil) + } else if currentConversationStarted { + sessionEvents = append(sessionEvents, entry) + } + } + + transcriptAssistantMessages := []any{} + var transcriptUsage *transcript.Usage + var transcriptModel any + if pyjson.Truthy(transcriptPath) { + tp, isStr := transcriptPath.(string) + // python `transcript_path != 'undefined'` is True for any non-str. + if (!isStr || tp != "undefined") && pyjson.Truthy(userPromptTimestamp) { + if !isStr { + raise("os.path.exists on a non-str transcript_path") + } + ts, ok := userPromptTimestamp.(string) + if !ok { + // python would abort the transcript scan on the first + // filtered assistant entry instead — accepted divergence + // on corrupt audit timestamps. + ts = "" + } + data := transcript.ParseFile(tp, ts) + for _, m := range data.AssistantMessages { + if pyjson.Truthy(m.Content) { + transcriptAssistantMessages = append(transcriptAssistantMessages, m.Content) + } + } + transcriptUsage = data.Usage + transcriptModel = data.Model + } + } + + // Prefer the dominant model from the transcript (covers sub-agent turns + // where the cached session model is wrong); audit log otherwise (1167). + sessionModel := transcriptModel + if !pyjson.Truthy(sessionModel) { + sessionModel = extractSessionModel(logs, sessionID) + } + if !pyjson.Truthy(sessionModel) { + sessionModel = "auto" + } + + exchange := c.buildLLMExchange(sessionEvents, lastAssistantMessage, transcriptAssistantMessages, sessionModel, transcriptUsage) + if exchange != nil { + c.sendToAPI(exchange) + } +} + +// buildLLMExchange mirrors build_llm_exchange (989-1070). Returns nil when +// fewer than two messages were assembled — send nothing. +func (c *claudeCodeHook) buildLLMExchange(events []any, stopAssistantMessage any, transcriptAssistantMessages []any, model any, usage *transcript.Usage) *pyjson.Object { + messages := []any{} + assistantToolUses := []any{} + + var userPrompt, sessionID, permissionMode any + + for _, logEntry := range events { + le := mustObj(logEntry) + event := le + if v, has := le.Get("event"); has { + event = mustObj(v) + } + hookEventName, _ := event.GetDefault("hook_event_name", nil).(string) + + if !pyjson.Truthy(sessionID) { + sessionID = event.GetDefault("session_id", nil) + } + if !pyjson.Truthy(permissionMode) { + permissionMode = event.GetDefault("permission_mode", nil) + } + + if hookEventName == "UserPromptSubmit" { + if prompt := event.GetDefault("prompt", nil); pyjson.Truthy(prompt) { + userPrompt = prompt + } + } else if hookEventName == "PostToolUse" { + toolName := event.GetDefault("tool_name", nil) + toolInput := event.GetDefault("tool_input", pyjson.NewObject()) + toolResponse := event.GetDefault("tool_response", pyjson.NewObject()) + + // Dedup: drop the response content when it just echoes the + // input content (1017-1019). + if pyIn("content", toolResponse) && pyIn("content", toolInput) { + if pyEq(pyIndex(toolResponse, "content"), pyIndex(toolInput, "content")) { + tr := mustObj(toolResponse) + copied := pyjson.NewObject() + for _, m := range tr.Members() { + if m.Key != "content" { + copied.Set(m.Key, m.Value) + } + } + toolResponse = copied + } + } + + assistantToolUses = append(assistantToolUses, pyjson.NewObject(). + Set("type", "PostToolUse"). + Set("tool_name", toolName). + Set("tool_input", toolInput). + Set("tool_response", toolResponse)) + } + } + + if pyjson.Truthy(userPrompt) { + messages = append(messages, pyjson.NewObject().Set("role", "user").Set("content", userPrompt)) + } + + allResponses := append([]any{}, transcriptAssistantMessages...) + if pyjson.Truthy(stopAssistantMessage) { + found := false + for _, r := range allResponses { + if pyEq(r, stopAssistantMessage) { + found = true + break + } + } + if !found { + allResponses = append(allResponses, stopAssistantMessage) + } + } + assistantResponse := "" + if len(allResponses) > 0 { + parts := make([]string, len(allResponses)) + for i, r := range allResponses { + s, ok := r.(string) + if !ok { + raise("sequence item %d: expected str instance in join", i) + } + parts[i] = s + } + assistantResponse = strings.Join(parts, "\n\n") + } + + if assistantResponse != "" || len(assistantToolUses) > 0 { + assistantMsg := pyjson.NewObject(). + Set("role", "assistant"). + Set("content", assistantResponse) + if len(assistantToolUses) > 0 { + assistantMsg.Set("tool_use", assistantToolUses) + } + messages = append(messages, assistantMsg) + } + + if len(messages) < 2 { + return nil + } + + if !pyjson.Truthy(permissionMode) { + permissionMode = "default" + } + + if !pyjson.Truthy(model) { + model = c.getSessionModel(sessionID) + if !pyjson.Truthy(model) { + model = "auto" + } + } + + conversationID := sessionID + if !pyjson.Truthy(conversationID) { + conversationID = "unknown" + } + exchange := pyjson.NewObject(). + Set("conversation_id", conversationID). + Set("model", model). + Set("messages", messages). + Set("permission_mode", permissionMode). + Set("account_identity", c.buildAccountIdentity(true)) + + if usage != nil { + exchange.Set("usage", pyjson.NewObject(). + Set("input_tokens", usage.InputTokens). + Set("output_tokens", usage.OutputTokens). + Set("cache_read_input_tokens", usage.CacheReadInputTokens). + Set("cache_creation_input_tokens", usage.CacheCreationInputTokens). + Set("total_tokens", usage.TotalTokens)) + } + + return exchange +} + +// sendToAPI mirrors send_to_api (1073-1100): POST the exchange, log-only on +// any failure. +func (c *claudeCodeHook) sendToAPI(exchange *pyjson.Object) bool { + if c.apiKey == "" { + c.rep.LogError("No API key present in send_to_api function", "config") + return false + } + data, err := pyjson.Dumps(exchange) + if err != nil { + c.rep.LogError(fmt.Sprintf("Exception in send_to_api: %v", err), "api_call") + return false + } + res, err := httpc.PostJSON(c.gatewayURL+"/v1/hooks/claude", c.apiKey, []byte(data), 10*time.Second) + if err != nil { + c.rep.LogError(fmt.Sprintf("Exception in send_to_api: %v", err), "api_call") + return false + } + if res.ExitCode != 0 { + errorMsg := "Unknown error" + if len(res.Stderr) > 0 { + errorMsg = strings.TrimSpace(string(res.Stderr)) + } + c.rep.LogError("API request failed: "+errorMsg, "api_call") + return false + } + return true +} diff --git a/go/internal/hooks/claude_code_test.go b/go/internal/hooks/claude_code_test.go new file mode 100644 index 00000000..112e690e --- /dev/null +++ b/go/internal/hooks/claude_code_test.go @@ -0,0 +1,660 @@ +package hooks + +// Hermetic tests for the claude-code port: temp HOME, fake curl / +// system_profiler shims on PATH (the httpc test pattern), no network. +// Byte-level python parity is enforced separately by +// binary/tests/test_hook_cli.py against the real python hook. + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +const suppressLine = "{\"suppressOutput\": true}\n" + +func writeScript(t *testing.T, path, body string) { + t.Helper() + if err := os.WriteFile(path, []byte(body), 0o755); err != nil { + t.Fatal(err) + } +} + +// installFakeBins shims curl and system_profiler. curl records argv/stdin +// and answers per CURL_STDOUT / CURL_EXIT; system_profiler prints a +// hardware-serial line per FAKE_SERIAL. +func installFakeBins(t *testing.T) (argsFile, stdinFile string) { + t.Helper() + dir := t.TempDir() + argsFile = filepath.Join(dir, "curl-args") + stdinFile = filepath.Join(dir, "curl-stdin") + writeScript(t, filepath.Join(dir, "curl"), `#!/bin/sh +for a in "$@"; do printf '%s\n' "$a"; done > "$CURL_ARGS_FILE" +cat > "$CURL_STDIN_FILE" +printf '%s' "$CURL_STDOUT" +exit "${CURL_EXIT:-0}" +`) + writeScript(t, filepath.Join(dir, "system_profiler"), `#!/bin/sh +printf ' Serial Number (system): %s\n' "$FAKE_SERIAL" +`) + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("CURL_ARGS_FILE", argsFile) + t.Setenv("CURL_STDIN_FILE", stdinFile) + t.Setenv("CURL_STDOUT", "") + t.Setenv("CURL_EXIT", "0") + t.Setenv("FAKE_SERIAL", "TESTSERIAL123") + return argsFile, stdinFile +} + +func setupHome(t *testing.T) string { + t.Helper() + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("UNBOUND_GATEWAY_URL", "https://gw.test") + t.Setenv("UNBOUND_CLAUDE_API_KEY", "test-key") + t.Setenv("ANTHROPIC_API_KEY", "") + return home +} + +func runHook(t *testing.T, payload string) string { + t.Helper() + var out bytes.Buffer + if code := Dispatch("claude-code", "", strings.NewReader(payload), &out); code != 0 { + t.Fatalf("exit code = %d, want 0", code) + } + return out.String() +} + +func auditLines(t *testing.T, home string) []string { + t.Helper() + data, err := os.ReadFile(filepath.Join(home, ".claude", "hooks", "agent-audit.log")) + if err != nil { + return nil + } + return strings.Split(strings.TrimRight(string(data), "\n"), "\n") +} + +func loadsObj(t *testing.T, s string) *pyjson.Object { + t.Helper() + v, err := pyjson.Loads([]byte(s)) + if err != nil { + t.Fatalf("Loads(%q): %v", s, err) + } + obj, ok := v.(*pyjson.Object) + if !ok { + t.Fatalf("not an object: %q", s) + } + return obj +} + +func waitForFile(t *testing.T, path string) string { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if data, err := os.ReadFile(path); err == nil { + return string(data) + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("file never appeared: %s", path) + return "" +} + +// --- entry semantics (main, lines 1607-1681) --- + +func TestEntryEmptyStdinSuppresses(t *testing.T) { + setupHome(t) + installFakeBins(t) + for _, in := range []string{"", " \n\t "} { + if got := runHook(t, in); got != suppressLine { + t.Errorf("stdin %q -> %q, want %q", in, got, suppressLine) + } + } +} + +func TestEntryMalformedJSONSuppresses(t *testing.T) { + setupHome(t) + installFakeBins(t) + if got := runHook(t, "not json at all"); got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } +} + +func TestEntryNonDictEventLogsAndSuppresses(t *testing.T) { + home := setupHome(t) + t.Setenv("UNBOUND_CLAUDE_API_KEY", "") // keep the error report offline + installFakeBins(t) + if got := runHook(t, "[1, 2]"); got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } + data, err := os.ReadFile(filepath.Join(home, ".claude", "hooks", "error.log")) + if err != nil || !strings.Contains(string(data), "Exception in main") { + t.Errorf("error.log = %q, %v", data, err) + } +} + +func TestEntryUnknownEventIsAuditLogged(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + if got := runHook(t, `{"hook_event_name": "SomethingNew", "session_id": "s1"}`); got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } + lines := auditLines(t, home) + if len(lines) != 1 { + t.Fatalf("audit lines = %d, want 1", len(lines)) + } + entry := loadsObj(t, lines[0]) + if v, _ := entry.Get("session_id"); v != "s1" { + t.Errorf("session_id = %v", v) + } + if _, ok := entry.Get("timestamp"); !ok { + t.Error("timestamp missing from audit entry") + } +} + +// --- PreToolUse (process_pre_tool_use, lines 834-967) --- + +func TestPreToolUseUncheckedToolSkipsGateway(t *testing.T) { + home := setupHome(t) + argsFile, _ := installFakeBins(t) + got := runHook(t, `{"hook_event_name": "PreToolUse", "session_id": "s1", "tool_name": "Grep", "tool_input": {"pattern": "x"}}`) + if got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } + if _, err := os.Stat(argsFile); err == nil { + t.Error("gateway was called for a non-allowed tool") + } + if lines := auditLines(t, home); lines != nil { + t.Errorf("PreToolUse must not audit-log, got %v", lines) + } +} + +func TestPreToolUseNativeToolSkippedOnFreshCache(t *testing.T) { + home := setupHome(t) + argsFile, _ := installFakeBins(t) + cacheFile := filepath.Join(home, ".claude", "hooks", ".policy_cache.json") + if err := os.MkdirAll(filepath.Dir(cacheFile), 0o755); err != nil { + t.Fatal(err) + } + cache := fmt.Sprintf(`{"last_synced": "%s", "tools_to_check": [], "policy_check_failure_action": "allow"}`, + time.Now().UTC().Format("2006-01-02T15:04:05.000000")+"Z") + if err := os.WriteFile(cacheFile, []byte(cache), 0o644); err != nil { + t.Fatal(err) + } + got := runHook(t, `{"hook_event_name": "PreToolUse", "session_id": "s1", "tool_name": "Read", "tool_input": {"file_path": "/tmp/x"}}`) + if got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } + if _, err := os.Stat(argsFile); err == nil { + t.Error("gateway was called for an unchecked native tool with a fresh cache") + } +} + +func TestPreToolUseDenySavesCacheAndTransforms(t *testing.T) { + home := setupHome(t) + _, stdinFile := installFakeBins(t) + t.Setenv("CURL_STDOUT", `{"decision": "deny", "reason": "blocked by policy", "additionalContext": "ctx", "tools_to_check": ["Bash"], "policy_check_failure_action": "block"}`) + + got := runHook(t, `{"hook_event_name": "PreToolUse", "session_id": "s1", "tool_name": "Bash", "tool_input": {"command": "git status"}}`) + want := `{"hookSpecificOutput": {"hookEventName": "PreToolUse", "permissionDecision": "deny", "permissionDecisionReason": "blocked by policy", "additionalContext": "ctx"}, "suppressOutput": true}` + "\n" + if got != want { + t.Errorf("stdout = %q, want %q", got, want) + } + + body := loadsObj(t, waitForFile(t, stdinFile)) + for key, wantVal := range map[string]any{ + "conversation_id": "s1", + "unbound_app_label": "claude-code", + "model": "auto", + "event_name": "tool_use", + "first_approval_check": true, + "pull_policies": true, + } { + if v, _ := body.Get(key); v != wantVal { + t.Errorf("request %s = %v, want %v", key, v, wantVal) + } + } + pretool := body.GetDefault("pre_tool_use_data", nil).(*pyjson.Object) + if cmd, _ := pretool.Get("command"); cmd != "git status" { + t.Errorf("command = %v", cmd) + } + + cacheData, err := os.ReadFile(filepath.Join(home, ".claude", "hooks", ".policy_cache.json")) + if err != nil { + t.Fatal(err) + } + cache := loadsObj(t, string(cacheData)) + if v, _ := cache.Get("policy_check_failure_action"); v != "block" { + t.Errorf("cached failure action = %v", v) + } + if !pyEq(cache.GetDefault("tools_to_check", nil), []any{"Bash"}) { + t.Errorf("cached tools_to_check = %v", cache.GetDefault("tools_to_check", nil)) + } +} + +func TestPreToolUseGatewayFailureFailsOpenByDefault(t *testing.T) { + setupHome(t) + installFakeBins(t) + t.Setenv("CURL_EXIT", "7") + got := runHook(t, `{"hook_event_name": "PreToolUse", "session_id": "s1", "tool_name": "Bash", "tool_input": {"command": "git status"}}`) + if got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } +} + +func TestPreToolUseGatewayFailureBlocksWhenCachedActionIsBlock(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + t.Setenv("CURL_EXIT", "7") + cacheFile := filepath.Join(home, ".claude", "hooks", ".policy_cache.json") + if err := os.MkdirAll(filepath.Dir(cacheFile), 0o755); err != nil { + t.Fatal(err) + } + // Stale on purpose: get_policy_check_failure_action ignores the TTL. + if err := os.WriteFile(cacheFile, []byte(`{"last_synced": "2020-01-01T00:00:00Z", "tools_to_check": [], "policy_check_failure_action": "block"}`), 0o644); err != nil { + t.Fatal(err) + } + got := runHook(t, `{"hook_event_name": "PreToolUse", "session_id": "s1", "tool_name": "Bash", "tool_input": {"command": "git status"}}`) + resp := loadsObj(t, got) + hso := resp.GetDefault("hookSpecificOutput", nil).(*pyjson.Object) + if v, _ := hso.Get("permissionDecision"); v != "deny" { + t.Errorf("permissionDecision = %v, want deny", v) + } + if v, _ := hso.Get("permissionDecisionReason"); v != "policy engine unavailable — please retry" { + t.Errorf("reason = %v", v) + } +} + +func TestPreToolUseApprovalRequiredWritesMarker(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + t.Setenv("CURL_STDOUT", `{"decision": "approval_required", "approvalCheck": {"policyIds": ["p1"], "applicationId": "app1", "requestId": "r1"}}`) + + got := runHook(t, `{"hook_event_name": "PreToolUse", "session_id": "s1", "tool_name": "Bash", "tool_input": {"command": "git status"}}`) + resp := loadsObj(t, got) + hso := resp.GetDefault("hookSpecificOutput", nil).(*pyjson.Object) + if v, _ := hso.Get("permissionDecision"); v != "deny" { + t.Errorf("permissionDecision = %v, want deny", v) + } + if v, _ := hso.Get("permissionDecisionReason"); v != "An approval request has been sent to your Slack DMs. Please approve it there." { + t.Errorf("reason = %v", v) + } + + marker, err := os.ReadFile(filepath.Join(home, ".claude", "hooks", ".approval_pending")) + if err != nil { + t.Fatal(err) + } + m := loadsObj(t, string(marker)) + if v, _ := m.Get("cmd"); v != approvalCmdHash("Bash:git status") { + t.Errorf("marker cmd = %v", v) + } + if !pyEq(m.GetDefault("policyIds", nil), []any{"p1"}) { + t.Errorf("marker policyIds = %v", m.GetDefault("policyIds", nil)) + } + if v, _ := m.Get("requestId"); v != "r1" { + t.Errorf("marker requestId = %v", v) + } +} + +func TestPreToolUseApprovalRetryPollsAndAllows(t *testing.T) { + if testing.Short() { + t.Skip("first poll interval is 3s") + } + home := setupHome(t) + argsFile, stdinFile := installFakeBins(t) + t.Setenv("CURL_STDOUT", `{"decision": "allow"}`) + + markerPath := filepath.Join(home, ".claude", "hooks", ".approval_pending") + if err := os.MkdirAll(filepath.Dir(markerPath), 0o755); err != nil { + t.Fatal(err) + } + marker := fmt.Sprintf(`{"cmd": "%s", "ts": %d.5, "policyIds": ["p1"], "applicationId": "app1", "requestId": "r1", "escalatedAdminContact": ""}`, + approvalCmdHash("Bash:git status"), time.Now().Unix()-10) + if err := os.WriteFile(markerPath, []byte(marker), 0o644); err != nil { + t.Fatal(err) + } + + got := runHook(t, `{"hook_event_name": "PreToolUse", "session_id": "s1", "tool_name": "Bash", "tool_input": {"command": "git status"}}`) + resp := loadsObj(t, got) + hso := resp.GetDefault("hookSpecificOutput", nil).(*pyjson.Object) + if v, _ := hso.Get("permissionDecision"); v != "allow" { + t.Errorf("permissionDecision = %v, want allow", v) + } + if _, err := os.Stat(markerPath); err == nil { + t.Error("approval marker was not cleared") + } + args := waitForFile(t, argsFile) + if !strings.Contains(args, "https://gw.test/v1/hooks/pretool/approval-status") { + t.Errorf("poll did not hit approval-status: %q", args) + } + wantBody := `{"policyIds": ["p1"], "applicationId": "app1", "requestId": "r1"}` + if body := waitForFile(t, stdinFile); body != wantBody { + t.Errorf("poll body = %q, want %q", body, wantBody) + } +} + +// --- UserPromptSubmit (process_user_prompt_submit, lines 970-986) --- + +func TestUserPromptSubmitDenyBlocksAndLogs(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + t.Setenv("CURL_STDOUT", `{"decision": "deny", "reason": "nope"}`) + got := runHook(t, `{"hook_event_name": "UserPromptSubmit", "session_id": "s1", "prompt": "hello"}`) + want := `{"decision": "block", "reason": "nope", "suppressOutput": true}` + "\n" + if got != want { + t.Errorf("stdout = %q, want %q", got, want) + } + if lines := auditLines(t, home); len(lines) != 1 { + t.Errorf("audit lines = %d, want 1", len(lines)) + } +} + +func TestUserPromptSubmitAllowFallsThroughAndLogs(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + t.Setenv("CURL_STDOUT", `{"decision": "allow"}`) + got := runHook(t, `{"hook_event_name": "UserPromptSubmit", "session_id": "s1", "prompt": "hello"}`) + if got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } + lines := auditLines(t, home) + if len(lines) != 1 { + t.Fatalf("audit lines = %d, want 1", len(lines)) + } + event := loadsObj(t, lines[0]).GetDefault("event", nil).(*pyjson.Object) + if v, _ := event.Get("prompt"); v != "hello" { + t.Errorf("logged prompt = %v", v) + } +} + +// --- PostToolUse / audit append --- + +func TestPostToolUseAppendsAudit(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + got := runHook(t, `{"hook_event_name": "PostToolUse", "session_id": "s1", "tool_name": "Bash", "tool_input": {"command": "ls"}, "tool_response": {"output": "ok"}}`) + if got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } + if lines := auditLines(t, home); len(lines) != 1 { + t.Errorf("audit lines = %d, want 1", len(lines)) + } +} + +// --- Stop (process_stop_event + build_llm_exchange, lines 989-1178) --- + +func TestStopSendsExchange(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("deterministic device serial needs the system_profiler shim") + } + home := setupHome(t) + argsFile, stdinFile := installFakeBins(t) + + hooksDir := filepath.Join(home, ".claude", "hooks") + if err := os.MkdirAll(hooksDir, 0o755); err != nil { + t.Fatal(err) + } + auditSeed := `{"timestamp": "2026-01-01T00:00:00Z", "session_id": "s1", "event": {"hook_event_name": "UserPromptSubmit", "session_id": "s1", "prompt": "hi"}} +{"timestamp": "2026-01-01T00:00:00Z", "session_id": "s1", "event": {"hook_event_name": "PostToolUse", "session_id": "s1", "tool_name": "Bash", "tool_input": {"command": "ls"}, "tool_response": {"output": "ok"}}} +` + if err := os.WriteFile(filepath.Join(hooksDir, "agent-audit.log"), []byte(auditSeed), 0o644); err != nil { + t.Fatal(err) + } + + transcriptPath := filepath.Join(home, "transcript.jsonl") + transcriptSeed := `{"type": "assistant", "timestamp": "2026-01-01T00:00:01Z", "message": {"role": "assistant", "model": "claude-test-1", "content": [{"type": "text", "text": "From transcript"}], "usage": {"input_tokens": 5, "output_tokens": 7}}} +` + if err := os.WriteFile(transcriptPath, []byte(transcriptSeed), 0o644); err != nil { + t.Fatal(err) + } + + payload := fmt.Sprintf(`{"hook_event_name": "Stop", "session_id": "s1", "transcript_path": %q, "last_assistant_message": "done"}`, transcriptPath) + if got := runHook(t, payload); got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } + + args := waitForFile(t, argsFile) + if !strings.Contains(args, "https://gw.test/v1/hooks/claude") { + t.Errorf("exchange did not hit /v1/hooks/claude: %q", args) + } + want := `{"conversation_id": "s1", "model": "claude-test-1", ` + + `"messages": [{"role": "user", "content": "hi"}, ` + + `{"role": "assistant", "content": "From transcript\n\ndone", ` + + `"tool_use": [{"type": "PostToolUse", "tool_name": "Bash", "tool_input": {"command": "ls"}, "tool_response": {"output": "ok"}}]}], ` + + `"permission_mode": "default", ` + + `"account_identity": {"org_id": null, "plan": null, "auth_mode": null, "user_email": null, "email_domain": null, "device_serial": "TESTSERIAL123"}, ` + + `"usage": {"input_tokens": 5, "output_tokens": 7, "cache_read_input_tokens": 0, "cache_creation_input_tokens": 0, "total_tokens": 12}}` + if body := waitForFile(t, stdinFile); body != want { + t.Errorf("exchange body =\n%s\nwant\n%s", body, want) + } +} + +func TestStopWithoutUserPromptSendsNothing(t *testing.T) { + setupHome(t) + argsFile, _ := installFakeBins(t) + got := runHook(t, `{"hook_event_name": "Stop", "session_id": "s1", "transcript_path": "/nonexistent.jsonl", "last_assistant_message": "done"}`) + if got != suppressLine { + t.Errorf("got %q, want %q", got, suppressLine) + } + // < 2 messages -> build_llm_exchange returns nothing, no POST. + if _, err := os.Stat(argsFile); err == nil { + t.Error("exchange was sent despite < 2 messages") + } +} + +// --- SessionStart: serial probe + discovery dispatch (lines 1629-1634) --- + +func seedDiscoveryEnabled(t *testing.T, home string) { + t.Helper() + unbound := filepath.Join(home, ".unbound") + if err := os.MkdirAll(unbound, 0o755); err != nil { + t.Fatal(err) + } + cfg := `{"api_key": "k1", "base_url": "https://backend.example"}` + if err := os.WriteFile(filepath.Join(unbound, "config.json"), []byte(cfg), 0o644); err != nil { + t.Fatal(err) + } + cache := fmt.Sprintf(`{"hook_discovery": {"enabled": true, "fetched_at": "%s"}}`, utcStamp(time.Now())) + if err := os.WriteFile(filepath.Join(unbound, "discovery-cache.json"), []byte(cache), 0o644); err != nil { + t.Fatal(err) + } +} + +func overrideDiscoveryBin(t *testing.T, path string) { + t.Helper() + old := frozenDiscoveryBin + frozenDiscoveryBin = path + t.Cleanup(func() { frozenDiscoveryBin = old }) +} + +func TestSessionStartDispatchesDiscovery(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + seedDiscoveryEnabled(t, home) + + outFile := filepath.Join(t.TempDir(), "discovery-out") + t.Setenv("DISCOVERY_OUT", outFile) + binPath := filepath.Join(t.TempDir(), "unbound-discovery") + writeScript(t, binPath, `#!/bin/sh +printf '%s %s' "$UNBOUND_API_KEY" "$*" > "$DISCOVERY_OUT" +`) + overrideDiscoveryBin(t, binPath) + + if got := runHook(t, `{"hook_event_name": "SessionStart", "session_id": "s1"}`); got != "{}\n" { + t.Errorf("got %q, want {}", got) + } + if out := waitForFile(t, outFile); out != "k1 --domain https://backend.example" { + t.Errorf("discovery argv/env = %q", out) + } + cacheData, err := os.ReadFile(filepath.Join(home, ".unbound", "discovery-cache.json")) + if err != nil || !strings.Contains(string(cacheData), `"last_run_at"`) { + t.Errorf("last_run_at not stamped: %q, %v", cacheData, err) + } + if _, err := os.Stat(filepath.Join(home, ".unbound", "discovery.dispatch.lock")); err == nil { + t.Error("dispatch lock not released") + } + if serial, err := os.ReadFile(filepath.Join(home, ".unbound", "identity.json")); err != nil || + !strings.Contains(string(serial), "TESTSERIAL123") { + if runtime.GOOS == "darwin" { + t.Errorf("identity cache = %q, %v", serial, err) + } + } +} + +func TestSessionStartDiscoveryDebounced(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + seedDiscoveryEnabled(t, home) + cache := fmt.Sprintf(`{"hook_discovery": {"enabled": true, "fetched_at": "%s"}, "last_run_at": "%s"}`, + utcStamp(time.Now()), utcStamp(time.Now())) + if err := os.WriteFile(filepath.Join(home, ".unbound", "discovery-cache.json"), []byte(cache), 0o644); err != nil { + t.Fatal(err) + } + + outFile := filepath.Join(t.TempDir(), "discovery-out") + t.Setenv("DISCOVERY_OUT", outFile) + binPath := filepath.Join(t.TempDir(), "unbound-discovery") + writeScript(t, binPath, "#!/bin/sh\ntouch \"$DISCOVERY_OUT\"\n") + overrideDiscoveryBin(t, binPath) + + if got := runHook(t, `{"hook_event_name": "SessionStart"}`); got != "{}\n" { + t.Errorf("got %q, want {}", got) + } + time.Sleep(300 * time.Millisecond) + if _, err := os.Stat(outFile); err == nil { + t.Error("discovery dispatched despite fresh last_run_at") + } +} + +func TestSessionStartDiscoveryStaleLockSkips(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + seedDiscoveryEnabled(t, home) + if err := os.WriteFile(filepath.Join(home, ".unbound", "discovery.lock"), nil, 0o644); err != nil { + t.Fatal(err) + } + + outFile := filepath.Join(t.TempDir(), "discovery-out") + t.Setenv("DISCOVERY_OUT", outFile) + binPath := filepath.Join(t.TempDir(), "unbound-discovery") + writeScript(t, binPath, "#!/bin/sh\ntouch \"$DISCOVERY_OUT\"\n") + overrideDiscoveryBin(t, binPath) + + if got := runHook(t, `{"hook_event_name": "SessionStart"}`); got != "{}\n" { + t.Errorf("got %q, want {}", got) + } + time.Sleep(300 * time.Millisecond) + if _, err := os.Stat(outFile); err == nil { + t.Error("discovery dispatched despite a fresh discovery.lock") + } +} + +func TestSessionStartDiscoveryBinaryMissingLogsAndSkips(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + seedDiscoveryEnabled(t, home) + overrideDiscoveryBin(t, filepath.Join(t.TempDir(), "missing-binary")) + + if got := runHook(t, `{"hook_event_name": "SessionStart"}`); got != "{}\n" { + t.Errorf("got %q, want {}", got) + } + data, err := os.ReadFile(filepath.Join(home, ".claude", "hooks", "error.log")) + if err != nil || !strings.Contains(string(data), "discovery binary missing") { + t.Errorf("error.log = %q, %v", data, err) + } + if _, err := os.Stat(filepath.Join(home, ".local", "share", "unbound", "install.sh")); err == nil { + t.Error("frozen hook downloaded install.sh") + } +} + +// --- device serial + identity (lines 674-831) --- + +func TestDeviceSerialProbeCachesAndPlaceholderRejected(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("system_profiler shim is the darwin probe path") + } + home := setupHome(t) + installFakeBins(t) + c, err := newClaudeCodeHook() + if err != nil { + t.Fatal(err) + } + + if got := c.deviceSerial(false); got != "" { + t.Errorf("probe=false must not probe, got %q", got) + } + + t.Setenv("FAKE_SERIAL", "To be filled by O.E.M.") + if got := c.deviceSerial(true); got != "" { + t.Errorf("placeholder serial accepted: %q", got) + } + if _, err := os.Stat(filepath.Join(home, ".unbound", "identity.json")); err == nil { + t.Error("identity cache written for a placeholder serial") + } + + t.Setenv("FAKE_SERIAL", "REALSERIAL9") + if got := c.deviceSerial(true); got != "REALSERIAL9" { + t.Errorf("serial = %q", got) + } + // Cached value wins over a fresh probe now. + t.Setenv("FAKE_SERIAL", "OTHER") + if got := c.deviceSerial(true); got != "REALSERIAL9" { + t.Errorf("cache not used: %q", got) + } +} + +func TestReadAccountIdentity(t *testing.T) { + home := setupHome(t) + installFakeBins(t) + c, err := newClaudeCodeHook() + if err != nil { + t.Fatal(err) + } + + dumps := func(o *pyjson.Object) string { + s, err := pyjson.Dumps(o) + if err != nil { + t.Fatal(err) + } + return s + } + + // Missing ~/.claude.json: every field None. + want := `{"org_id": null, "plan": null, "auth_mode": null, "user_email": null, "email_domain": null}` + if got := dumps(c.readAccountIdentity()); got != want { + t.Errorf("got %s, want %s", got, want) + } + + claudeJSON := filepath.Join(home, ".claude.json") + oauth := `{"oauthAccount": {"organizationUuid": "org-1", "organizationType": "team", "emailAddress": "Dev@Example.COM "}}` + if err := os.WriteFile(claudeJSON, []byte(oauth), 0o644); err != nil { + t.Fatal(err) + } + want = `{"org_id": "org-1", "plan": "team", "auth_mode": "subscription", "user_email": "Dev@Example.COM ", "email_domain": "example.com"}` + if got := dumps(c.readAccountIdentity()); got != want { + t.Errorf("got %s, want %s", got, want) + } + + if err := os.WriteFile(claudeJSON, []byte(`{"customApiKeyResponses": {"approved": ["sk"]}}`), 0o644); err != nil { + t.Fatal(err) + } + want = `{"org_id": null, "plan": null, "auth_mode": "api_key", "user_email": null, "email_domain": null}` + if got := dumps(c.readAccountIdentity()); got != want { + t.Errorf("got %s, want %s", got, want) + } + + if err := os.WriteFile(claudeJSON, []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + t.Setenv("ANTHROPIC_API_KEY", "sk-ant") + if got := dumps(c.readAccountIdentity()); got != want { + t.Errorf("got %s, want %s", got, want) + } +} diff --git a/go/internal/hooks/codex.go b/go/internal/hooks/codex.go new file mode 100644 index 00000000..6757fdf9 --- /dev/null +++ b/go/internal/hooks/codex.go @@ -0,0 +1,8 @@ +package hooks + +import "io" + +// TODO: port codex/hooks/unbound.py — fail-open stub until then. +func runCodex(event string, stdin io.Reader, stdout io.Writer) int { + return failOpenStub(stdin, stdout) +} diff --git a/go/internal/hooks/copilot.go b/go/internal/hooks/copilot.go new file mode 100644 index 00000000..c64bc757 --- /dev/null +++ b/go/internal/hooks/copilot.go @@ -0,0 +1,8 @@ +package hooks + +import "io" + +// TODO: port copilot/hooks/unbound.py — fail-open stub until then. +func runCopilot(event string, stdin io.Reader, stdout io.Writer) int { + return failOpenStub(stdin, stdout) +} diff --git a/go/internal/hooks/cursor.go b/go/internal/hooks/cursor.go new file mode 100644 index 00000000..f0360c0e --- /dev/null +++ b/go/internal/hooks/cursor.go @@ -0,0 +1,9 @@ +package hooks + +import "io" + +// TODO: port cursor/unbound.py — fail-open stub until then. Note: the real +// module exits 2 on deny; that contract exit must be preserved in the port. +func runCursor(event string, stdin io.Reader, stdout io.Writer) int { + return failOpenStub(stdin, stdout) +} diff --git a/go/internal/hooks/hooks.go b/go/internal/hooks/hooks.go new file mode 100644 index 00000000..70546bc4 --- /dev/null +++ b/go/internal/hooks/hooks.go @@ -0,0 +1,50 @@ +// Package hooks dispatches `unbound-hook hook []`: the tool's +// handler reads the event JSON from stdin and prints its response JSON to +// stdout, exactly as the python serving path does. The argument +// exists because managed settings register one command per event; handlers +// dispatch on the stdin payload's hook_event_name, so argv is +// routing/diagnostics only. +// +// Fail-open is non-negotiable here: this process sits between the user and +// their editor. Any dispatcher-level failure prints neutral JSON and exits 0. +// Mirrors binary/src/unbound_hook/hook_cmd.py. +package hooks + +import ( + "fmt" + "io" +) + +type handler func(event string, stdin io.Reader, stdout io.Writer) int + +var handlers = map[string]handler{ + "claude-code": runClaudeCode, + "cursor": runCursor, + "copilot": runCopilot, + "codex": runCodex, +} + +// Dispatch runs the tool's hook handler. Unknown/missing tool or a panic +// anywhere below never blocks the editor: neutral JSON, exit 0. +func Dispatch(tool, event string, stdin io.Reader, stdout io.Writer) (code int) { + defer func() { + if r := recover(); r != nil { + fmt.Fprintln(stdout, "{}") + code = 0 + } + }() + h, ok := handlers[tool] + if !ok { + fmt.Fprintln(stdout, "{}") + return 0 + } + return h(event, stdin, stdout) +} + +// failOpenStub is the shared phase-1 handler body: consume the event JSON, +// answer with neutral JSON, exit 0. +func failOpenStub(stdin io.Reader, stdout io.Writer) int { + _, _ = io.Copy(io.Discard, stdin) + fmt.Fprintln(stdout, "{}") + return 0 +} diff --git a/go/internal/hooks/pyutil.go b/go/internal/hooks/pyutil.go new file mode 100644 index 00000000..d87d2658 --- /dev/null +++ b/go/internal/hooks/pyutil.go @@ -0,0 +1,211 @@ +package hooks + +// Python-semantics helpers for the hook ports. The python originals run +// under main()'s blanket try/except; raise() panics stand in for the +// exceptions that escape a handler and are recovered at the port's main() +// equivalent, which logs and prints the neutral response like the python +// except branch does. + +import ( + "fmt" + "strconv" + "strings" + + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +// pyRaise is the panic payload standing in for an uncaught python exception. +type pyRaise struct{ msg string } + +func (e pyRaise) String() string { return e.msg } + +func raise(format string, args ...any) { + panic(pyRaise{fmt.Sprintf(format, args...)}) +} + +// mustObj mirrors dict-method access on a value assumed to be a dict: any +// other type raises (python AttributeError). +func mustObj(v any) *pyjson.Object { + obj, ok := v.(*pyjson.Object) + if !ok { + raise("'%T' object has no attribute 'get'", v) + } + return obj +} + +// objGet mirrors value.get(key, def) where value must be a dict. +func objGet(v any, key string, def any) any { + return mustObj(v).GetDefault(key, def) +} + +// pyIn mirrors `key in container` for string keys: dict key lookup, +// substring test on str, membership on list; anything else raises +// (python TypeError). +func pyIn(key string, container any) bool { + switch t := container.(type) { + case *pyjson.Object: + _, ok := t.Get(key) + return ok + case string: + return strings.Contains(t, key) + case []any: + for _, e := range t { + if pyEq(e, key) { + return true + } + } + return false + } + raise("argument of type '%T' is not iterable", container) + return false +} + +// pyIndex mirrors container[key] with a string key: only dicts support it; +// a missing key raises like python KeyError. +func pyIndex(container any, key string) any { + obj, ok := container.(*pyjson.Object) + if !ok { + raise("'%T' object is not subscriptable with a str", container) + } + v, has := obj.Get(key) + if !has { + raise("KeyError: %s", key) + } + return v +} + +// pyStr approximates python str() for f-string interpolation. Hook inputs +// make these strings in practice; non-strings fall back to their JSON form +// (python would render repr-style — accepted divergence). +func pyStr(v any) string { + switch t := v.(type) { + case nil: + return "None" + case bool: + if t { + return "True" + } + return "False" + case string: + return t + case pyjson.Number: + return string(t) + } + if s, err := pyjson.Dumps(v); err == nil { + return s + } + return fmt.Sprintf("%v", v) +} + +// pyEq mirrors python ==: deep equality with cross-type numeric equality +// (True == 1, 1 == 1.0) and order-insensitive dict comparison. +func pyEq(a, b any) bool { + an, aNum := numVal(a) + bn, bNum := numVal(b) + if aNum || bNum { + return aNum && bNum && numEq(an, bn) + } + switch ta := a.(type) { + case nil: + return b == nil + case string: + tb, ok := b.(string) + return ok && ta == tb + case []any: + tb, ok := b.([]any) + if !ok || len(ta) != len(tb) { + return false + } + for i := range ta { + if !pyEq(ta[i], tb[i]) { + return false + } + } + return true + case *pyjson.Object: + tb, ok := b.(*pyjson.Object) + if !ok || ta.Len() != tb.Len() { + return false + } + for _, m := range ta.Members() { + bv, has := tb.Get(m.Key) + if !has || !pyEq(m.Value, bv) { + return false + } + } + return true + } + return false +} + +type pyNum struct { + f float64 + i int64 + isInt bool +} + +func numVal(v any) (pyNum, bool) { + switch t := v.(type) { + case bool: + if t { + return pyNum{1, 1, true}, true + } + return pyNum{0, 0, true}, true + case int: + return pyNum{float64(t), int64(t), true}, true + case int64: + return pyNum{float64(t), t, true}, true + case float64: + return pyNum{t, 0, false}, true + case pyjson.Number: + s := string(t) + if !strings.ContainsAny(s, ".eE") { + if i, err := strconv.ParseInt(s, 10, 64); err == nil { + return pyNum{float64(i), i, true}, true + } + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return pyNum{}, false + } + return pyNum{f, 0, false}, true + } + return pyNum{}, false +} + +func numEq(a, b pyNum) bool { + if a.isInt && b.isInt { + return a.i == b.i + } + return a.f == b.f +} + +// toFloat mirrors python float coercion in arithmetic (time.time() - ts): +// non-numeric operands make the caller raise. +func toFloat(v any) (float64, bool) { + n, ok := numVal(v) + if !ok { + return 0, false + } + return n.f, ok +} + +// copyObject mirrors dict(d): a shallow copy preserving insertion order. +func copyObject(o *pyjson.Object) *pyjson.Object { + out := pyjson.NewObject() + for _, m := range o.Members() { + out.Set(m.Key, m.Value) + } + return out +} + +// posixDirname mirrors posixpath.dirname for the project-path walk in +// _read_mcp_server_config (claude-code/hooks/unbound.py line 658). +func posixDirname(p string) string { + i := strings.LastIndexByte(p, '/') + 1 + head := p[:i] + if head != "" && head != strings.Repeat("/", len(head)) { + head = strings.TrimRight(head, "/") + } + return head +} diff --git a/go/internal/hooks/pyutil_test.go b/go/internal/hooks/pyutil_test.go new file mode 100644 index 00000000..c33c0891 --- /dev/null +++ b/go/internal/hooks/pyutil_test.go @@ -0,0 +1,81 @@ +package hooks + +import ( + "testing" + + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +func TestPosixDirname(t *testing.T) { + cases := map[string]string{ + "/a/b/c": "/a/b", + "/a/b": "/a", + "/a": "/", + "/": "/", + "a/b": "a", + "a": "", + "": "", + "//a": "//", + } + for in, want := range cases { + if got := posixDirname(in); got != want { + t.Errorf("posixDirname(%q) = %q, want %q", in, got, want) + } + } +} + +func TestPyEq(t *testing.T) { + cases := []struct { + a, b any + want bool + }{ + {pyjson.Number("1"), pyjson.Number("1.0"), true}, // python 1 == 1.0 + {true, pyjson.Number("1"), true}, // python True == 1 + {false, pyjson.Number("0"), true}, + {pyjson.Number("1"), "1", false}, + {"x", "x", true}, + {nil, nil, true}, + {nil, "", false}, + {[]any{pyjson.Number("1"), "a"}, []any{pyjson.Number("1"), "a"}, true}, + {[]any{}, []any{}, true}, + { + pyjson.NewObject().Set("a", pyjson.Number("1")).Set("b", "x"), + pyjson.NewObject().Set("b", "x").Set("a", pyjson.Number("1.0")), + true, // dict equality is order-insensitive + }, + { + pyjson.NewObject().Set("a", pyjson.Number("1")), + pyjson.NewObject().Set("a", pyjson.Number("2")), + false, + }, + } + for _, c := range cases { + if got := pyEq(c.a, c.b); got != c.want { + t.Errorf("pyEq(%v, %v) = %v, want %v", c.a, c.b, got, c.want) + } + } +} + +func TestPyIn(t *testing.T) { + if !pyIn("comm", "this command") { // python: substring test on str + t.Error("substring containment failed") + } + if pyIn("xyz", "this command") { + t.Error("false substring matched") + } + if !pyIn("a", []any{"b", "a"}) { + t.Error("list membership failed") + } + if !pyIn("k", pyjson.NewObject().Set("k", nil)) { + t.Error("dict key with None value must count as present") + } +} + +func TestNextPollInterval(t *testing.T) { + cases := map[float64]int{0: 3, 299: 3, 300: 15, 1799: 15, 1800: 60, 7199: 60, 7200: 120, 99999: 120} + for elapsed, want := range cases { + if got := nextPollInterval(elapsed); got != want { + t.Errorf("nextPollInterval(%v) = %d, want %d", elapsed, got, want) + } + } +} diff --git a/go/internal/httpc/httpc.go b/go/internal/httpc/httpc.go new file mode 100644 index 00000000..03aef521 --- /dev/null +++ b/go/internal/httpc/httpc.go @@ -0,0 +1,141 @@ +// Package httpc shells out to curl for all HTTP, exactly like the python +// hook modules — never net/http. This is a deliberate house rule: corporate +// TLS interception (Zscaler etc.) trusts the system curl's CA handling +// where a static Go TLS stack would fail closed. +// +// Each helper mirrors a python curl invocation byte-for-byte in argv order: +// +// - PostJSON mirrors send_to_hook_api / send_to_api / poll_approval_status +// (claude-code/hooks/unbound.py lines 523-531, 564-572, 1083-1091): +// curl -fsSL -X POST -H "Authorization: Bearer " +// -H "Content-Type: application/json" --data-binary @- , +// body on stdin, subprocess timeout per caller. +// - Get mirrors _hook_discovery_enabled_for_org (lines 1383-1389): +// curl -fsSL -H "Authorization: Bearer " --max-time . +// - Fetch mirrors _download_latest_hook (lines 1253-1255): +// curl -fsSL --max-time . +// - Download mirrors the install.sh refresh (lines 1561-1563): +// curl -fsSL -o . +// - PostJSONDetached mirrors report_error_to_gateway's fire-and-forget +// Popen (lines 103-113): start curl, write the body to its stdin, +// close, never wait. +// +// Fail-open contract: any failure here (curl missing, non-zero exit, +// timeout) is an error/exit-code the callers treat as allow/skip — the +// hook process must never block the editor on transport problems. +package httpc + +import ( + "bytes" + "context" + "errors" + "fmt" + "os/exec" + "strconv" + "time" +) + +// Result is the captured outcome of a curl run, mirroring python +// subprocess.run(capture_output=True): callers check ExitCode == 0 and +// Stdout themselves (e.g. `result.returncode == 0 and result.stdout`). +type Result struct { + ExitCode int + Stdout []byte + Stderr []byte +} + +// run executes curl with args, mirroring subprocess.run(..., timeout=N): +// a non-zero curl exit is a Result (not an error); spawn failures and the +// timeout kill are errors (python's except branch / TimeoutExpired). +func run(args []string, stdin []byte, timeout time.Duration) (Result, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, "curl", args...) + // After the deadline kill, don't wait forever for pipe EOF (curl spawns + // no children, so in practice the pipes close with the process). + cmd.WaitDelay = time.Second + if stdin != nil { + cmd.Stdin = bytes.NewReader(stdin) + } + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if ctx.Err() == context.DeadlineExceeded { + return Result{}, fmt.Errorf("httpc: curl timed out after %s", timeout) + } + res := Result{Stdout: stdout.Bytes(), Stderr: stderr.Bytes()} + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + res.ExitCode = exitErr.ExitCode() + return res, nil + } + return Result{}, err + } + return res, nil +} + +// PostJSON POSTs body as application/json with a Bearer api key. The body +// travels via stdin (--data-binary @-) so it never appears in argv. +func PostJSON(url, apiKey string, body []byte, timeout time.Duration) (Result, error) { + args := []string{"-fsSL", "-X", "POST", + "-H", "Authorization: Bearer " + apiKey, + "-H", "Content-Type: application/json", + "--data-binary", "@-", url} + if body == nil { + body = []byte{} + } + return run(args, body, timeout) +} + +// Get performs an authenticated GET with curl's own --max-time cap plus the +// outer subprocess timeout (python passes both, e.g. --max-time 5 / timeout=8). +func Get(url, apiKey string, maxTimeSecs int, timeout time.Duration) (Result, error) { + args := []string{"-fsSL", + "-H", "Authorization: Bearer " + apiKey, + "--max-time", strconv.Itoa(maxTimeSecs), + url} + return run(args, nil, timeout) +} + +// Fetch GETs a URL with no auth header (self-update payload download). +func Fetch(url string, maxTimeSecs int, timeout time.Duration) (Result, error) { + args := []string{"-fsSL", "--max-time", strconv.Itoa(maxTimeSecs), url} + return run(args, nil, timeout) +} + +// Download GETs a URL straight to a file (-o dest), no auth header. +func Download(url, dest string, timeout time.Duration) (Result, error) { + args := []string{"-fsSL", "-o", dest, url} + return run(args, nil, timeout) +} + +// PostJSONDetached fires a POST and returns without waiting, like python's +// Popen with DEVNULL stdio: the curl child outlives the hook process. The +// body is written to curl's stdin synchronously (small payloads only) and +// a background goroutine reaps the child if we are still alive when it +// exits. Errors are returned for callers that log; none ever block. +func PostJSONDetached(url, apiKey string, body []byte) error { + cmd := exec.Command("curl", "-fsSL", "-X", "POST", + "-H", "Authorization: Bearer "+apiKey, + "-H", "Content-Type: application/json", + "--data-binary", "@-", url) + stdin, err := cmd.StdinPipe() + if err != nil { + return err + } + if err := cmd.Start(); err != nil { + stdin.Close() + return err + } + _, werr := stdin.Write(body) + cerr := stdin.Close() + go func() { _ = cmd.Wait() }() + if werr != nil { + return werr + } + return cerr +} diff --git a/go/internal/httpc/httpc_test.go b/go/internal/httpc/httpc_test.go new file mode 100644 index 00000000..b347cc1a --- /dev/null +++ b/go/internal/httpc/httpc_test.go @@ -0,0 +1,164 @@ +package httpc + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// installFakeCurl puts an executable `curl` shim first on PATH that records +// its argv and stdin, then behaves per CURL_STDOUT / CURL_EXIT. +func installFakeCurl(t *testing.T) (argsFile, stdinFile string) { + t.Helper() + dir := t.TempDir() + argsFile = filepath.Join(dir, "args") + stdinFile = filepath.Join(dir, "stdin") + script := `#!/bin/sh +for a in "$@"; do printf '%s\n' "$a"; done > "$CURL_ARGS_FILE" +cat > "$CURL_STDIN_FILE" +[ -n "$CURL_SLEEP" ] && sleep "$CURL_SLEEP" +printf '%s' "$CURL_STDOUT" +[ -n "$CURL_STDERR" ] && printf '%s' "$CURL_STDERR" >&2 +exit "${CURL_EXIT:-0}" +` + if err := os.WriteFile(filepath.Join(dir, "curl"), []byte(script), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("CURL_ARGS_FILE", argsFile) + t.Setenv("CURL_STDIN_FILE", stdinFile) + t.Setenv("CURL_STDOUT", "") + t.Setenv("CURL_STDERR", "") + t.Setenv("CURL_EXIT", "0") + t.Setenv("CURL_SLEEP", "") + return argsFile, stdinFile +} + +func recordedArgs(t *testing.T, argsFile string) []string { + t.Helper() + data, err := os.ReadFile(argsFile) + if err != nil { + t.Fatal(err) + } + return strings.Split(strings.TrimRight(string(data), "\n"), "\n") +} + +func TestPostJSONArgsMirrorPython(t *testing.T) { + argsFile, stdinFile := installFakeCurl(t) + t.Setenv("CURL_STDOUT", `{"decision": "allow"}`) + + res, err := PostJSON("https://gw.example.com/v1/hooks/pretool", "sk-123", + []byte(`{"a": 1}`), 20*time.Second) + if err != nil { + t.Fatal(err) + } + if res.ExitCode != 0 || string(res.Stdout) != `{"decision": "allow"}` { + t.Errorf("Result = %+v", res) + } + + want := []string{"-fsSL", "-X", "POST", + "-H", "Authorization: Bearer sk-123", + "-H", "Content-Type: application/json", + "--data-binary", "@-", + "https://gw.example.com/v1/hooks/pretool"} + got := recordedArgs(t, argsFile) + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Errorf("argv = %q, want %q", got, want) + } + body, err := os.ReadFile(stdinFile) + if err != nil { + t.Fatal(err) + } + if string(body) != `{"a": 1}` { + t.Errorf("stdin = %q", body) + } +} + +func TestGetArgsMirrorPython(t *testing.T) { + argsFile, _ := installFakeCurl(t) + t.Setenv("CURL_STDOUT", `{"enabled": true}`) + + res, err := Get("https://gw.example.com/v1/hooks/discovery-enabled", "sk-123", 5, 8*time.Second) + if err != nil { + t.Fatal(err) + } + if res.ExitCode != 0 || string(res.Stdout) != `{"enabled": true}` { + t.Errorf("Result = %+v", res) + } + want := []string{"-fsSL", + "-H", "Authorization: Bearer sk-123", + "--max-time", "5", + "https://gw.example.com/v1/hooks/discovery-enabled"} + got := recordedArgs(t, argsFile) + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Errorf("argv = %q, want %q", got, want) + } +} + +func TestFetchArgsMirrorPython(t *testing.T) { + argsFile, _ := installFakeCurl(t) + if _, err := Fetch("https://raw.example.com/unbound.py", 10, 15*time.Second); err != nil { + t.Fatal(err) + } + want := []string{"-fsSL", "--max-time", "10", "https://raw.example.com/unbound.py"} + got := recordedArgs(t, argsFile) + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Errorf("argv = %q, want %q", got, want) + } +} + +func TestDownloadArgsMirrorPython(t *testing.T) { + argsFile, _ := installFakeCurl(t) + if _, err := Download("https://raw.example.com/install.sh", "/tmp/install.tmp", 30*time.Second); err != nil { + t.Fatal(err) + } + want := []string{"-fsSL", "-o", "/tmp/install.tmp", "https://raw.example.com/install.sh"} + got := recordedArgs(t, argsFile) + if strings.Join(got, "\x00") != strings.Join(want, "\x00") { + t.Errorf("argv = %q, want %q", got, want) + } +} + +func TestNonZeroExitIsResultNotError(t *testing.T) { + installFakeCurl(t) + t.Setenv("CURL_EXIT", "22") + t.Setenv("CURL_STDERR", "curl: (22) The requested URL returned error: 500") + + res, err := PostJSON("https://gw.example.com/x", "k", []byte("{}"), 10*time.Second) + if err != nil { + t.Fatalf("non-zero exit must not be an error (python checks returncode): %v", err) + } + if res.ExitCode != 22 { + t.Errorf("ExitCode = %d, want 22", res.ExitCode) + } + if !strings.Contains(string(res.Stderr), "(22)") { + t.Errorf("Stderr = %q", res.Stderr) + } +} + +func TestTimeoutIsError(t *testing.T) { + installFakeCurl(t) + t.Setenv("CURL_SLEEP", "5") + if _, err := Get("https://gw.example.com/x", "k", 5, 200*time.Millisecond); err == nil { + t.Error("expected timeout error (python TimeoutExpired)") + } +} + +func TestPostJSONDetachedDeliversBody(t *testing.T) { + _, stdinFile := installFakeCurl(t) + if err := PostJSONDetached("https://gw.example.com/v1/hooks/errors", "k", []byte(`{"errors": []}`)); err != nil { + t.Fatal(err) + } + deadline := time.Now().Add(5 * time.Second) + for { + if data, err := os.ReadFile(stdinFile); err == nil && string(data) == `{"errors": []}` { + return + } + if time.Now().After(deadline) { + t.Fatal("detached curl never received the body") + } + time.Sleep(10 * time.Millisecond) + } +} diff --git a/go/internal/locks/locks.go b/go/internal/locks/locks.go new file mode 100644 index 00000000..a69b2d3c --- /dev/null +++ b/go/internal/locks/locks.go @@ -0,0 +1,101 @@ +// Package locks ports the python hooks' mtime-TTL lock-file patterns. +// Exact TTLs stay with the callers; this package is only the mechanism. +// +// - AcquireExcl mirrors _acquire_self_update_lock +// (claude-code/hooks/unbound.py lines 1238-1248): a fresh lock loses, +// a stale one is unlinked, then an O_CREAT|O_EXCL create decides the +// winner; every failure mode returns false (fail-closed: skip the work). +// - Claim mirrors the discovery dispatch marker (lines 1513-1530): +// O_CREAT|O_EXCL first; on "already exists" check staleness and steal +// (unlink + re-create) only if older than the TTL. Errors other than +// "exists" on the first create propagate, like the python code letting +// non-FileExistsError OSErrors reach the outer handler. +// - IsFresh mirrors the discovery.lock busy check (lines 1502-1508): +// a stat failure is treated as stale (python sets age = TTL + 1). +// - Release mirrors unlink(missing_ok=True) under try/except. +// - Touch mirrors Path.touch(): create the file or bump its mtime +// (self-update state stamp, error-report rate-limit marker). +package locks + +import ( + "errors" + "io/fs" + "os" + "time" +) + +// AcquireExcl takes the lock at path unless a fresh one (younger than ttl) +// exists. Returns true only when this caller created the lock file. +func AcquireExcl(path string, ttl time.Duration) bool { + if fi, err := os.Stat(path); err == nil { + if time.Since(fi.ModTime()) < ttl { + return false + } + if err := os.Remove(path); err != nil && !errors.Is(err, fs.ErrNotExist) { + return false + } + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) + if err != nil { + return false + } + f.Close() + return true +} + +// Claim atomically creates the dispatch marker at path. A fresh existing +// marker yields (false, nil); a stale one is stolen. A first-create failure +// that is not fs.ErrExist is returned as an error (python propagates it). +func Claim(path string, ttl time.Duration) (bool, error) { + f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) + if err == nil { + f.Close() + return true, nil + } + if !errors.Is(err, fs.ErrExist) { + return false, err + } + age := ttl + time.Second // stat failure counts as stale (python: TTL + 1) + if fi, statErr := os.Stat(path); statErr == nil { + age = time.Since(fi.ModTime()) + } + if age < ttl { + return false, nil + } + if err := os.Remove(path); err != nil { + return false, nil + } + f, err = os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) + if err != nil { + return false, nil + } + f.Close() + return true, nil +} + +// IsFresh reports whether path exists and is younger than ttl. Unstattable +// files count as stale. The python self-update throttle (_self_update_due) +// is exactly !IsFresh(statePath, interval). +func IsFresh(path string, ttl time.Duration) bool { + fi, err := os.Stat(path) + if err != nil { + return false + } + return time.Since(fi.ModTime()) < ttl +} + +// Release removes the lock file, ignoring all errors. +func Release(path string) { + _ = os.Remove(path) +} + +// Touch creates path if missing and bumps its mtime to now. +func Touch(path string) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o666) + if err != nil { + return err + } + f.Close() + now := time.Now() + return os.Chtimes(path, now, now) +} diff --git a/go/internal/locks/locks_test.go b/go/internal/locks/locks_test.go new file mode 100644 index 00000000..61fac045 --- /dev/null +++ b/go/internal/locks/locks_test.go @@ -0,0 +1,126 @@ +package locks + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func backdate(t *testing.T, path string, age time.Duration) { + t.Helper() + old := time.Now().Add(-age) + if err := os.Chtimes(path, old, old); err != nil { + t.Fatal(err) + } +} + +func TestAcquireExclCreatesLock(t *testing.T) { + path := filepath.Join(t.TempDir(), "x.lock") + if !AcquireExcl(path, time.Minute) { + t.Fatal("expected acquisition of missing lock") + } + if _, err := os.Stat(path); err != nil { + t.Fatalf("lock file not created: %v", err) + } +} + +func TestAcquireExclFreshLockLoses(t *testing.T) { + path := filepath.Join(t.TempDir(), "x.lock") + if !AcquireExcl(path, time.Minute) { + t.Fatal("setup acquire failed") + } + if AcquireExcl(path, time.Minute) { + t.Error("fresh lock must not be re-acquired") + } +} + +func TestAcquireExclStealsStaleLock(t *testing.T) { + path := filepath.Join(t.TempDir(), "x.lock") + if !AcquireExcl(path, time.Minute) { + t.Fatal("setup acquire failed") + } + backdate(t, path, 2*time.Minute) + if !AcquireExcl(path, time.Minute) { + t.Error("stale lock must be stolen") + } +} + +func TestAcquireExclMissingParentDirFails(t *testing.T) { + path := filepath.Join(t.TempDir(), "missing", "x.lock") + if AcquireExcl(path, time.Minute) { + t.Error("expected failure when parent dir is missing") + } +} + +func TestClaimFirstWins(t *testing.T) { + path := filepath.Join(t.TempDir(), "d.lock") + ok, err := Claim(path, 10*time.Second) + if err != nil || !ok { + t.Fatalf("Claim = %v, %v", ok, err) + } + ok, err = Claim(path, 10*time.Second) + if err != nil || ok { + t.Errorf("fresh marker must block: Claim = %v, %v", ok, err) + } +} + +func TestClaimStealsStaleMarker(t *testing.T) { + path := filepath.Join(t.TempDir(), "d.lock") + if ok, _ := Claim(path, 10*time.Second); !ok { + t.Fatal("setup claim failed") + } + backdate(t, path, time.Minute) + ok, err := Claim(path, 10*time.Second) + if err != nil || !ok { + t.Errorf("stale marker must be stolen: Claim = %v, %v", ok, err) + } +} + +func TestClaimMissingParentDirPropagatesError(t *testing.T) { + path := filepath.Join(t.TempDir(), "missing", "d.lock") + ok, err := Claim(path, 10*time.Second) + if ok || err == nil { + t.Errorf("Claim = %v, %v; want false + error (python propagates non-EEXIST)", ok, err) + } +} + +func TestIsFresh(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "f.lock") + if IsFresh(path, time.Minute) { + t.Error("missing file must not be fresh") + } + if err := Touch(path); err != nil { + t.Fatal(err) + } + if !IsFresh(path, time.Minute) { + t.Error("just-touched file must be fresh") + } + backdate(t, path, 2*time.Minute) + if IsFresh(path, time.Minute) { + t.Error("backdated file must be stale") + } +} + +func TestReleaseIgnoresMissing(t *testing.T) { + Release(filepath.Join(t.TempDir(), "never-existed.lock")) // must not panic +} + +func TestTouchBumpsMtime(t *testing.T) { + path := filepath.Join(t.TempDir(), "t") + if err := Touch(path); err != nil { + t.Fatal(err) + } + backdate(t, path, time.Hour) + if err := Touch(path); err != nil { + t.Fatal(err) + } + fi, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if time.Since(fi.ModTime()) > time.Minute { + t.Errorf("Touch did not bump mtime: %v", fi.ModTime()) + } +} diff --git a/go/internal/pyjson/pyjson.go b/go/internal/pyjson/pyjson.go new file mode 100644 index 00000000..e665c3a0 --- /dev/null +++ b/go/internal/pyjson/pyjson.go @@ -0,0 +1,423 @@ +// Package pyjson encodes and decodes JSON byte-identically to python's +// stdlib json module with default options, as used throughout +// claude-code/hooks/unbound.py (json.dumps / json.loads): separators +// (", ", ": "), ensure_ascii=True (non-ASCII escaped as lowercase \uXXXX, +// astral planes as surrogate pairs), floats printed with python's +// repr(float) rules, and object key order preserved (python dicts are +// insertion-ordered). The parity harness compares stdout byte-for-byte +// against the python hooks, and audit-log lines are rewritten via +// json.dumps(json.loads(line)), so a standard-library json.Marshal +// (",":" separators, sorted map keys, HTML escaping) would diverge. +// +// Known gaps, deliberate: Loads rejects the non-standard NaN/Infinity +// literals python accepts (they never occur in hook inputs, which are +// themselves produced by JSON serializers), and invalid UTF-8 is replaced +// with U+FFFD where python would raise UnicodeDecodeError. +package pyjson + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "sort" + "strconv" + "strings" + "unicode/utf16" +) + +// Object is an insertion-ordered JSON object mirroring python dict +// semantics: setting an existing key updates the value in place and keeps +// the original position. +type Object struct { + members []Member + index map[string]int +} + +// Member is a single key/value pair of an Object. +type Member struct { + Key string + Value any +} + +// NewObject returns an empty ordered object. +func NewObject() *Object { + return &Object{index: map[string]int{}} +} + +// Set inserts or updates a key, preserving insertion order on update. +func (o *Object) Set(key string, value any) *Object { + if i, ok := o.index[key]; ok { + o.members[i].Value = value + return o + } + o.index[key] = len(o.members) + o.members = append(o.members, Member{Key: key, Value: value}) + return o +} + +// Get returns the value for key and whether it was present. +func (o *Object) Get(key string) (any, bool) { + if i, ok := o.index[key]; ok { + return o.members[i].Value, true + } + return nil, false +} + +// GetDefault mirrors python dict.get(key, default). +func (o *Object) GetDefault(key string, def any) any { + if v, ok := o.Get(key); ok { + return v + } + return def +} + +// Len returns the number of members. +func (o *Object) Len() int { return len(o.members) } + +// Members returns the key/value pairs in insertion order. The returned +// slice is the Object's backing storage; do not mutate. +func (o *Object) Members() []Member { return o.members } + +// Number holds a JSON number literal verbatim. Integer literals round-trip +// unchanged (python int(str) -> repr emits the same digits, except "-0" +// which python normalizes to "0"); literals containing '.', 'e' or 'E' are +// floats in python and are re-rendered with repr(float) on encode (python +// normalizes e.g. "1e5" to "100000.0"). +type Number string + +// Loads decodes a single JSON document the way python json.loads does: +// objects keep key order (duplicate keys keep the first position, last +// value), numbers keep their literal text, and trailing non-whitespace +// data is an error. +func Loads(data []byte) (any, error) { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.UseNumber() + v, err := decodeValue(dec) + if err != nil { + return nil, err + } + if _, err := dec.Token(); err != io.EOF { + return nil, errors.New("pyjson: trailing data after JSON document") + } + return v, nil +} + +func decodeValue(dec *json.Decoder) (any, error) { + tok, err := dec.Token() + if err != nil { + return nil, err + } + switch t := tok.(type) { + case json.Delim: + switch t { + case '{': + return decodeObject(dec) + case '[': + return decodeArray(dec) + } + return nil, fmt.Errorf("pyjson: unexpected delimiter %q", t.String()) + case string: + return t, nil + case json.Number: + return Number(t.String()), nil + case bool: + return t, nil + case nil: + return nil, nil + } + return nil, fmt.Errorf("pyjson: unexpected token %v", tok) +} + +func decodeObject(dec *json.Decoder) (*Object, error) { + obj := NewObject() + for { + tok, err := dec.Token() + if err != nil { + return nil, err + } + if d, ok := tok.(json.Delim); ok && d == '}' { + return obj, nil + } + key, ok := tok.(string) + if !ok { + return nil, fmt.Errorf("pyjson: non-string object key %v", tok) + } + val, err := decodeValue(dec) + if err != nil { + return nil, err + } + obj.Set(key, val) + } +} + +func decodeArray(dec *json.Decoder) ([]any, error) { + arr := []any{} + for { + if !dec.More() { + if _, err := dec.Token(); err != nil { // consume ']' + return nil, err + } + return arr, nil + } + v, err := decodeValue(dec) + if err != nil { + return nil, err + } + arr = append(arr, v) + } +} + +// Dumps encodes v exactly as python json.dumps(v) would. Supported values: +// nil, bool, string, Number, int, int64, float64, []any, *Object. +func Dumps(v any) (string, error) { + var sb strings.Builder + if err := encode(&sb, v); err != nil { + return "", err + } + return sb.String(), nil +} + +func encode(sb *strings.Builder, v any) error { + switch t := v.(type) { + case nil: + sb.WriteString("null") + case bool: + if t { + sb.WriteString("true") + } else { + sb.WriteString("false") + } + case string: + encodeString(sb, t) + case Number: + return encodeNumber(sb, t) + case int: + sb.WriteString(strconv.Itoa(t)) + case int64: + sb.WriteString(strconv.FormatInt(t, 10)) + case float64: + sb.WriteString(FloatRepr(t)) + case []any: + sb.WriteByte('[') + for i, e := range t { + if i > 0 { + sb.WriteString(", ") + } + if err := encode(sb, e); err != nil { + return err + } + } + sb.WriteByte(']') + case *Object: + sb.WriteByte('{') + for i, m := range t.Members() { + if i > 0 { + sb.WriteString(", ") + } + encodeString(sb, m.Key) + sb.WriteString(": ") + if err := encode(sb, m.Value); err != nil { + return err + } + } + sb.WriteByte('}') + default: + return fmt.Errorf("pyjson: unsupported type %T", v) + } + return nil +} + +func encodeNumber(sb *strings.Builder, n Number) error { + s := string(n) + if strings.ContainsAny(s, ".eE") { + f, err := strconv.ParseFloat(s, 64) + if err != nil && !errors.Is(err, strconv.ErrRange) { + return fmt.Errorf("pyjson: bad number literal %q", s) + } + // Out-of-range literals overflow to ±Inf in python too + // (json.loads("1e400") -> inf -> dumps -> "Infinity"). + sb.WriteString(FloatRepr(f)) + return nil + } + if s == "-0" { + s = "0" // python: int("-0") == 0 + } + sb.WriteString(s) + return nil +} + +// DumpsIndentSorted encodes v exactly as python +// json.dumps(v, indent=2, sort_keys=True) — the discovery-cache writer's +// format (claude-code/hooks/unbound.py lines 1405, 1596): newline plus a +// two-space per-level indent after '{' / '[' and each ',', ": " after keys, +// the closing bracket at the parent's indent, object keys sorted by code +// point, and empty containers rendered inline as {} / []. +func DumpsIndentSorted(v any) (string, error) { + var sb strings.Builder + if err := encodeIndent(&sb, v, 0); err != nil { + return "", err + } + return sb.String(), nil +} + +func encodeIndent(sb *strings.Builder, v any, level int) error { + const indent = 2 + pad := strings.Repeat(" ", indent*(level+1)) + switch t := v.(type) { + case []any: + if len(t) == 0 { + sb.WriteString("[]") + return nil + } + sb.WriteString("[\n") + for i, e := range t { + if i > 0 { + sb.WriteString(",\n") + } + sb.WriteString(pad) + if err := encodeIndent(sb, e, level+1); err != nil { + return err + } + } + sb.WriteString("\n" + strings.Repeat(" ", indent*level) + "]") + case *Object: + if t.Len() == 0 { + sb.WriteString("{}") + return nil + } + members := append([]Member(nil), t.Members()...) + sort.SliceStable(members, func(i, j int) bool { return members[i].Key < members[j].Key }) + sb.WriteString("{\n") + for i, m := range members { + if i > 0 { + sb.WriteString(",\n") + } + sb.WriteString(pad) + encodeString(sb, m.Key) + sb.WriteString(": ") + if err := encodeIndent(sb, m.Value, level+1); err != nil { + return err + } + } + sb.WriteString("\n" + strings.Repeat(" ", indent*level) + "}") + default: + return encode(sb, v) + } + return nil +} + +// python json's ESCAPE_DCT short escapes. +var shortEscapes = map[rune]string{ + '"': `\"`, + '\\': `\\`, + '\b': `\b`, + '\f': `\f`, + '\n': `\n`, + '\r': `\r`, + '\t': `\t`, +} + +// encodeString mirrors python json.encoder.py_encode_basestring_ascii: +// everything outside 0x20-0x7e is escaped, '/' is not. +func encodeString(sb *strings.Builder, s string) { + sb.WriteByte('"') + for _, r := range s { + if esc, ok := shortEscapes[r]; ok { + sb.WriteString(esc) + continue + } + if r >= 0x20 && r <= 0x7e { + sb.WriteRune(r) + continue + } + if r > 0xffff { + r1, r2 := utf16.EncodeRune(r) + fmt.Fprintf(sb, `\u%04x\u%04x`, r1, r2) + continue + } + fmt.Fprintf(sb, `\u%04x`, r) + } + sb.WriteByte('"') +} + +// FloatRepr renders f exactly as python repr(float) / json.dumps(float): +// shortest round-trip digits, fixed notation while the decimal point +// position is in (-4, 16], otherwise scientific with a sign and at least +// two exponent digits. Non-finite values use python json's non-standard +// Infinity / -Infinity / NaN spellings. +func FloatRepr(f float64) string { + if math.IsInf(f, 1) { + return "Infinity" + } + if math.IsInf(f, -1) { + return "-Infinity" + } + if math.IsNaN(f) { + return "NaN" + } + s := strconv.FormatFloat(f, 'e', -1, 64) + neg := strings.HasPrefix(s, "-") + if neg { + s = s[1:] + } + mant, expStr, _ := strings.Cut(s, "e") + digits := strings.Replace(mant, ".", "", 1) + exp, _ := strconv.Atoi(expStr) + decpt := exp + 1 // decimal point position relative to digits + + var body string + switch { + case decpt <= -4 || decpt > 16: + body = digits[:1] + if len(digits) > 1 { + body += "." + digits[1:] + } + e := decpt - 1 + sign := "+" + if e < 0 { + sign = "-" + e = -e + } + body += "e" + sign + fmt.Sprintf("%02d", e) + case decpt <= 0: + body = "0." + strings.Repeat("0", -decpt) + digits + case decpt >= len(digits): + body = digits + strings.Repeat("0", decpt-len(digits)) + ".0" + default: + body = digits[:decpt] + "." + digits[decpt:] + } + if neg { + body = "-" + body + } + return body +} + +// Truthy mirrors python truthiness for the value kinds Loads produces. +// Used by ports of python code that branch on `if value:`. +func Truthy(v any) bool { + switch t := v.(type) { + case nil: + return false + case bool: + return t + case string: + return t != "" + case Number: + f, err := strconv.ParseFloat(string(t), 64) + return err != nil || f != 0 + case int: + return t != 0 + case int64: + return t != 0 + case float64: + return t != 0 + case []any: + return len(t) > 0 + case *Object: + return t.Len() > 0 + } + return true +} diff --git a/go/internal/pyjson/pyjson_test.go b/go/internal/pyjson/pyjson_test.go new file mode 100644 index 00000000..da856bd1 --- /dev/null +++ b/go/internal/pyjson/pyjson_test.go @@ -0,0 +1,229 @@ +package pyjson + +import ( + "math" + "testing" +) + +// bs is a literal backslash for building \u goldens (kept out of string +// literals so tooling cannot normalize the escapes). +const bs = "\x5c" + +// Golden values produced by python3 json.dumps / repr. Any change here must +// be re-verified against python. +func TestFloatReprMatchesPython(t *testing.T) { + cases := []struct { + in float64 + want string + }{ + {1e15, "1000000000000000.0"}, + {1e16, "1e+16"}, + {1e17, "1e+17"}, + {123456789012345.0, "123456789012345.0"}, + {1234567890123456.0, "1234567890123456.0"}, + {0.0001, "0.0001"}, + {1e-5, "1e-05"}, + {1.5, "1.5"}, + {0.0, "0.0"}, + {math.Copysign(0, -1), "-0.0"}, + {3.14, "3.14"}, + {1e100, "1e+100"}, + {5e-324, "5e-324"}, + {1.7976931348623157e308, "1.7976931348623157e+308"}, + {2.5e-10, "2.5e-10"}, + {100000.0, "100000.0"}, + {1e21, "1e+21"}, + {math.Inf(1), "Infinity"}, + {math.Inf(-1), "-Infinity"}, + {math.NaN(), "NaN"}, + } + for _, c := range cases { + if got := FloatRepr(c.in); got != c.want { + t.Errorf("FloatRepr(%v) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestDumpsMatchesPython(t *testing.T) { + obj := NewObject(). + Set("a", 1). + Set("b", "héllo ☃"). + Set("c", []any{true, nil, math.Copysign(0, -1)}) + got, err := Dumps(obj) + if err != nil { + t.Fatal(err) + } + // python: json.dumps({'a': 1, 'b': 'h\xe9llo ☃', 'c': [True, None, -0.0]}) + want := `{"a": 1, "b": "h` + bs + `u00e9llo ` + bs + `u2603", "c": [true, null, -0.0]}` + if got != want { + t.Errorf("Dumps = %q, want %q", got, want) + } +} + +func TestDumpsStringEscaping(t *testing.T) { + got, err := Dumps("a\x7f\x01\"\\/\n") + if err != nil { + t.Fatal(err) + } + // python golden: '/' not escaped, DEL and control chars \u-escaped + want := `"a` + bs + `u007f` + bs + `u0001\"\\/\n"` + if got != want { + t.Errorf("Dumps = %q, want %q", got, want) + } +} +func TestDumpsSurrogatePairs(t *testing.T) { + got, err := Dumps("\U0001F600\u2028") + if err != nil { + t.Fatal(err) + } + // python golden: astral plane as a surrogate pair, U+2028 escaped too + want := `"` + bs + `ud83d` + bs + `ude00` + bs + `u2028"` + if got != want { + t.Errorf("Dumps = %q, want %q", got, want) + } +} + +func TestNumberLiterals(t *testing.T) { + cases := []struct{ in, want string }{ + {"1E2", "100.0"}, // python float normalization + {"-0", "0"}, // python int("-0") == 0 + {"5.00", "5.0"}, + {"123456789012345678901234567890", "123456789012345678901234567890"}, // big int verbatim + {"1e5", "100000.0"}, + {"1e400", "Infinity"}, // overflow like python json.loads + } + for _, c := range cases { + got, err := Dumps(Number(c.in)) + if err != nil { + t.Fatalf("Dumps(Number(%q)): %v", c.in, err) + } + if got != c.want { + t.Errorf("Dumps(Number(%q)) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestLoadsDumpsRoundTrip(t *testing.T) { + // A claude-code audit-log style line: order must be preserved. + in := `{"timestamp": "2026-06-12T01:02:03Z", "session_id": "s1", "event": {"hook_event_name": "Stop", "n": 5, "f": 1.5, "deep": [1, {"k": null}]}}` + v, err := Loads([]byte(in)) + if err != nil { + t.Fatal(err) + } + out, err := Dumps(v) + if err != nil { + t.Fatal(err) + } + if out != in { + t.Errorf("round trip changed bytes:\n in: %s\nout: %s", in, out) + } +} + +func TestLoadsCompactInputNormalized(t *testing.T) { + // python json.dumps(json.loads(s)) normalizes whitespace. + v, err := Loads([]byte(`{"a":1,"b":[1,2]}`)) + if err != nil { + t.Fatal(err) + } + out, err := Dumps(v) + if err != nil { + t.Fatal(err) + } + if want := `{"a": 1, "b": [1, 2]}`; out != want { + t.Errorf("got %q, want %q", out, want) + } +} + +func TestLoadsDuplicateKeysLastValueFirstPosition(t *testing.T) { + v, err := Loads([]byte(`{"a": 1, "b": 2, "a": 3}`)) + if err != nil { + t.Fatal(err) + } + out, err := Dumps(v) + if err != nil { + t.Fatal(err) + } + if want := `{"a": 3, "b": 2}`; out != want { // python dict semantics + t.Errorf("got %q, want %q", out, want) + } +} + +func TestLoadsRejectsTrailingData(t *testing.T) { + if _, err := Loads([]byte(`{} {}`)); err == nil { + t.Error("expected error for trailing data") + } + if _, err := Loads([]byte(`{"a": 1}` + "\n ")); err != nil { + t.Errorf("trailing whitespace should be fine: %v", err) + } +} + +func TestObjectSetUpdatesInPlace(t *testing.T) { + o := NewObject().Set("x", 1).Set("y", 2).Set("x", 3) + out, err := Dumps(o) + if err != nil { + t.Fatal(err) + } + if want := `{"x": 3, "y": 2}`; out != want { + t.Errorf("got %q, want %q", out, want) + } + if v, ok := o.Get("x"); !ok || v != 3 { + t.Errorf("Get(x) = %v, %v", v, ok) + } +} + +func TestTruthy(t *testing.T) { + truthy := []any{true, "x", Number("1"), Number("0.5"), []any{nil}, NewObject().Set("k", nil), 1, int64(2), 0.1} + falsy := []any{nil, false, "", Number("0"), Number("0.0"), Number("-0"), []any{}, NewObject(), 0, int64(0), 0.0} + for _, v := range truthy { + if !Truthy(v) { + t.Errorf("Truthy(%#v) = false, want true", v) + } + } + for _, v := range falsy { + if Truthy(v) { + t.Errorf("Truthy(%#v) = true, want false", v) + } + } +} + +func TestDumpsEmptyContainers(t *testing.T) { + for in, want := range map[string]string{`{}`: `{}`, `[]`: `[]`, `[[]]`: `[[]]`} { + v, err := Loads([]byte(in)) + if err != nil { + t.Fatal(err) + } + out, err := Dumps(v) + if err != nil { + t.Fatal(err) + } + if out != want { + t.Errorf("Dumps(Loads(%q)) = %q, want %q", in, out, want) + } + } +} + +func TestDumpsIndentSortedMatchesPython(t *testing.T) { + // python: json.dumps(v, indent=2, sort_keys=True) goldens. + obj := NewObject(). + Set("b", []any{1, 2.5, "x"}). + Set("a", NewObject(). + Set("nested", NewObject().Set("z", nil).Set("y", true)). + Set("empty", NewObject()). + Set("list", []any{})). + Set("c", "ué") + got, err := DumpsIndentSorted(obj) + if err != nil { + t.Fatal(err) + } + want := "{\n \"a\": {\n \"empty\": {},\n \"list\": [],\n \"nested\": {\n \"y\": true,\n \"z\": null\n }\n },\n \"b\": [\n 1,\n 2.5,\n \"x\"\n ],\n \"c\": \"u" + bs + "u00e9\"\n}" + if got != want { + t.Errorf("DumpsIndentSorted = %q, want %q", got, want) + } + + if got, _ := DumpsIndentSorted(NewObject()); got != "{}" { + t.Errorf("empty object = %q, want {}", got) + } + if got, _ := DumpsIndentSorted([]any{1, []any{2}}); got != "[\n 1,\n [\n 2\n ]\n]" { + t.Errorf("nested list = %q", got) + } +} diff --git a/go/internal/report/report.go b/go/internal/report/report.go new file mode 100644 index 00000000..7ca6b318 --- /dev/null +++ b/go/internal/report/report.go @@ -0,0 +1,183 @@ +// Package report ports the python hooks' error logging + best-effort +// backend reporting. It must never fail the hook: every path swallows its +// own errors, mirroring the blanket try/except in the python originals. +// +// - Reporter.LogError mirrors log_error (claude-code/hooks/unbound.py +// lines 120-141): ": \n" appended to error.log, +// trimmed to the last 25 lines, then forwarded to the gateway. +// - Reporter.ReportToGateway mirrors report_error_to_gateway (lines +// 92-117): rate-limited to one report per 60s via the .last_error_report +// marker mtime (_should_report, fail-closed), reentrancy-guarded, and +// fired as a detached curl POST to /v1/hooks/errors with payload +// {"errors": [{"message", "timestamp", "category"}], "hook_source"}. +// No message truncation — python sends the full string. +// +// Timestamp quirk copied as-is: claude-code and codex stamp error.log lines +// with datetime.utcnow().isoformat()+"Z" while cursor and copilot use the +// local-zone datetime.now().astimezone().isoformat() (with "+00:00" +// rewritten to "Z"); the gateway payload timestamp is always the UTC form. +// Python isoformat omits the .%06d microseconds entirely when they are +// exactly zero — UTCTimestamp/LocalTimestamp reproduce that. +package report + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/websentry-ai/setup/go/internal/httpc" + "github.com/websentry-ai/setup/go/internal/locks" + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +// maxErrorLogLines mirrors the "keep only last 25 errors" trim. +const maxErrorLogLines = 25 + +// rateLimitWindow mirrors _should_report's 60-second window. +const rateLimitWindow = 60 * time.Second + +// Reporter carries the per-tool error-reporting state that lives in module +// globals on the python side. +type Reporter struct { + GatewayURL string // already slash-stripped (config.GatewayURL) + HookSource string // "claude-code" | "cursor" | "codex" | "copilot" + ErrorLog string // per-tool error.log path + LastReportFile string // per-tool .last_error_report marker + APIKey string // _cached_api_key: set once by main, used by LogError + LocalTime bool // cursor/copilot stamp error.log in local time + + reporting bool // _reporting_error reentrancy flag + now func() time.Time // test seam; nil means time.Now +} + +func (r *Reporter) clock() time.Time { + if r.now != nil { + return r.now() + } + return time.Now() +} + +// isoSeconds renders python datetime.isoformat()'s date-time part: +// microseconds are six digits, omitted entirely when zero. +func isoSeconds(t time.Time) string { + s := t.Format("2006-01-02T15:04:05") + if us := t.Nanosecond() / 1000; us != 0 { + s += fmt.Sprintf(".%06d", us) + } + return s +} + +// UTCTimestamp mirrors datetime.utcnow().isoformat() + "Z". +func UTCTimestamp(t time.Time) string { + return isoSeconds(t.UTC()) + "Z" +} + +// LocalTimestamp mirrors +// datetime.now().astimezone().isoformat().replace("+00:00", "Z"). +func LocalTimestamp(t time.Time) string { + local := t.Local() + _, offset := local.Zone() + sign := "+" + if offset < 0 { + sign = "-" + offset = -offset + } + suffix := fmt.Sprintf("%s%02d:%02d", sign, offset/3600, offset%3600/60) + if suffix == "+00:00" { + suffix = "Z" + } + return isoSeconds(local) + suffix +} + +// shouldReport rate-limits to one gateway report per window. Fails closed: +// any filesystem error means "do not report". +func (r *Reporter) shouldReport() bool { + if fi, err := os.Stat(r.LastReportFile); err == nil { + if r.clock().Sub(fi.ModTime()) < rateLimitWindow { + return false + } + } + if err := locks.Touch(r.LastReportFile); err != nil { + return false + } + return true +} + +// ReportToGateway fires a best-effort error report. Never blocks (detached +// curl), never returns an error. +func (r *Reporter) ReportToGateway(message, category, apiKey string) { + if r.reporting || apiKey == "" || !r.shouldReport() { + return + } + r.reporting = true + defer func() { r.reporting = false }() + + entry := pyjson.NewObject(). + Set("message", message). + Set("timestamp", UTCTimestamp(r.clock())). + Set("category", category) + payload, err := pyjson.Dumps(pyjson.NewObject(). + Set("errors", []any{entry}). + Set("hook_source", r.HookSource)) + if err != nil { + return + } + _ = httpc.PostJSONDetached(r.GatewayURL+"/v1/hooks/errors", apiKey, []byte(payload)) +} + +// LogError appends a timestamped line to error.log, trims it to the last +// 25 lines, then forwards the message to the gateway using the cached API +// key. All errors are swallowed. +func (r *Reporter) LogError(message, category string) { + ts := UTCTimestamp(r.clock()) + if r.LocalTime { + ts = LocalTimestamp(r.clock()) + } + r.appendAndTrim(ts + ": " + message + "\n") + r.ReportToGateway(message, category, r.APIKey) +} + +// appendAndTrim mirrors log_error's file handling. claude-code mkdirs the +// parent first (lines 126); cursor/copilot create LOG_DIR at startup +// instead, so the mkdir is a no-op there — kept unconditional. +func (r *Reporter) appendAndTrim(entry string) { + if err := os.MkdirAll(filepath.Dir(r.ErrorLog), 0o755); err != nil { + return + } + f, err := os.OpenFile(r.ErrorLog, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return + } + _, werr := f.WriteString(entry) + if cerr := f.Close(); werr != nil || cerr != nil { + return + } + + data, err := os.ReadFile(r.ErrorLog) + if err != nil { + return + } + lines := splitKeepEnds(string(data)) + if len(lines) > maxErrorLogLines { + trimmed := strings.Join(lines[len(lines)-maxErrorLogLines:], "") + _ = os.WriteFile(r.ErrorLog, []byte(trimmed), 0o644) + } +} + +// splitKeepEnds mirrors python readlines(): split after each '\n', a +// trailing unterminated chunk counts as a line. +func splitKeepEnds(s string) []string { + var lines []string + for len(s) > 0 { + i := strings.IndexByte(s, '\n') + if i < 0 { + lines = append(lines, s) + break + } + lines = append(lines, s[:i+1]) + s = s[i+1:] + } + return lines +} diff --git a/go/internal/report/report_test.go b/go/internal/report/report_test.go new file mode 100644 index 00000000..a0b7107a --- /dev/null +++ b/go/internal/report/report_test.go @@ -0,0 +1,172 @@ +package report + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func newReporter(t *testing.T) (*Reporter, string) { + t.Helper() + dir := t.TempDir() + return &Reporter{ + GatewayURL: "https://gw.example.com", + HookSource: "claude-code", + ErrorLog: filepath.Join(dir, "error.log"), + LastReportFile: filepath.Join(dir, ".last_error_report"), + }, dir +} + +// installFakeCurl captures the detached error-report POST. +func installFakeCurl(t *testing.T) (argsFile, stdinFile string) { + t.Helper() + dir := t.TempDir() + argsFile = filepath.Join(dir, "args") + stdinFile = filepath.Join(dir, "stdin") + script := `#!/bin/sh +for a in "$@"; do printf '%s\n' "$a"; done > "$CURL_ARGS_FILE" +cat > "$CURL_STDIN_FILE" +` + if err := os.WriteFile(filepath.Join(dir, "curl"), []byte(script), 0o755); err != nil { + t.Fatal(err) + } + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("CURL_ARGS_FILE", argsFile) + t.Setenv("CURL_STDIN_FILE", stdinFile) + return argsFile, stdinFile +} + +func waitForFile(t *testing.T, path string) string { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for { + if data, err := os.ReadFile(path); err == nil && len(data) > 0 { + return string(data) + } + if time.Now().After(deadline) { + t.Fatalf("file %s never appeared", path) + } + time.Sleep(10 * time.Millisecond) + } +} + +func TestUTCTimestampMatchesPythonIsoformat(t *testing.T) { + ts := time.Date(2026, 6, 12, 1, 2, 3, 456789000, time.UTC) + if got := UTCTimestamp(ts); got != "2026-06-12T01:02:03.456789Z" { + t.Errorf("UTCTimestamp = %q", got) + } + // python isoformat omits microseconds entirely when zero + zero := time.Date(2026, 6, 12, 1, 2, 3, 0, time.UTC) + if got := UTCTimestamp(zero); got != "2026-06-12T01:02:03Z" { + t.Errorf("UTCTimestamp = %q", got) + } +} + +func TestLocalTimestampOffsetFormat(t *testing.T) { + loc := time.FixedZone("IST", 5*3600+30*60) + ts := time.Date(2026, 6, 12, 1, 2, 3, 0, loc) + defer func(orig *time.Location) { time.Local = orig }(time.Local) + time.Local = loc + if got := LocalTimestamp(ts); got != "2026-06-12T01:02:03+05:30" { + t.Errorf("LocalTimestamp = %q", got) + } + time.Local = time.FixedZone("UTC", 0) // python replaces "+00:00" with "Z" + if got := LocalTimestamp(ts.UTC()); !strings.HasSuffix(got, "Z") { + t.Errorf("LocalTimestamp = %q, want Z suffix", got) + } +} + +func TestLogErrorAppendsTimestampedLine(t *testing.T) { + r, _ := newReporter(t) + r.now = func() time.Time { return time.Date(2026, 6, 12, 1, 2, 3, 0, time.UTC) } + r.LogError("boom happened", "general") + data, err := os.ReadFile(r.ErrorLog) + if err != nil { + t.Fatal(err) + } + if got := string(data); got != "2026-06-12T01:02:03Z: boom happened\n" { + t.Errorf("error.log = %q", got) + } +} + +func TestLogErrorTrimsToLast25Lines(t *testing.T) { + r, _ := newReporter(t) + for i := 0; i < 30; i++ { + r.LogError(fmt.Sprintf("err %d", i), "general") + } + data, err := os.ReadFile(r.ErrorLog) + if err != nil { + t.Fatal(err) + } + lines := strings.Split(strings.TrimRight(string(data), "\n"), "\n") + if len(lines) != 25 { + t.Fatalf("got %d lines, want 25", len(lines)) + } + if !strings.HasSuffix(lines[0], ": err 5") || !strings.HasSuffix(lines[24], ": err 29") { + t.Errorf("window wrong: first %q last %q", lines[0], lines[24]) + } +} + +func TestReportToGatewayPayloadMatchesPython(t *testing.T) { + argsFile, stdinFile := installFakeCurl(t) + r, _ := newReporter(t) + r.now = func() time.Time { return time.Date(2026, 6, 12, 1, 2, 3, 456789000, time.UTC) } + + r.ReportToGateway("it broke", "api_call", "sk-9") + + // python: json.dumps({'errors': [{'message': ..., 'timestamp': ..., 'category': ...}], 'hook_source': 'claude-code'}) + wantBody := `{"errors": [{"message": "it broke", "timestamp": "2026-06-12T01:02:03.456789Z", "category": "api_call"}], "hook_source": "claude-code"}` + if got := waitForFile(t, stdinFile); got != wantBody { + t.Errorf("payload:\n got %s\nwant %s", got, wantBody) + } + + wantArgs := []string{"-fsSL", "-X", "POST", + "-H", "Authorization: Bearer sk-9", + "-H", "Content-Type: application/json", + "--data-binary", "@-", + "https://gw.example.com/v1/hooks/errors"} + got := strings.Split(strings.TrimRight(waitForFile(t, argsFile), "\n"), "\n") + if strings.Join(got, "\x00") != strings.Join(wantArgs, "\x00") { + t.Errorf("argv = %q, want %q", got, wantArgs) + } +} + +func TestReportToGatewayRateLimited(t *testing.T) { + _, stdinFile := installFakeCurl(t) + r, _ := newReporter(t) + + r.ReportToGateway("first", "general", "sk-9") + waitForFile(t, stdinFile) + if err := os.Remove(stdinFile); err != nil { + t.Fatal(err) + } + r.ReportToGateway("second", "general", "sk-9") // inside the 60s window + time.Sleep(200 * time.Millisecond) + if _, err := os.ReadFile(stdinFile); err == nil { + t.Error("second report within 60s must be suppressed") + } +} + +func TestReportToGatewayNoAPIKeyIsNoop(t *testing.T) { + _, stdinFile := installFakeCurl(t) + r, _ := newReporter(t) + r.ReportToGateway("msg", "general", "") + time.Sleep(100 * time.Millisecond) + if _, err := os.ReadFile(stdinFile); err == nil { + t.Error("no api key must mean no report") + } + if _, err := os.Stat(r.LastReportFile); err == nil { + t.Error("rate-limit marker must not be touched before the key check") + } +} + +func TestShouldReportFailsClosed(t *testing.T) { + r, _ := newReporter(t) + r.LastReportFile = filepath.Join(t.TempDir(), "missing-dir", "marker") + if r.shouldReport() { + t.Error("untouchable marker must fail closed") + } +} diff --git a/go/internal/selfupdate/selfupdate.go b/go/internal/selfupdate/selfupdate.go new file mode 100644 index 00000000..ab566392 --- /dev/null +++ b/go/internal/selfupdate/selfupdate.go @@ -0,0 +1,19 @@ +// Package selfupdate will hold the Go binary's self-update, the +// binary-swap variant of _check_self_update and its helpers +// (claude-code/hooks/unbound.py lines 1283-1344: the 2h throttle via the +// .self_update_check mtime, the 30s .self_update.lock, the curl download, +// and the same-directory tempfile + rename swap; cursor/codex/copilot carry +// per-tool copies of the same flow). +// +// Until that lands, Check is a deliberate no-op — and a faithful one: the +// python frozen-binary gate (lines 1284-1286) skips self-update entirely +// because packaged deployments are updated by the MDM package, never in +// place, and the managed-location gate (lines 1292-1301) skips whenever the +// running file is not the user-level ~/.claude/hooks/unbound.py script, +// which is never true for this binary. No state files are written and no +// network calls are made. +package selfupdate + +// Check is the SessionStart self-update entry point (unbound.py line 1631). +// No-op for now; see the package comment. +func Check() {} diff --git a/go/internal/transcript/transcript.go b/go/internal/transcript/transcript.go new file mode 100644 index 00000000..bfc979d7 --- /dev/null +++ b/go/internal/transcript/transcript.go @@ -0,0 +1,245 @@ +// Package transcript ports the claude-code Stop-path transcript reader: +// parse_transcript_file (claude-code/hooks/unbound.py lines 367-438), the +// JSONL primitive process_stop_event and get_recent_user_prompts_for_session +// build on. It walks the ~/.claude/projects session transcript line by +// line, collecting user messages, this turn's assistant text, the turn +// model, and summed token usage. +// +// Python quirks copied as-is: +// +// - The userPromptTimestamp exchange-boundary filter (string-compare +// entry.timestamp <= boundary) applies ONLY to assistant entries; user +// messages are collected from the whole file. +// - The first truthy assistant model wins ("turn_model = turn_model or +// message.get('model')"), captured even on usage-less entries. +// - Undecodable lines are skipped (except json.JSONDecodeError), but any +// other "exception" — a non-object entry, a non-dict message, a null / +// numeric content, a non-dict usage, an uncoercible usage value, a +// non-string timestamp on a filtered assistant entry — aborts the scan +// (python's blanket `except Exception: pass`), keeping whatever was +// accumulated. +// - Usage is reported only when some counter is non-zero +// (any(usage.values())); total_tokens is the sum of the four counters. +// - ToolUses exists in the python result shape but is never populated; +// kept for parity. +// +// Divergences from python, accepted: invalid UTF-8 becomes U+FFFD instead +// of aborting (python text mode raises UnicodeDecodeError), and usage +// values beyond int64 fail the int coercion (python ints are unbounded). +package transcript + +import ( + "bufio" + "bytes" + "math" + "os" + "strconv" + "strings" + + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +// Message is one collected transcript message. Content and Timestamp hold +// the raw pyjson values (content may be a string or a content-block list; +// timestamp may be absent = nil), exactly what python stores. +type Message struct { + Content any + Timestamp any +} + +// Usage is the summed token usage across the scanned assistant entries. +type Usage struct { + InputTokens int64 + OutputTokens int64 + CacheReadInputTokens int64 + CacheCreationInputTokens int64 + TotalTokens int64 +} + +// Data mirrors parse_transcript_file's conversation_data dict. +type Data struct { + UserMessages []Message + AssistantMessages []Message + ToolUses []any // never populated, like python + Usage *Usage + Model any // raw model value; nil when none seen +} + +var usageKeys = [4]string{ + "input_tokens", "output_tokens", + "cache_read_input_tokens", "cache_creation_input_tokens", +} + +// ParseFile reads a transcript JSONL file. userPromptTimestamp == "" means +// no boundary filter (python's None). Missing/empty path or an unreadable +// file returns empty Data; nothing here ever fails the hook. +func ParseFile(path, userPromptTimestamp string) Data { + data := Data{ + UserMessages: []Message{}, + AssistantMessages: []Message{}, + ToolUses: []any{}, + } + if path == "" { + return data + } + f, err := os.Open(path) + if err != nil { + return data + } + defer f.Close() + + var counts [4]int64 + var turnModel any + + r := bufio.NewReader(f) +scan: + for { + line, readErr := r.ReadBytes('\n') + trimmed := bytes.TrimSpace(line) + if len(trimmed) > 0 { + abort := scanEntry(trimmed, userPromptTimestamp, &data, &counts, &turnModel) + if abort { + break scan + } + } + if readErr != nil { + break scan + } + } + + if counts[0] != 0 || counts[1] != 0 || counts[2] != 0 || counts[3] != 0 { + data.Usage = &Usage{ + InputTokens: counts[0], + OutputTokens: counts[1], + CacheReadInputTokens: counts[2], + CacheCreationInputTokens: counts[3], + TotalTokens: counts[0] + counts[1] + counts[2] + counts[3], + } + } + if pyjson.Truthy(turnModel) { + data.Model = turnModel + } + return data +} + +// scanEntry processes one JSONL line. Returns true where python would have +// raised out of the per-line code into the file-level `except Exception`. +func scanEntry(line []byte, userPromptTimestamp string, data *Data, counts *[4]int64, turnModel *any) bool { + entry, err := pyjson.Loads(line) + if err != nil { + return false // json.JSONDecodeError: continue + } + obj, ok := entry.(*pyjson.Object) + if !ok { + return true // entry.get on a non-dict: AttributeError + } + entryType, _ := obj.GetDefault("type", "").(string) + entryTimestamp, _ := obj.Get("timestamp") + + switch entryType { + case "user": + msg, ok := obj.GetDefault("message", pyjson.NewObject()).(*pyjson.Object) + if !ok { + return true // message.get on a non-dict + } + if role, _ := msg.GetDefault("role", nil).(string); role == "user" { + content := msg.GetDefault("content", "") + if pyjson.Truthy(content) { + data.UserMessages = append(data.UserMessages, Message{content, entryTimestamp}) + } + } + + case "assistant": + if userPromptTimestamp != "" && pyjson.Truthy(entryTimestamp) { + ts, ok := entryTimestamp.(string) + if !ok { + return true // str <= non-str: TypeError + } + if ts <= userPromptTimestamp { + return false + } + } + msg, ok := obj.GetDefault("message", pyjson.NewObject()).(*pyjson.Object) + if !ok { + return true + } + if role, _ := msg.GetDefault("role", nil).(string); role != "assistant" { + return false + } + + switch content := msg.GetDefault("content", []any{}).(type) { + case []any: + for _, item := range content { + io, ok := item.(*pyjson.Object) + if !ok { + continue // isinstance(content_item, dict) check + } + if t, _ := io.GetDefault("type", nil).(string); t != "text" { + continue + } + text := io.GetDefault("text", "") + if pyjson.Truthy(text) { + data.AssistantMessages = append(data.AssistantMessages, Message{text, entryTimestamp}) + } + } + case string, *pyjson.Object: + // python iterates chars / keys — strings are never dicts, so no-op + default: + return true // `for ... in None/number/bool`: TypeError + } + + if !pyjson.Truthy(*turnModel) { + *turnModel = msg.GetDefault("model", nil) + } + + usageVal := msg.GetDefault("usage", nil) + if pyjson.Truthy(usageVal) { // `message.get('usage') or {}` then `if msg_usage:` + uo, ok := usageVal.(*pyjson.Object) + if !ok { + return true // msg_usage.get on a non-dict + } + for i, key := range usageKeys { + n, ok := pyInt(uo.GetDefault(key, nil)) + if !ok { + return true // int() raised + } + counts[i] += n + } + } + } + return false +} + +// pyInt mirrors `int(value or 0)`: falsy values are 0; numbers truncate +// toward zero; numeric strings parse base-10 after strip. ok=false where +// python int() would raise (and the scan must abort). +func pyInt(v any) (int64, bool) { + if !pyjson.Truthy(v) { + return 0, true + } + switch t := v.(type) { + case bool: + return 1, true // only reachable for true; false is falsy + case pyjson.Number: + s := string(t) + if strings.ContainsAny(s, ".eE") { + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, false + } + return int64(math.Trunc(f)), true + } + n, err := strconv.ParseInt(s, 10, 64) + return n, err == nil + case string: + n, err := strconv.ParseInt(strings.TrimSpace(t), 10, 64) + return n, err == nil + case int: + return int64(t), true + case int64: + return t, true + case float64: + return int64(math.Trunc(t)), true + } + return 0, false // TypeError: dict/list +} diff --git a/go/internal/transcript/transcript_test.go b/go/internal/transcript/transcript_test.go new file mode 100644 index 00000000..2dabdec9 --- /dev/null +++ b/go/internal/transcript/transcript_test.go @@ -0,0 +1,205 @@ +package transcript + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/websentry-ai/setup/go/internal/pyjson" +) + +func writeFixture(t *testing.T, lines ...string) string { + t.Helper() + path := filepath.Join(t.TempDir(), "transcript.jsonl") + if err := os.WriteFile(path, []byte(strings.Join(lines, "\n")+"\n"), 0o644); err != nil { + t.Fatal(err) + } + return path +} + +const ( + userLine1 = `{"type": "user", "timestamp": "2026-06-12T00:00:01Z", "message": {"role": "user", "content": "first prompt"}}` + // list-form content is kept verbatim (python appends whatever truthy value) + userLine2 = `{"type": "user", "timestamp": "2026-06-12T00:00:05Z", "message": {"role": "user", "content": [{"type": "text", "text": "second prompt"}]}}` + // empty content is dropped + userEmpty = `{"type": "user", "timestamp": "2026-06-12T00:00:06Z", "message": {"role": "user", "content": ""}}` + // assistant BEFORE the boundary timestamp + asstEarly = `{"type": "assistant", "timestamp": "2026-06-12T00:00:02Z", "message": {"role": "assistant", "model": "model-early", "content": [{"type": "text", "text": "early reply"}], "usage": {"input_tokens": 1, "output_tokens": 1}}}` + // assistant AFTER the boundary: two text blocks, a tool_use block, usage + asstLate = `{"type": "assistant", "timestamp": "2026-06-12T00:00:07Z", "message": {"role": "assistant", "model": "claude-sonnet-4-5", "content": [{"type": "text", "text": "part one"}, {"type": "tool_use", "id": "t1"}, {"type": "text", "text": "part two"}], "usage": {"input_tokens": 10, "output_tokens": 20, "cache_read_input_tokens": 30, "cache_creation_input_tokens": 40}}}` + // usage-less assistant entry — model still captured, but first model wins + asstNoUsage = `{"type": "assistant", "timestamp": "2026-06-12T00:00:08Z", "message": {"role": "assistant", "model": "model-late", "content": [{"type": "text", "text": "tail"}]}}` +) + +func TestParseFileNoBoundary(t *testing.T) { + path := writeFixture(t, userLine1, asstEarly, userLine2, userEmpty, "", " ", "{bad json", asstLate, asstNoUsage) + d := ParseFile(path, "") + + if len(d.UserMessages) != 2 { + t.Fatalf("UserMessages = %d, want 2", len(d.UserMessages)) + } + if d.UserMessages[0].Content != "first prompt" { + t.Errorf("user[0] = %v", d.UserMessages[0].Content) + } + if _, ok := d.UserMessages[1].Content.([]any); !ok { + t.Errorf("user[1] content kept verbatim, got %T", d.UserMessages[1].Content) + } + if d.UserMessages[0].Timestamp != "2026-06-12T00:00:01Z" { + t.Errorf("user[0] ts = %v", d.UserMessages[0].Timestamp) + } + + var texts []string + for _, m := range d.AssistantMessages { + texts = append(texts, m.Content.(string)) + } + want := []string{"early reply", "part one", "part two", "tail"} + if strings.Join(texts, "|") != strings.Join(want, "|") { + t.Errorf("assistant texts = %v, want %v", texts, want) + } + + if d.Model != "model-early" { // first truthy model wins + t.Errorf("Model = %v, want model-early", d.Model) + } + if d.Usage == nil { + t.Fatal("Usage = nil") + } + if d.Usage.InputTokens != 11 || d.Usage.OutputTokens != 21 || + d.Usage.CacheReadInputTokens != 30 || d.Usage.CacheCreationInputTokens != 40 || + d.Usage.TotalTokens != 102 { + t.Errorf("Usage = %+v", *d.Usage) + } + if len(d.ToolUses) != 0 { + t.Errorf("ToolUses must stay empty (python never fills it)") + } +} + +func TestParseFileBoundaryFiltersAssistantOnly(t *testing.T) { + path := writeFixture(t, userLine1, asstEarly, userLine2, asstLate, asstNoUsage) + d := ParseFile(path, "2026-06-12T00:00:05Z") + + // user messages are NOT filtered (python quirk) + if len(d.UserMessages) != 2 { + t.Fatalf("UserMessages = %d, want 2", len(d.UserMessages)) + } + var texts []string + for _, m := range d.AssistantMessages { + texts = append(texts, m.Content.(string)) + } + want := []string{"part one", "part two", "tail"} + if strings.Join(texts, "|") != strings.Join(want, "|") { + t.Errorf("assistant texts = %v, want %v", texts, want) + } + // early entry filtered before usage/model capture + if d.Model != "claude-sonnet-4-5" { + t.Errorf("Model = %v", d.Model) + } + if d.Usage == nil || d.Usage.InputTokens != 10 || d.Usage.TotalTokens != 100 { + t.Errorf("Usage = %+v", d.Usage) + } +} + +func TestParseFileMissingOrEmptyPath(t *testing.T) { + for _, path := range []string{"", filepath.Join(t.TempDir(), "nope.jsonl")} { + d := ParseFile(path, "") + if len(d.UserMessages) != 0 || len(d.AssistantMessages) != 0 || d.Usage != nil || d.Model != nil { + t.Errorf("ParseFile(%q) not empty: %+v", path, d) + } + } +} + +func TestParseFileZeroUsageStaysNil(t *testing.T) { + path := writeFixture(t, `{"type": "assistant", "message": {"role": "assistant", "content": [{"type": "text", "text": "x"}], "usage": {"input_tokens": 0, "output_tokens": 0}}}`) + d := ParseFile(path, "") + if d.Usage != nil { // python: any(usage.values()) is False + t.Errorf("Usage = %+v, want nil", *d.Usage) + } +} + +func TestParseFileNonDictEntryAbortsKeepingPriorData(t *testing.T) { + // python: entry.get on a list raises AttributeError -> blanket except + // aborts the scan, keeping what was collected so far. + path := writeFixture(t, userLine1, `["not", "a", "dict"]`, userLine2) + d := ParseFile(path, "") + if len(d.UserMessages) != 1 || d.UserMessages[0].Content != "first prompt" { + t.Errorf("UserMessages = %+v, want only the first", d.UserMessages) + } +} + +func TestParseFileNullContentAborts(t *testing.T) { + // python: `for item in None` raises TypeError -> abort. + path := writeFixture(t, + userLine1, + `{"type": "assistant", "message": {"role": "assistant", "content": null}}`, + userLine2) + d := ParseFile(path, "") + if len(d.UserMessages) != 1 { + t.Errorf("UserMessages = %d, want 1 (scan aborted)", len(d.UserMessages)) + } +} + +func TestParseFileStringContentIsNoop(t *testing.T) { + // python iterates the chars of a string content; none are dicts. + path := writeFixture(t, + `{"type": "assistant", "message": {"role": "assistant", "model": "m", "content": "plain string"}}`, + userLine1) + d := ParseFile(path, "") + if len(d.AssistantMessages) != 0 { + t.Errorf("AssistantMessages = %+v", d.AssistantMessages) + } + if d.Model != "m" || len(d.UserMessages) != 1 { + t.Errorf("scan must continue: Model=%v users=%d", d.Model, len(d.UserMessages)) + } +} + +func TestParseFileBadUsageValueAborts(t *testing.T) { + // python: int("garbage") raises ValueError -> abort, keeping prior sums. + path := writeFixture(t, + asstLate, + `{"type": "assistant", "timestamp": "2026-06-12T00:00:09Z", "message": {"role": "assistant", "content": [], "usage": {"input_tokens": "garbage"}}}`, + userLine1) + d := ParseFile(path, "") + if d.Usage == nil || d.Usage.InputTokens != 10 { + t.Errorf("Usage = %+v, want sums from the first entry only", d.Usage) + } + if len(d.UserMessages) != 0 { + t.Error("scan must have aborted before the user line") + } +} + +func TestParseFileUsageCoercion(t *testing.T) { + // python int() semantics: "5" parses, 5.9 truncates, null/absent are 0. + path := writeFixture(t, `{"type": "assistant", "message": {"role": "assistant", "content": [], "usage": {"input_tokens": "5", "output_tokens": 5.9, "cache_read_input_tokens": null}}}`) + d := ParseFile(path, "") + if d.Usage == nil || d.Usage.InputTokens != 5 || d.Usage.OutputTokens != 5 || + d.Usage.CacheReadInputTokens != 0 || d.Usage.TotalTokens != 10 { + t.Errorf("Usage = %+v", d.Usage) + } +} + +func TestPyInt(t *testing.T) { + cases := []struct { + in any + want int64 + ok bool + }{ + {nil, 0, true}, + {false, 0, true}, + {true, 1, true}, + {"", 0, true}, + {" 7 ", 7, true}, + {"x", 0, false}, + {pyjson.Number("12"), 12, true}, + {pyjson.Number("5.9"), 5, true}, + {pyjson.Number("-5.9"), -5, true}, // int() truncates toward zero + {pyjson.Number("0"), 0, true}, + {[]any{}, 0, true}, // falsy -> 0 + {[]any{1}, 0, false}, + } + for _, c := range cases { + got, ok := pyInt(c.in) + if got != c.want || ok != c.ok { + t.Errorf("pyInt(%#v) = %d, %v; want %d, %v", c.in, got, ok, c.want, c.ok) + } + } +}