Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/differential/test_red_policy_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
pytestmark = pytest.mark.slow


@pytest.mark.parametrize("seed", [1000, 1001, 1002, 1003, 1004])
@pytest.mark.parametrize("seed", [0, 1000, 1001, 1002, 1003, 1004])
def test_red_policy_matches_cyborg_multistep(seed: int) -> None:
"""JAX red picks identical to CybORG red picks across 200 steps under matched RNG."""
harness = CC4DifferentialHarness(
Expand Down
286 changes: 47 additions & 239 deletions tests/test_fsm_red_env_differential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@

pytestmark = pytest.mark.slow

TWO_STEP_TRACE_STEPS = 2
THREE_STEP_TRACE_STEPS = 3


@pytest.fixture
def cyborg_sleep_env():
Expand Down Expand Up @@ -258,242 +255,6 @@ def test_sleep_blue_cumulative_reward_same_sign(self, cyborg_sleep_env, jax_env_
if cyborg_total < 0:
assert jax_total <= 0, f"JAX sleep reward should be <= 0 when CybORG is {cyborg_total}"

@pytest.mark.xfail(
reason="requires retired CybORG green/replay tapes; topology snapshots preserve static layout only",
)
def test_snapshot_topology_matches_red4_known_hosts_after_first_green_phish(self, tmp_path):
"""After activation, native JAX red_4 should know only the same hosts as CybORG's FSM agent."""
import types

from CybORG import CybORG
from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent
from CybORG.Agents.Wrappers import BlueFlatWrapper
from CybORG.Simulator.Actions import Sleep
from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator

from jaxborg.actions.encoding import BLUE_SLEEP
from jaxborg.constants import GLOBAL_MAX_HOSTS, NUM_BLUE_AGENTS
from jaxborg.parity.fsm_red_env import FsmRedCC4Env
from jaxborg.parity.translate import build_mappings_from_cyborg
from jaxborg.scenarios.cc4.red_fsm import fsm_red_apply_delayed_update

seed = 0
scenario = EnterpriseScenarioGenerator(
blue_agent_class=SleepAgent,
green_agent_class=EnterpriseGreenAgent,
red_agent_class=FiniteStateRedAgent,
steps=TWO_STEP_TRACE_STEPS,
)
cyborg = CybORG(scenario_generator=scenario, seed=seed)
cyborg_env = BlueFlatWrapper(env=cyborg, pad_spaces=True)
cyborg_env.reset()
mappings = build_mappings_from_cyborg(cyborg)
topology_path = _save_cyborg_topology_snapshot(cyborg, tmp_path, seed)

captured_known_hosts = None
interface = cyborg.environment_controller.agent_interfaces["red_agent_4"]
agent = interface.agent
original_get_action = agent.get_action

def _wrapped(self, observation, action_space):
nonlocal captured_known_hosts
action = original_get_action(observation, action_space)
captured_known_hosts = np.zeros(GLOBAL_MAX_HOSTS, dtype=bool)
for ip, info in self.host_states.items():
hostname = info.get("hostname")
if hostname in mappings.hostname_to_idx:
captured_known_hosts[mappings.hostname_to_idx[hostname]] = True
return action

agent.get_action = types.MethodType(_wrapped, agent)

jax_env = FsmRedCC4Env(num_steps=TWO_STEP_TRACE_STEPS, topology_path=topology_path)
loop_key = jax.random.PRNGKey(seed)
_, jax_state = jax_env.reset(loop_key)
blue_actions = {f"blue_{i}": jnp.int32(BLUE_SLEEP) for i in range(NUM_BLUE_AGENTS)}

loop_key, step_key = jax.random.split(loop_key)
_, jax_state, _, _, _ = jax_env.step(step_key, jax_state, blue_actions)
_, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents})

loop_key, step_key = jax.random.split(loop_key)
state_before = fsm_red_apply_delayed_update(jax_state.state)
_, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents})

assert captured_known_hosts is not None, "Expected wrapped CybORG red_4 action to capture known hosts"
np.testing.assert_array_equal(
np.array(state_before.red_discovered_hosts[4], dtype=bool),
captured_known_hosts,
)

@pytest.mark.xfail(
reason="requires retired CybORG green/replay tapes; topology snapshots preserve static layout only",
)
def test_snapshot_topology_matches_second_step_red4_action_after_green_phish(self, tmp_path):
"""After the seed-0 phishing foothold, JAX and CybORG should pick the same red_4 follow-up action."""
from CybORG import CybORG
from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent
from CybORG.Agents.Wrappers import BlueFlatWrapper
from CybORG.Simulator.Actions import Sleep
from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator

from jaxborg.actions.encoding import BLUE_SLEEP
from jaxborg.constants import NUM_BLUE_AGENTS, NUM_RED_AGENTS
from jaxborg.parity.fsm_red_env import FsmRedCC4Env
from jaxborg.parity.translate import build_mappings_from_cyborg, jax_red_to_cyborg
from jaxborg.scenarios.cc4.red_fsm import fsm_red_apply_delayed_update, fsm_red_select_actions

seed = 0
scenario = EnterpriseScenarioGenerator(
blue_agent_class=SleepAgent,
green_agent_class=EnterpriseGreenAgent,
red_agent_class=FiniteStateRedAgent,
steps=TWO_STEP_TRACE_STEPS,
)
cyborg = CybORG(scenario_generator=scenario, seed=seed)
cyborg_env = BlueFlatWrapper(env=cyborg, pad_spaces=True)
cyborg_env.reset()
mappings = build_mappings_from_cyborg(cyborg)
topology_path = _save_cyborg_topology_snapshot(cyborg, tmp_path, seed)

logged_actions = {}
for agent_name, interface in cyborg.environment_controller.agent_interfaces.items():
if not agent_name.startswith("red_agent_"):
continue
agent = interface.agent
original_get_action = agent.get_action

def _wrap_get_action(orig_fn, wrapped_name):
def _wrapped(self, observation, action_space):
action = orig_fn(observation, action_space)
logged_actions[wrapped_name] = action
return action

return types.MethodType(_wrapped, agent)

agent.get_action = _wrap_get_action(original_get_action, agent_name)

jax_env = FsmRedCC4Env(num_steps=TWO_STEP_TRACE_STEPS, topology_path=topology_path)
loop_key = jax.random.PRNGKey(seed)
_, jax_state = jax_env.reset(loop_key)
blue_actions = {f"blue_{i}": jnp.int32(BLUE_SLEEP) for i in range(NUM_BLUE_AGENTS)}

loop_key, step_key = jax.random.split(loop_key)
_, jax_state, _, _, _ = jax_env.step(step_key, jax_state, blue_actions)
_, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents})

loop_key, step_key = jax.random.split(loop_key)
key_for_step_env, _key_reset = jax.random.split(step_key)
_key_unused, key_red = jax.random.split(key_for_step_env)
red_keys = jax.random.split(key_red, NUM_RED_AGENTS)
state_before = fsm_red_apply_delayed_update(jax_state.state)
jax_red_actions = fsm_red_select_actions(state_before, jax_state.const, red_keys)[0]
jax_red4 = jax_red_to_cyborg(int(jax_red_actions[4]), 4, mappings)

_, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents})
cyborg_red4 = logged_actions["red_agent_4"]

def _action_target(action):
if hasattr(action, "hostname"):
return action.hostname
if hasattr(action, "ip_address"):
return str(action.ip_address)
return None

assert type(jax_red4).__name__ == type(cyborg_red4).__name__ == "PrivilegeEscalate"
assert _action_target(jax_red4) == _action_target(cyborg_red4)

@pytest.mark.xfail(
reason="requires retired CybORG green/replay tapes; topology snapshots preserve static layout only",
)
def test_explicit_cyborg_red_trace_matches_green_phish_privesc_privilege(self, tmp_path):
"""Replaying CybORG's first seed-0 red trace should preserve the red_4 privesc privilege gain."""
from CybORG import CybORG
from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent
from CybORG.Agents.Wrappers import BlueFlatWrapper
from CybORG.Simulator.Actions import Sleep
from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator

from jaxborg.actions.encoding import BLUE_SLEEP, RED_SLEEP
from jaxborg.constants import COMPROMISE_PRIVILEGED, NUM_BLUE_AGENTS, NUM_RED_AGENTS
from jaxborg.env import ScenarioEnv
from jaxborg.parity.translate import build_mappings_from_cyborg, cyborg_red_to_jax
from tests.differential.state_comparator import compare_snapshots, extract_cyborg_snapshot, extract_jax_snapshot

seed = 0
scenario = EnterpriseScenarioGenerator(
blue_agent_class=SleepAgent,
green_agent_class=EnterpriseGreenAgent,
red_agent_class=FiniteStateRedAgent,
steps=THREE_STEP_TRACE_STEPS,
)
cyborg = CybORG(scenario_generator=scenario, seed=seed)
cyborg_env = BlueFlatWrapper(env=cyborg, pad_spaces=True)
cyborg_env.reset()
mappings = build_mappings_from_cyborg(cyborg)
topology_path = _save_cyborg_topology_snapshot(cyborg, tmp_path, seed)

logged_actions = {}
for agent_name, interface in cyborg.environment_controller.agent_interfaces.items():
if not agent_name.startswith("red_agent_"):
continue
agent = interface.agent
original_get_action = agent.get_action

def _wrap_get_action(orig_fn, wrapped_name):
def _wrapped(self, observation, action_space):
action = orig_fn(observation, action_space)
logged_actions[wrapped_name] = action
return action

return types.MethodType(_wrapped, agent)

agent.get_action = _wrap_get_action(original_get_action, agent_name)

jax_env = ScenarioEnv(num_steps=THREE_STEP_TRACE_STEPS, topology_path=topology_path)
loop_key = jax.random.PRNGKey(seed)
_, jax_state = jax_env.reset(loop_key)

for _step in range(3):
logged_actions.clear()
_, _, _, _, _ = cyborg_env.step(actions={a: Sleep() for a in cyborg_env.agents})
red_actions = {}
for agent_id in range(NUM_RED_AGENTS):
agent_name = f"red_agent_{agent_id}"
cy_action = logged_actions.get(agent_name)
red_actions[f"red_{agent_id}"] = jnp.int32(
RED_SLEEP if cy_action is None else cyborg_red_to_jax(cy_action, agent_name, mappings)
)

loop_key, step_key = jax.random.split(loop_key)
blue_actions = {f"blue_{i}": jnp.int32(BLUE_SLEEP) for i in range(NUM_BLUE_AGENTS)}
_, jax_state, _, _, _ = jax_env.step(step_key, jax_state, {**blue_actions, **red_actions})

target_hostname = "operational_zone_b_subnet_user_host_5"
target_host = mappings.hostname_to_idx[target_hostname]
target_sessions = [
sess
for sess in cyborg.environment_controller.state.sessions["red_agent_4"].values()
if sess.hostname == target_hostname
]
assert any(sess.has_privileged_access() for sess in target_sessions), target_sessions
assert int(jax_state.state.red_privilege[4, target_host]) == COMPROMISE_PRIVILEGED
assert int(jax_state.state.host_compromised[target_host]) == COMPROMISE_PRIVILEGED

diffs = compare_snapshots(
extract_cyborg_snapshot(cyborg, mappings),
extract_jax_snapshot(jax_state.state, jax_state.const, mappings),
)
host_label = f"host_{target_host}"
agent_host_label = f"red_4_host_{target_host}"
target_diffs = [
diff
for diff in diffs
if diff.field_name in {"host_compromised", "red_privilege"}
and diff.host_or_agent in {host_label, agent_host_label}
]
assert target_diffs == []

def test_explicit_replay_corrects_generic_exploit_to_cyborg_subaction(self, tmp_path):
"""Seed-0 generic exploit replay should not invent a host_22 foothold."""
from CybORG import CybORG
Expand Down Expand Up @@ -743,3 +504,50 @@ def test_native_generic_exploit_respects_blocked_scan_source_route_matches_cybor
break

assert found_step is not None, "Never found a red exploit from a blocked subnet in 300 steps"


class TestFsmRedGreenSyncParity:
"""End-state parity for {FSM red + EnterpriseGreen + Sleep blue} under sync_green_rng.

Replaces three retired tests that depended on green/replay tape infrastructure.
`tests/differential/test_red_policy_parity.py` already verifies CybORG/JAX red-policy
picks and host_states eligibility match step-by-step; this test guards the remaining
angle — that the resulting host_compromised/red_privilege/red_sessions state also
matches after a multi-step green-phish → red-exploit → red-privesc chain.
"""

@pytest.mark.parametrize("seed", [0, 1000])
def test_no_critical_state_diffs_over_10_steps(self, seed):
from CybORG.Agents import EnterpriseGreenAgent, FiniteStateRedAgent, SleepAgent

from jaxborg.actions.encoding import BLUE_SLEEP
from jaxborg.constants import NUM_BLUE_AGENTS
from tests.differential.harness import CC4DifferentialHarness

harness = CC4DifferentialHarness(
seed=seed,
max_steps=10,
blue_cls=SleepAgent,
green_cls=EnterpriseGreenAgent,
red_cls=FiniteStateRedAgent,
sync_green_rng=True,
check_rewards=False,
check_obs=False,
check_masks=False,
)
harness.reset()

critical = {"host_compromised", "red_privilege", "red_sessions"}
sleep_blue = {b: BLUE_SLEEP for b in range(NUM_BLUE_AGENTS)}
saw_privesc = False
for step in range(10):
result = harness.full_step(sleep_blue)
errors = [d for d in result.diffs if d.field_name in critical]
assert errors == [], f"seed={seed} step={step}: critical state diffs: {errors[:5]}"
if not saw_privesc and int(jnp.max(harness.jax_state.red_privilege)) > 0:
saw_privesc = True

assert saw_privesc, (
f"seed={seed}: no red agent reached privileged access in 10 steps; "
"test trajectory is degenerate (green phish → red exploit → privesc chain didn't fire)"
)
Loading