diff --git a/claude-code/hooks/unbound.py b/claude-code/hooks/unbound.py index d0d5918..7e0194b 100644 --- a/claude-code/hooks/unbound.py +++ b/claude-code/hooks/unbound.py @@ -20,8 +20,8 @@ AUDIT_LOG = Path.home() / ".claude" / "hooks" / "agent-audit.log" ERROR_LOG = Path.home() / ".claude" / "hooks" / "error.log" LAST_REPORT_FILE = Path.home() / ".claude" / "hooks" / ".last_error_report" -ALLOWED_NON_MCP_HOOK_NAMES = ['Bash', 'Read', 'Write', 'Edit'] # MCP tools (mcp__*) are always checked separately -NATIVE_FILE_TOOLS = {'Read', 'Write', 'Edit'} +ALLOWED_NON_MCP_HOOK_NAMES = ['Bash', 'Read', 'Write', 'Edit', 'MultiEdit', 'NotebookEdit'] # MCP tools (mcp__*) are always checked separately +NATIVE_FILE_TOOLS = {'Read', 'Write', 'Edit', 'MultiEdit', 'NotebookEdit'} MCP_TOOL_PREFIX = 'mcp__' CLAUDE_MCP_CONFIG_PATH = Path.home() / ".claude.json" CLAUDE_PLUGIN_CACHE_DIR = Path.home() / ".claude" / "plugins" / "cache" @@ -488,6 +488,13 @@ def _build_user_prompt_payload(recent_user_prompts: List[str]) -> Dict: } +def _tool_file_path(tool_input: Dict) -> Optional[str]: + """Target path for a native file tool. NotebookEdit nests it under + notebook_path; Read/Write/Edit/MultiEdit use file_path.""" + path = tool_input.get('file_path') or tool_input.get('notebook_path') + return path if isinstance(path, str) and path else None + + def extract_command_for_pretool(event: Dict) -> str: """Extract command from tool_input based on tool type.""" tool_input = event.get('tool_input') or {} @@ -499,9 +506,11 @@ def extract_command_for_pretool(event: Dict) -> str: # MCP tools: stringify the input if tool_name.startswith(MCP_TOOL_PREFIX): return json.dumps(tool_input) - # File tools: file_path - if tool_name in ['Write', 'Edit', 'Read'] and 'file_path' in tool_input: - return tool_input['file_path'] + # File tools: file_path / notebook_path + if tool_name in NATIVE_FILE_TOOLS: + path = _tool_file_path(tool_input) + if path: + return path # Grep: pattern if tool_name == 'Grep' and 'pattern' in tool_input: return tool_input['pattern'] @@ -937,6 +946,77 @@ def _get_device_serial() -> Optional[str]: return None +_GIT_CONTEXT_CACHE: Dict = {} +_GIT_CONTEXT_CACHE_MAX = 256 + + +def _cache_git_context(key, value) -> None: + if key not in _GIT_CONTEXT_CACHE and len(_GIT_CONTEXT_CACHE) >= _GIT_CONTEXT_CACHE_MAX: + _GIT_CONTEXT_CACHE.pop(next(iter(_GIT_CONTEXT_CACHE)), None) + _GIT_CONTEXT_CACHE[key] = value + + +def _strip_git_credentials(url): + try: + if not url or '@' not in url: + return url + scheme = re.match(r'^[a-zA-Z][a-zA-Z0-9+.-]*://', url) + prefix = scheme.group(0) if scheme else '' + rest = url[len(prefix):] + slash = rest.find('/') + authority = rest if slash == -1 else rest[:slash] + tail = '' if slash == -1 else rest[slash:] + at = authority.rfind('@') + if at == -1: + return url + return prefix + authority[at + 1:] + tail + except Exception: + return None + + +def _get_git_context(session_id: Optional[str], cwd: Optional[str]) -> Optional[str]: + """Credential-stripped origin remote URL for cwd, or None. Successful and + conclusive-no-repo lookups are cached per (session_id, cwd); transient + failures are not cached. Never raises.""" + key = (session_id, cwd) + if key in _GIT_CONTEXT_CACHE: + return _GIT_CONTEXT_CACHE[key] + if not cwd: + return None + try: + out = subprocess.run( + ['git', '-C', cwd, 'config', '--get', 'remote.origin.url'], + capture_output=True, text=True, timeout=2, + ) + except Exception as exc: + log_error(f"git context lookup failed session={session_id} cwd={cwd}: {exc}", 'git_context') + return None + result = None + if out.returncode == 0: + url = out.stdout.strip() + if url: + result = _strip_git_credentials(url) + _cache_git_context(key, result) + return result + + +def _repo_context_dir(cwd: Optional[str], file_path) -> Optional[str]: + """Directory whose git repo governs the operation: the nearest existing + ancestor of the target file for file tools, else the session cwd. The + ancestor walk lets a write into a not-yet-created path still resolve to its + enclosing repo.""" + if isinstance(file_path, str) and file_path: + base = file_path if os.path.isabs(file_path) else os.path.join(cwd or '', file_path) + d = os.path.dirname(base) or cwd + while d and not os.path.isdir(d): + parent = os.path.dirname(d) + if parent == d: + break + d = parent + return d or cwd + return cwd + + def _device_serial(probe: bool = True) -> Optional[str]: """Hardware serial, computed once and cached. Never raises and never blocks the hook. On the latency-critical pre-tool path callers pass probe=False to read the @@ -1023,8 +1103,9 @@ def process_pre_tool_use(event: Dict, api_key: str) -> Dict: # Build metadata with the raw event metadata = dict(event) tool_input = event.get('tool_input') or {} - if 'file_path' in tool_input: - metadata['file_path'] = tool_input['file_path'] + file_path = _tool_file_path(tool_input) + if file_path: + metadata['file_path'] = file_path if is_mcp: # Parse mcp____ to extract server and tool for gateway matching @@ -1108,6 +1189,10 @@ def process_pre_tool_use(event: Dict, api_key: str) -> Dict: 'additionalContext': 'This command was blocked by an organization security policy that requires approval. Do not attempt to achieve the same result using alternative tools, file operations, or workarounds. The user must approve via Slack and retry.', }) + request_body['git_remote_url'] = _get_git_context( + session_id, _repo_context_dir(event.get('cwd'), file_path) + ) + if need_pull_policies: request_body['pull_policies'] = True diff --git a/copilot/hooks/unbound.py b/copilot/hooks/unbound.py index 8c2469e..f3c772a 100644 --- a/copilot/hooks/unbound.py +++ b/copilot/hooks/unbound.py @@ -919,6 +919,77 @@ def transform_response_for_copilot_prompt(api_response): return {} +_GIT_CONTEXT_CACHE = {} +_GIT_CONTEXT_CACHE_MAX = 256 + + +def _cache_git_context(key, value): + if key not in _GIT_CONTEXT_CACHE and len(_GIT_CONTEXT_CACHE) >= _GIT_CONTEXT_CACHE_MAX: + _GIT_CONTEXT_CACHE.pop(next(iter(_GIT_CONTEXT_CACHE)), None) + _GIT_CONTEXT_CACHE[key] = value + + +def _strip_git_credentials(url): + try: + if not url or '@' not in url: + return url + scheme = re.match(r'^[a-zA-Z][a-zA-Z0-9+.-]*://', url) + prefix = scheme.group(0) if scheme else '' + rest = url[len(prefix):] + slash = rest.find('/') + authority = rest if slash == -1 else rest[:slash] + tail = '' if slash == -1 else rest[slash:] + at = authority.rfind('@') + if at == -1: + return url + return prefix + authority[at + 1:] + tail + except Exception: + return None + + +def _get_git_context(session_id, cwd): + """Credential-stripped origin remote URL for cwd, or None. Successful and + conclusive-no-repo lookups are cached per (session_id, cwd); transient + failures are not cached. Never raises.""" + key = (session_id, cwd) + if key in _GIT_CONTEXT_CACHE: + return _GIT_CONTEXT_CACHE[key] + if not cwd: + return None + try: + out = subprocess.run( + ['git', '-C', cwd, 'config', '--get', 'remote.origin.url'], + capture_output=True, text=True, timeout=2, + ) + except Exception as exc: + log_error(f"git context lookup failed session={session_id} cwd={cwd}: {exc}", 'git_context') + return None + result = None + if out.returncode == 0: + url = out.stdout.strip() + if url: + result = _strip_git_credentials(url) + _cache_git_context(key, result) + return result + + +def _repo_context_dir(cwd, file_path): + """Directory whose git repo governs the operation: the nearest existing + ancestor of the target file for file tools, else the session cwd. The + ancestor walk lets a write into a not-yet-created path still resolve to its + enclosing repo.""" + if isinstance(file_path, str) and file_path: + base = file_path if os.path.isabs(file_path) else os.path.join(cwd or '', file_path) + d = os.path.dirname(base) or cwd + while d and not os.path.isdir(d): + parent = os.path.dirname(d) + if parent == d: + break + d = parent + return d or cwd + return cwd + + def process_pre_tool_use(event, api_key): """Process PreToolUse event - check policy before tool execution.""" raw_tool = event.get('tool_name') or event.get('toolName') or '' @@ -999,7 +1070,8 @@ def process_pre_tool_use(event, api_key): # Preserve the raw event (raw tool_name + tool_input) inside metadata. metadata = dict(event) - file_path = tool_input.get('filePath') or tool_input.get('path') or tool_input.get('file_path') + file_path = (tool_input.get('filePath') or tool_input.get('path') + or tool_input.get('file_path') or _extract_patch_target_path(tool_input)) if file_path: metadata['file_path'] = file_path @@ -1052,6 +1124,10 @@ def process_pre_tool_use(event, api_key): 'additionalContext': 'This action was blocked by an organization security policy that requires approval. Do not attempt to achieve the same result using alternative tools, file operations, or workarounds. The user must approve via Slack and retry.', }) + request_body['git_remote_url'] = _get_git_context( + session_id, _repo_context_dir(event.get('cwd'), file_path) + ) + if need_pull_policies: request_body['pull_policies'] = True diff --git a/cursor/unbound.py b/cursor/unbound.py index fc8afd8..b52c249 100644 --- a/cursor/unbound.py +++ b/cursor/unbound.py @@ -662,6 +662,75 @@ def build_account_identity(event=None, probe=False): return identity +_GIT_CONTEXT_CACHE = {} +_GIT_CONTEXT_CACHE_MAX = 256 + + +def _cache_git_context(key, value): + if key not in _GIT_CONTEXT_CACHE and len(_GIT_CONTEXT_CACHE) >= _GIT_CONTEXT_CACHE_MAX: + _GIT_CONTEXT_CACHE.pop(next(iter(_GIT_CONTEXT_CACHE)), None) + _GIT_CONTEXT_CACHE[key] = value + + +def _strip_git_credentials(url): + try: + if not url or '@' not in url: + return url + scheme = re.match(r'^[a-zA-Z][a-zA-Z0-9+.-]*://', url) + prefix = scheme.group(0) if scheme else '' + rest = url[len(prefix):] + slash = rest.find('/') + authority = rest if slash == -1 else rest[:slash] + tail = '' if slash == -1 else rest[slash:] + at = authority.rfind('@') + if at == -1: + return url + return prefix + authority[at + 1:] + tail + except Exception: + return None + + +def _get_git_context(session_id, cwd): + """Credential-stripped origin remote URL for cwd, or None. Successful and + conclusive-no-repo lookups are cached per (session_id, cwd); transient + failures are not cached. Never raises.""" + key = (session_id, cwd) + if key in _GIT_CONTEXT_CACHE: + return _GIT_CONTEXT_CACHE[key] + if not cwd: + return None + try: + out = subprocess.run( + ['git', '-C', cwd, 'config', '--get', 'remote.origin.url'], + capture_output=True, text=True, timeout=2, + ) + except Exception as exc: + log_error(f"git context lookup failed session={session_id} cwd={cwd}: {exc}", 'git_context') + return None + result = None + if out.returncode == 0: + url = out.stdout.strip() + if url: + result = _strip_git_credentials(url) + _cache_git_context(key, result) + return result + + +def _repo_context_dir(cwd, file_path): + """Directory whose git repo governs the operation: the nearest existing + ancestor of the target file for file tools, else the session cwd.""" + if isinstance(file_path, str) and file_path: + base = file_path if os.path.isabs(file_path) else os.path.join(cwd or '', file_path) + d = os.path.dirname(base) or cwd + while d and not os.path.isdir(d): + parent = os.path.dirname(d) + if parent == d: + break + d = parent + return d or cwd + return cwd + + def process_pre_tool_use(event, api_key): """Process preToolUse event - check policy before tool execution.""" tool_name = event.get('tool_name', '') @@ -742,6 +811,10 @@ def process_pre_tool_use(event, api_key): 'agent_message': 'This action was blocked by an organization security policy that requires approval. Do not attempt to achieve the same result using alternative tools, file operations, or workarounds. The user must approve via Slack and retry.', } + request_body['git_remote_url'] = _get_git_context( + conversation_id, _repo_context_dir(event.get('cwd'), file_path) + ) + if need_pull_policies: request_body['pull_policies'] = True diff --git a/test_repo_allowlist.py b/test_repo_allowlist.py new file mode 100644 index 0000000..0f8dd34 --- /dev/null +++ b/test_repo_allowlist.py @@ -0,0 +1,377 @@ +"""Repo Allowlist client-hook tests: _get_git_context behavior and +git_remote_url payload parity across claude-code and copilot.""" + +import contextlib +import importlib.util +import os +import subprocess +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +ROOT = Path(__file__).resolve().parent + + +def _load(name, rel): + spec = importlib.util.spec_from_file_location(name, ROOT / rel) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +cc = _load("cc_unbound", "claude-code/hooks/unbound.py") +co = _load("co_unbound", "copilot/hooks/unbound.py") +cur = _load("cur_unbound", "cursor/unbound.py") + +ALL_HOOKS = [("claude-code", cc), ("copilot", co), ("cursor", cur)] + +# (label, module, native file-tool name, session-id key, file_path key) +HOOK_TOOL_CASES = [ + ("claude-code", cc, "Edit", "session_id", "file_path"), + ("claude-code-multiedit", cc, "MultiEdit", "session_id", "file_path"), + ("claude-code-notebookedit", cc, "NotebookEdit", "session_id", "notebook_path"), + ("copilot", co, "Edit", "session_id", "filePath"), + ("cursor", cur, "Write", "conversation_id", "file_path"), +] + + +def _git(args, cwd): + subprocess.run(["git", *args], cwd=cwd, check=True, + capture_output=True, text=True) + + +def _make_repo(remote_url): + d = tempfile.mkdtemp() + _git(["init", "-q"], d) + _git(["remote", "add", "origin", remote_url], d) + return d + + +class TestGetGitContext(unittest.TestCase): + def setUp(self): + for _, mod in ALL_HOOKS: + mod._GIT_CONTEXT_CACHE.clear() + + def _both(self): + return ALL_HOOKS + + def test_returns_origin_url(self): + repo = _make_repo("https://github.com/org/repo.git") + for label, mod in self._both(): + with self.subTest(hook=label): + self.assertEqual( + mod._get_git_context("s1", repo), + "https://github.com/org/repo.git", + ) + + def test_strips_credentials(self): + repo = _make_repo("https://user:token@github.com/org/repo.git") + for label, mod in self._both(): + with self.subTest(hook=label): + result = mod._get_git_context("s1", repo) + self.assertNotIn("user:token@", result) + self.assertNotIn("@", result) + self.assertEqual(result, "https://github.com/org/repo.git") + + def test_no_git_repo(self): + plain = tempfile.mkdtemp() + for label, mod in self._both(): + with self.subTest(hook=label): + self.assertIsNone(mod._get_git_context("s1", plain)) + + def test_no_cwd(self): + for label, mod in self._both(): + with self.subTest(hook=label): + self.assertIsNone(mod._get_git_context("s1", None)) + + def test_git_missing_no_raise(self): + for label, mod in self._both(): + with self.subTest(hook=label): + with patch("subprocess.run", side_effect=FileNotFoundError()): + self.assertIsNone(mod._get_git_context("s_missing", "/x")) + + def test_timeout_no_raise(self): + for label, mod in self._both(): + with self.subTest(hook=label): + with patch("subprocess.run", + side_effect=subprocess.TimeoutExpired("git", 2)): + self.assertIsNone(mod._get_git_context("s_timeout", "/x")) + + def test_caches_per_session_cwd(self): + repo = _make_repo("https://github.com/org/repo.git") + for label, mod in self._both(): + with self.subTest(hook=label): + mod._GIT_CONTEXT_CACHE.clear() + real = subprocess.run + with patch("subprocess.run", side_effect=real) as spy: + mod._get_git_context("s_cache", repo) + mod._get_git_context("s_cache", repo) + self.assertEqual(spy.call_count, 1) + + +def _capture(mod, fn, event): + with contextlib.ExitStack() as stack: + send = stack.enter_context( + patch.object(mod, "send_to_hook_api", return_value={"decision": "allow"})) + if hasattr(mod, "build_account_identity"): + stack.enter_context( + patch.object(mod, "build_account_identity", return_value={})) + if hasattr(mod, "load_policy_cache"): + stack.enter_context(patch.object(mod, "load_policy_cache", return_value=None)) + fn(event, "key") + return send.call_args.args[0] + + +class TestRepoContextDir(unittest.TestCase): + """The repo identity follows the operation target: the nearest existing + ancestor of the edited/read file for file tools, the session cwd otherwise.""" + + def test_uses_existing_target_directory(self): + parent = tempfile.mkdtemp() + sub = os.path.join(parent, "src") + os.makedirs(sub) + target = os.path.join(sub, "x.py") + for mod in (cc, co, cur): + with self.subTest(mod=mod.__name__): + self.assertEqual(mod._repo_context_dir(parent, target), sub) + + def test_walks_up_to_nearest_existing_dir(self): + parent = tempfile.mkdtemp() + target = os.path.join(parent, "brand", "new", "deep", "file.py") + for mod in (cc, co, cur): + with self.subTest(mod=mod.__name__): + self.assertEqual(mod._repo_context_dir(parent, target), parent) + + def test_falls_back_to_cwd_without_file_path(self): + parent = tempfile.mkdtemp() + for mod in (cc, co, cur): + with self.subTest(mod=mod.__name__): + self.assertEqual(mod._repo_context_dir(parent, None), parent) + + +class TestTargetDirResolution(unittest.TestCase): + """A file tool launched from a non-git parent resolves to the subdir-repo + it actually edits; a file outside any repo resolves to nothing.""" + + def setUp(self): + for _, mod in ALL_HOOKS: + mod._GIT_CONTEXT_CACHE.clear() + + def test_file_in_subdir_repo_resolves_from_parent_cwd(self): + parent = tempfile.mkdtemp() + repo = Path(parent) / "service" + (repo / "src").mkdir(parents=True) + _git(["init", "-q"], str(repo)) + _git(["remote", "add", "origin", "https://github.com/org/service.git"], str(repo)) + target = str(repo / "src" / "x.py") + + for label, mod, tool, idk, pathk in HOOK_TOOL_CASES: + with self.subTest(hook=label): + body = _capture(mod, mod.process_pre_tool_use, { + idk: "t", "tool_name": tool, "cwd": parent, + "tool_input": {pathk: target}}) + self.assertEqual( + body["git_remote_url"], "https://github.com/org/service.git") + + def test_file_outside_any_repo_resolves_null(self): + parent = tempfile.mkdtemp() + target = str(Path(parent) / "notes.py") + + for label, mod, tool, idk, pathk in HOOK_TOOL_CASES: + with self.subTest(hook=label): + body = _capture(mod, mod.process_pre_tool_use, { + idk: "t2", "tool_name": tool, "cwd": parent, + "tool_input": {pathk: target}}) + self.assertIsNone(body["git_remote_url"]) + + def test_warm_cache_still_sends_file_tool_when_listed(self): + repo = _make_repo("https://github.com/org/repo.git") + target = os.path.join(repo, "x.py") + for label, mod, tool, idk, pathk in HOOK_TOOL_CASES: + with self.subTest(hook=label), contextlib.ExitStack() as stack: + stack.enter_context(patch.object( + mod, "load_policy_cache", return_value={"tools_to_check": [tool]})) + stack.enter_context(patch.object(mod, "is_cache_stale", return_value=False)) + send = stack.enter_context(patch.object( + mod, "send_to_hook_api", return_value={"decision": "allow"})) + if hasattr(mod, "build_account_identity"): + stack.enter_context(patch.object(mod, "build_account_identity", return_value={})) + mod.process_pre_tool_use({ + idk: "w", "tool_name": tool, "cwd": repo, + "tool_input": {pathk: target}}, "k") + self.assertTrue(send.called) + + +class TestPayloadParity(unittest.TestCase): + """The prompt path is no longer gated, so its body omits git_remote_url; + the tool path carries it for both hooks.""" + + def setUp(self): + for _, mod in ALL_HOOKS: + mod._GIT_CONTEXT_CACHE.clear() + self.repo = _make_repo("https://github.com/org/repo.git") + + def test_tool_bodies_carry_git_remote_url(self): + tool_event = {"session_id": "t", "tool_name": "Bash", + "tool_input": {"command": "ls"}, "cwd": self.repo} + for label, mod in (("claude-code", cc), ("copilot", co)): + body = _capture(mod, mod.process_pre_tool_use, dict(tool_event)) + with self.subTest(hook=label): + self.assertEqual( + body["git_remote_url"], "https://github.com/org/repo.git") + + def test_prompt_bodies_omit_git_remote_url(self): + prompt_event = {"session_id": "p", "prompt": "hi", "cwd": self.repo} + for label, mod in (("claude-code", cc), ("copilot", co)): + body = _capture(mod, mod.process_user_prompt_submit, dict(prompt_event)) + with self.subTest(hook=label): + self.assertNotIn("git_remote_url", body) + + +class TestToolFilePath(unittest.TestCase): + def test_empty_or_missing_path_is_none(self): + self.assertIsNone(cc._tool_file_path({})) + self.assertIsNone(cc._tool_file_path({'file_path': ''})) + + +class TestStripGitCredentials(unittest.TestCase): + def _both(self): + return ALL_HOOKS + + def _check(self, url, expected): + for label, mod in self._both(): + with self.subTest(hook=label, url=url): + out = mod._strip_git_credentials(url) + self.assertEqual(out, expected) + self.assertNotIn("@", out) + + def test_scheme_with_token(self): + self._check("https://user:token@github.com/org/repo.git", + "https://github.com/org/repo.git") + + def test_scp_form_with_token(self): + self._check("user:token@github.com:org/repo.git", + "github.com:org/repo.git") + + def test_scp_form_plain_user(self): + self._check("git@github.com:org/repo.git", + "github.com:org/repo.git") + + def test_password_with_at_scheme(self): + self._check("https://user:p@ss@w@rd@github.com/org/repo", + "https://github.com/org/repo") + + def test_password_with_at_scp(self): + self._check("user:p@ss@github.com:org/repo", + "github.com:org/repo") + + def test_ssh_scheme_with_port(self): + self._check("ssh://git@github.com:22/org/repo", + "ssh://github.com:22/org/repo") + + def test_no_userinfo_unchanged(self): + for url in ("https://github.com/org/repo.git", + "github.com:org/repo.git", "", "/local/path/repo"): + for label, mod in self._both(): + with self.subTest(hook=label, url=url): + self.assertEqual(mod._strip_git_credentials(url), url) + + def test_parse_failure_fails_closed(self): + for label, mod in self._both(): + with self.subTest(hook=label): + with patch.object(mod.re, "match", side_effect=ValueError("boom")): + out = mod._strip_git_credentials( + "https://user:token@github.com/org/repo.git") + self.assertIsNone(out) + + +class TestGitContextFailureHandling(unittest.TestCase): + def setUp(self): + for _, mod in ALL_HOOKS: + mod._GIT_CONTEXT_CACHE.clear() + + def test_transient_failure_not_cached(self): + repo = _make_repo("https://github.com/org/repo.git") + for label, mod in ALL_HOOKS: + with self.subTest(hook=label): + mod._GIT_CONTEXT_CACHE.clear() + real = subprocess.run + state = {"n": 0} + + def flaky(*a, **k): + state["n"] += 1 + if state["n"] == 1: + raise subprocess.TimeoutExpired("git", 2) + return real(*a, **k) + + with patch("subprocess.run", side_effect=flaky): + first = mod._get_git_context("s_flaky", repo) + second = mod._get_git_context("s_flaky", repo) + self.assertIsNone(first) + self.assertEqual(second, "https://github.com/org/repo.git") + + def test_cache_is_bounded(self): + for label, mod in ALL_HOOKS: + with self.subTest(hook=label): + mod._GIT_CONTEXT_CACHE.clear() + cap = mod._GIT_CONTEXT_CACHE_MAX + for i in range(cap + 50): + mod._cache_git_context(("s", i), "url") + self.assertLessEqual(len(mod._GIT_CONTEXT_CACHE), cap) + + +class TestApprovalRetrySkipsGit(unittest.TestCase): + def setUp(self): + for _, mod in ALL_HOOKS: + mod._GIT_CONTEXT_CACHE.clear() + + def test_retry_path_skips_git_context(self): + repo = _make_repo("https://github.com/org/repo.git") + cases = [ + (cc, {"session_id": "r", "tool_name": "Bash", + "tool_input": {"command": "ls"}, "cwd": repo}), + (co, {"session_id": "r", "tool_name": "Bash", + "tool_input": {"command": "ls"}, "cwd": repo}), + (cur, {"conversation_id": "r", "tool_name": "Write", + "tool_input": {"file_path": str(Path(repo) / "x.py")}, "cwd": repo}), + ] + for mod, event in cases: + with self.subTest(mod=mod.__name__), contextlib.ExitStack() as stack: + stack.enter_context(patch.object(mod, "_is_approval_retry", return_value=True)) + stack.enter_context(patch.object( + mod, "_get_approval_marker_data", + return_value={"policyIds": [], "applicationId": "", "requestId": ""})) + stack.enter_context(patch.object(mod, "poll_approval_status", return_value="approved")) + stack.enter_context(patch.object(mod, "_clear_approval_marker")) + stack.enter_context(patch.object(mod, "load_policy_cache", return_value=None)) + if hasattr(mod, "build_account_identity"): + stack.enter_context(patch.object(mod, "build_account_identity", return_value={})) + git = stack.enter_context(patch.object(mod, "_get_git_context")) + mod.process_pre_tool_use(event, "k") + git.assert_not_called() + + +class TestApplyPatchRepoContext(unittest.TestCase): + """Copilot canonicalizes apply_patch as Edit; the repo identity must come + from the path inside the patch payload, not the (missing) filePath arg.""" + + def setUp(self): + co._GIT_CONTEXT_CACHE.clear() + + def test_apply_patch_resolves_repo_from_payload(self): + parent = tempfile.mkdtemp() + repo = Path(parent) / "svc" + (repo / "src").mkdir(parents=True) + _git(["init", "-q"], str(repo)) + _git(["remote", "add", "origin", "https://github.com/org/svc.git"], str(repo)) + target = str(repo / "src" / "a.py") + patch_text = f"*** Begin Patch\n*** Update File: {target}\n@@\n-old\n+new\n*** End Patch\n" + body = _capture(co, co.process_pre_tool_use, { + "session_id": "ap", "tool_name": "apply_patch", + "tool_input": {"input": patch_text}, "cwd": parent}) + self.assertEqual(body["git_remote_url"], "https://github.com/org/svc.git") + + +if __name__ == "__main__": + unittest.main()