From 6d7af6d1a57afd208798805d7e64ccbe64c7c125 Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Fri, 15 May 2026 15:23:21 -0400 Subject: [PATCH] test(parity): replace retired xfail tests with sync-green harness coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove three xfail tests in test_fsm_red_env_differential.py that depended on retired CybORG green/replay tape infrastructure. Their first two checks (red_4 known-hosts parity and red_4 action-selection parity) are already covered by test_red_policy_matches_cyborg_multistep across 200 steps x 5 seeds. The third (end-state host_compromised/red_privilege parity under FSM red + green phish) had no equivalent — existing green-sync tests use SleepAgent for red, so no exploit/privesc chains fire. Add TestFsmRedGreenSyncParity::test_no_critical_state_diffs_over_10_steps which closes that gap via CC4DifferentialHarness(FSM red + EnterpriseGreen + sync_green_rng=True) and asserts at least one privesc fired so the test can't pass on a degenerate trajectory. Also adds seed=0 to the existing red_policy_parity parametrize to preserve the original tests' seed. --- tests/differential/test_red_policy_parity.py | 2 +- tests/test_fsm_red_env_differential.py | 286 +++---------------- 2 files changed, 48 insertions(+), 240 deletions(-) diff --git a/tests/differential/test_red_policy_parity.py b/tests/differential/test_red_policy_parity.py index b0dc066..affd11e 100644 --- a/tests/differential/test_red_policy_parity.py +++ b/tests/differential/test_red_policy_parity.py @@ -25,7 +25,7 @@ pytestmark = pytest.mark.slow -@pytest.mark.parametrize("seed", [1000, 1001, 1002, 1003, 1004]) +@pytest.mark.parametrize("seed", [0, 1000, 1001, 1002, 1003, 1004]) def test_red_policy_matches_cyborg_multistep(seed: int) -> None: """JAX red picks identical to CybORG red picks across 200 steps under matched RNG.""" harness = CC4DifferentialHarness( diff --git a/tests/test_fsm_red_env_differential.py b/tests/test_fsm_red_env_differential.py index 66603ca..3ca69ae 100644 --- a/tests/test_fsm_red_env_differential.py +++ b/tests/test_fsm_red_env_differential.py @@ -9,9 +9,6 @@ pytestmark = pytest.mark.slow -TWO_STEP_TRACE_STEPS = 2 -THREE_STEP_TRACE_STEPS = 3 - @pytest.fixture def cyborg_sleep_env(): @@ -258,242 +255,6 @@ def test_sleep_blue_cumulative_reward_same_sign(self, cyborg_sleep_env, jax_env_ if cyborg_total < 0: assert jax_total <= 0, f"JAX sleep reward should be <= 0 when CybORG is {cyborg_total}" - @pytest.mark.xfail( - reason="requires retired CybORG green/replay tapes; topology snapshots preserve static layout only", - ) - def test_snapshot_topology_matches_red4_known_hosts_after_first_green_phish(self, tmp_path): - """After activation, native JAX red_4 should know only the same hosts as CybORG's FSM agent.""" - import types - - from CybORG import CybORG - from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent - from CybORG.Agents.Wrappers import BlueFlatWrapper - from CybORG.Simulator.Actions import Sleep - from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator - - from jaxborg.actions.encoding import BLUE_SLEEP - from jaxborg.constants import GLOBAL_MAX_HOSTS, NUM_BLUE_AGENTS - from jaxborg.parity.fsm_red_env import FsmRedCC4Env - from jaxborg.parity.translate import build_mappings_from_cyborg - from jaxborg.scenarios.cc4.red_fsm import fsm_red_apply_delayed_update - - seed = 0 - scenario = EnterpriseScenarioGenerator( - blue_agent_class=SleepAgent, - green_agent_class=EnterpriseGreenAgent, - red_agent_class=FiniteStateRedAgent, - steps=TWO_STEP_TRACE_STEPS, - ) - cyborg = CybORG(scenario_generator=scenario, seed=seed) - cyborg_env = BlueFlatWrapper(env=cyborg, pad_spaces=True) - cyborg_env.reset() - mappings = build_mappings_from_cyborg(cyborg) - topology_path = _save_cyborg_topology_snapshot(cyborg, tmp_path, seed) - - captured_known_hosts = None - interface = cyborg.environment_controller.agent_interfaces["red_agent_4"] - agent = interface.agent - original_get_action = agent.get_action - - def _wrapped(self, observation, action_space): - nonlocal captured_known_hosts - action = original_get_action(observation, action_space) - captured_known_hosts = np.zeros(GLOBAL_MAX_HOSTS, dtype=bool) - for ip, info in self.host_states.items(): - hostname = info.get("hostname") - if hostname in mappings.hostname_to_idx: - captured_known_hosts[mappings.hostname_to_idx[hostname]] = True - return action - - agent.get_action = types.MethodType(_wrapped, agent) - - jax_env = FsmRedCC4Env(num_steps=TWO_STEP_TRACE_STEPS, topology_path=topology_path) - loop_key = jax.random.PRNGKey(seed) - _, jax_state = jax_env.reset(loop_key) - blue_actions = {f"blue_{i}": jnp.int32(BLUE_SLEEP) for i in range(NUM_BLUE_AGENTS)} - - loop_key, step_key = jax.random.split(loop_key) - _, jax_state, _, _, _ = jax_env.step(step_key, jax_state, blue_actions) - _, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents}) - - loop_key, step_key = jax.random.split(loop_key) - state_before = fsm_red_apply_delayed_update(jax_state.state) - _, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents}) - - assert captured_known_hosts is not None, "Expected wrapped CybORG red_4 action to capture known hosts" - np.testing.assert_array_equal( - np.array(state_before.red_discovered_hosts[4], dtype=bool), - captured_known_hosts, - ) - - @pytest.mark.xfail( - reason="requires retired CybORG green/replay tapes; topology snapshots preserve static layout only", - ) - def test_snapshot_topology_matches_second_step_red4_action_after_green_phish(self, tmp_path): - """After the seed-0 phishing foothold, JAX and CybORG should pick the same red_4 follow-up action.""" - from CybORG import CybORG - from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent - from CybORG.Agents.Wrappers import BlueFlatWrapper - from CybORG.Simulator.Actions import Sleep - from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator - - from jaxborg.actions.encoding import BLUE_SLEEP - from jaxborg.constants import NUM_BLUE_AGENTS, NUM_RED_AGENTS - from jaxborg.parity.fsm_red_env import FsmRedCC4Env - from jaxborg.parity.translate import build_mappings_from_cyborg, jax_red_to_cyborg - from jaxborg.scenarios.cc4.red_fsm import fsm_red_apply_delayed_update, fsm_red_select_actions - - seed = 0 - scenario = EnterpriseScenarioGenerator( - blue_agent_class=SleepAgent, - green_agent_class=EnterpriseGreenAgent, - red_agent_class=FiniteStateRedAgent, - steps=TWO_STEP_TRACE_STEPS, - ) - cyborg = CybORG(scenario_generator=scenario, seed=seed) - cyborg_env = BlueFlatWrapper(env=cyborg, pad_spaces=True) - cyborg_env.reset() - mappings = build_mappings_from_cyborg(cyborg) - topology_path = _save_cyborg_topology_snapshot(cyborg, tmp_path, seed) - - logged_actions = {} - for agent_name, interface in cyborg.environment_controller.agent_interfaces.items(): - if not agent_name.startswith("red_agent_"): - continue - agent = interface.agent - original_get_action = agent.get_action - - def _wrap_get_action(orig_fn, wrapped_name): - def _wrapped(self, observation, action_space): - action = orig_fn(observation, action_space) - logged_actions[wrapped_name] = action - return action - - return types.MethodType(_wrapped, agent) - - agent.get_action = _wrap_get_action(original_get_action, agent_name) - - jax_env = FsmRedCC4Env(num_steps=TWO_STEP_TRACE_STEPS, topology_path=topology_path) - loop_key = jax.random.PRNGKey(seed) - _, jax_state = jax_env.reset(loop_key) - blue_actions = {f"blue_{i}": jnp.int32(BLUE_SLEEP) for i in range(NUM_BLUE_AGENTS)} - - loop_key, step_key = jax.random.split(loop_key) - _, jax_state, _, _, _ = jax_env.step(step_key, jax_state, blue_actions) - _, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents}) - - loop_key, step_key = jax.random.split(loop_key) - key_for_step_env, _key_reset = jax.random.split(step_key) - _key_unused, key_red = jax.random.split(key_for_step_env) - red_keys = jax.random.split(key_red, NUM_RED_AGENTS) - state_before = fsm_red_apply_delayed_update(jax_state.state) - jax_red_actions = fsm_red_select_actions(state_before, jax_state.const, red_keys)[0] - jax_red4 = jax_red_to_cyborg(int(jax_red_actions[4]), 4, mappings) - - _, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents}) - cyborg_red4 = logged_actions["red_agent_4"] - - def _action_target(action): - if hasattr(action, "hostname"): - return action.hostname - if hasattr(action, "ip_address"): - return str(action.ip_address) - return None - - assert type(jax_red4).__name__ == type(cyborg_red4).__name__ == "PrivilegeEscalate" - assert _action_target(jax_red4) == _action_target(cyborg_red4) - - @pytest.mark.xfail( - reason="requires retired CybORG green/replay tapes; topology snapshots preserve static layout only", - ) - def test_explicit_cyborg_red_trace_matches_green_phish_privesc_privilege(self, tmp_path): - """Replaying CybORG's first seed-0 red trace should preserve the red_4 privesc privilege gain.""" - from CybORG import CybORG - from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent - from CybORG.Agents.Wrappers import BlueFlatWrapper - from CybORG.Simulator.Actions import Sleep - from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator - - from jaxborg.actions.encoding import BLUE_SLEEP, RED_SLEEP - from jaxborg.constants import COMPROMISE_PRIVILEGED, NUM_BLUE_AGENTS, NUM_RED_AGENTS - from jaxborg.env import ScenarioEnv - from jaxborg.parity.translate import build_mappings_from_cyborg, cyborg_red_to_jax - from tests.differential.state_comparator import compare_snapshots, extract_cyborg_snapshot, extract_jax_snapshot - - seed = 0 - scenario = EnterpriseScenarioGenerator( - blue_agent_class=SleepAgent, - green_agent_class=EnterpriseGreenAgent, - red_agent_class=FiniteStateRedAgent, - steps=THREE_STEP_TRACE_STEPS, - ) - cyborg = CybORG(scenario_generator=scenario, seed=seed) - cyborg_env = BlueFlatWrapper(env=cyborg, pad_spaces=True) - cyborg_env.reset() - mappings = build_mappings_from_cyborg(cyborg) - topology_path = _save_cyborg_topology_snapshot(cyborg, tmp_path, seed) - - logged_actions = {} - for agent_name, interface in cyborg.environment_controller.agent_interfaces.items(): - if not agent_name.startswith("red_agent_"): - continue - agent = interface.agent - original_get_action = agent.get_action - - def _wrap_get_action(orig_fn, wrapped_name): - def _wrapped(self, observation, action_space): - action = orig_fn(observation, action_space) - logged_actions[wrapped_name] = action - return action - - return types.MethodType(_wrapped, agent) - - agent.get_action = _wrap_get_action(original_get_action, agent_name) - - jax_env = ScenarioEnv(num_steps=THREE_STEP_TRACE_STEPS, topology_path=topology_path) - loop_key = jax.random.PRNGKey(seed) - _, jax_state = jax_env.reset(loop_key) - - for _step in range(3): - logged_actions.clear() - _, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents}) - red_actions = {} - for agent_id in range(NUM_RED_AGENTS): - agent_name = f"red_agent_{agent_id}" - cy_action = logged_actions.get(agent_name) - red_actions[f"red_{agent_id}"] = jnp.int32( - RED_SLEEP if cy_action is None else cyborg_red_to_jax(cy_action, agent_name, mappings) - ) - - loop_key, step_key = jax.random.split(loop_key) - blue_actions = {f"blue_{i}": jnp.int32(BLUE_SLEEP) for i in range(NUM_BLUE_AGENTS)} - _, jax_state, _, _, _ = jax_env.step(step_key, jax_state, {**blue_actions, **red_actions}) - - target_hostname = "operational_zone_b_subnet_user_host_5" - target_host = mappings.hostname_to_idx[target_hostname] - target_sessions = [ - sess - for sess in cyborg.environment_controller.state.sessions["red_agent_4"].values() - if sess.hostname == target_hostname - ] - assert any(sess.has_privileged_access() for sess in target_sessions), target_sessions - assert int(jax_state.state.red_privilege[4, target_host]) == COMPROMISE_PRIVILEGED - assert int(jax_state.state.host_compromised[target_host]) == COMPROMISE_PRIVILEGED - - diffs = compare_snapshots( - extract_cyborg_snapshot(cyborg, mappings), - extract_jax_snapshot(jax_state.state, jax_state.const, mappings), - ) - host_label = f"host_{target_host}" - agent_host_label = f"red_4_host_{target_host}" - target_diffs = [ - diff - for diff in diffs - if diff.field_name in {"host_compromised", "red_privilege"} - and diff.host_or_agent in {host_label, agent_host_label} - ] - assert target_diffs == [] - def test_explicit_replay_corrects_generic_exploit_to_cyborg_subaction(self, tmp_path): """Seed-0 generic exploit replay should not invent a host_22 foothold.""" from CybORG import CybORG @@ -743,3 +504,50 @@ def test_native_generic_exploit_respects_blocked_scan_source_route_matches_cybor break assert found_step is not None, "Never found a red exploit from a blocked subnet in 300 steps" + + +class TestFsmRedGreenSyncParity: + """End-state parity for {FSM red + EnterpriseGreen + Sleep blue} under sync_green_rng. + + Replaces three retired tests that depended on green/replay tape infrastructure. + `tests/differential/test_red_policy_parity.py` already verifies CybORG/JAX red-policy + picks and host_states eligibility match step-by-step; this test guards the remaining + angle — that the resulting host_compromised/red_privilege/red_sessions state also + matches after a multi-step green-phish → red-exploit → red-privesc chain. + """ + + @pytest.mark.parametrize("seed", [0, 1000]) + def test_no_critical_state_diffs_over_10_steps(self, seed): + from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent + + from jaxborg.actions.encoding import BLUE_SLEEP + from jaxborg.constants import NUM_BLUE_AGENTS + from tests.differential.harness import CC4DifferentialHarness + + harness = CC4DifferentialHarness( + seed=seed, + max_steps=10, + blue_cls=SleepAgent, + green_cls=EnterpriseGreenAgent, + red_cls=FiniteStateRedAgent, + sync_green_rng=True, + check_rewards=False, + check_obs=False, + check_masks=False, + ) + harness.reset() + + critical = {"host_compromised", "red_privilege", "red_sessions"} + sleep_blue = {b: BLUE_SLEEP for b in range(NUM_BLUE_AGENTS)} + saw_privesc = False + for step in range(10): + result = harness.full_step(sleep_blue) + errors = [d for d in result.diffs if d.field_name in critical] + assert errors == [], f"seed={seed} step={step}: critical state diffs: {errors[:5]}" + if not saw_privesc and int(jnp.max(harness.jax_state.red_privilege)) > 0: + saw_privesc = True + + assert saw_privesc, ( + f"seed={seed}: no red agent reached privileged access in 10 steps; " + "test trajectory is degenerate (green phish → red exploit → privesc chain didn't fire)" + )