diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..6f94355 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +asyncio_mode = auto diff --git a/tests/app/flags/__init__.py b/tests/app/flags/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/advanced/__init__.py b/tests/app/flags/advanced/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/advanced/test_flag0.py b/tests/app/flags/advanced/test_flag0.py index 2390842..0ecf375 100644 --- a/tests/app/flags/advanced/test_flag0.py +++ b/tests/app/flags/advanced/test_flag0.py @@ -1,61 +1,132 @@ +"""Exhaustive tests for advanced flag_0 (AdvancedFlag0).""" import pytest +from unittest.mock import MagicMock, AsyncMock from plugins.training.app.flags.advanced.flag_0 import AdvancedFlag0 -class TestFlag: - def test_valid_external_http_contact(self): - test_contact = 'http://10.10.10.10:8888' - assert AdvancedFlag0.valid_external_http_contact(test_contact) +class TestValidExternalHttpContact: + def test_valid_external_http(self): + assert AdvancedFlag0.valid_external_http_contact('http://10.10.10.10:8888') - def test_valid_external_https_contact(self): - test_contact = 'https://10.10.10.10:8888' - assert AdvancedFlag0.valid_external_http_contact(test_contact) + def test_valid_external_https(self): + assert AdvancedFlag0.valid_external_http_contact('https://10.10.10.10:8888') - def test_valid_external_http_contact_no_port(self): - test_contact = 'http://10.10.10.10' - assert AdvancedFlag0.valid_external_http_contact(test_contact) + def test_valid_http_no_port(self): + assert AdvancedFlag0.valid_external_http_contact('http://10.10.10.10') - def test_valid_external_https_contact_no_port(self): - test_contact = 'https://10.10.10.10' - assert AdvancedFlag0.valid_external_http_contact(test_contact) + def test_valid_https_no_port(self): + assert AdvancedFlag0.valid_external_http_contact('https://10.10.10.10') - def test_internal_http_contact_loopback(self): - test_contact = 'http://127.0.0.1:8888' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + def test_loopback_rejected(self): + assert not AdvancedFlag0.valid_external_http_contact('http://127.0.0.1:8888') - def test_internal_http_contact_loopback_no_port(self): - test_contact = 'http://127.0.0.1' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + def test_loopback_no_port(self): + assert not AdvancedFlag0.valid_external_http_contact('http://127.0.0.1') - def test_internal_https_contact_loopback(self): - test_contact = 'https://127.0.0.1:12345' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + def test_loopback_https(self): + assert not AdvancedFlag0.valid_external_http_contact('https://127.0.0.1:12345') - def test_internal_http_contact_loopback_other(self): - test_contact = 'http://127.10.10.10' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + def test_loopback_other(self): + assert not AdvancedFlag0.valid_external_http_contact('http://127.10.10.10') - def test_internal_http_contact_0000(self): - test_contact = 'http://0.0.0.0:8888' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + def test_zero_ip(self): + assert not AdvancedFlag0.valid_external_http_contact('http://0.0.0.0:8888') - def test_internal_http_contact_0000_no_port(self): - test_contact = 'http://0.0.0.0' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + def test_zero_ip_no_port(self): + assert not AdvancedFlag0.valid_external_http_contact('http://0.0.0.0') def test_invalid_port(self): - test_contact = 'http://10.10.10.10:abcd' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + assert not AdvancedFlag0.valid_external_http_contact('http://10.10.10.10:abcd') def test_out_of_range_port(self): - test_contact = 'http://10.10.10.10:123456' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + assert not AdvancedFlag0.valid_external_http_contact('http://10.10.10.10:123456') - def test_not_ip_addr(self): - test_contact = 'http://myhostname.tld:1234' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + def test_hostname_rejected(self): + assert not AdvancedFlag0.valid_external_http_contact('http://myhostname.tld:1234') def test_wrong_protocol(self): - test_contact = 'nothttp://10.10.10.10:12345' - assert not AdvancedFlag0.valid_external_http_contact(test_contact) + assert not AdvancedFlag0.valid_external_http_contact('nothttp://10.10.10.10:12345') + + def test_ftp_rejected(self): + assert not AdvancedFlag0.valid_external_http_contact('ftp://10.10.10.10:21') + + def test_empty_string(self): + assert not AdvancedFlag0.valid_external_http_contact('') + + def test_no_scheme(self): + assert not AdvancedFlag0.valid_external_http_contact('10.10.10.10:8888') + + def test_private_192_168(self): + assert AdvancedFlag0.valid_external_http_contact('http://192.168.1.1:443') + + def test_private_172(self): + assert AdvancedFlag0.valid_external_http_contact('http://172.16.0.1:8080') + + def test_port_zero(self): + # Port 0 is technically valid in URL parsing + result = AdvancedFlag0.valid_external_http_contact('http://10.10.10.10:0') + # Should still pass since port 0 is in range + assert isinstance(result, bool) + + def test_max_valid_port(self): + assert AdvancedFlag0.valid_external_http_contact('http://10.10.10.10:65535') + + def test_boundary_ip_255(self): + assert AdvancedFlag0.valid_external_http_contact('http://255.255.255.255:80') + + def test_ip_1_0_0_1(self): + assert AdvancedFlag0.valid_external_http_contact('http://1.0.0.1:80') + + +class TestExternalFacingIp: + def test_external(self): + assert AdvancedFlag0.external_facing_ip('10.0.0.1') + + def test_loopback(self): + assert not AdvancedFlag0.external_facing_ip('127.0.0.1') + + def test_loopback_range(self): + assert not AdvancedFlag0.external_facing_ip('127.255.0.1') + + def test_zero(self): + assert not AdvancedFlag0.external_facing_ip('0.0.0.0') + + +class TestAdvancedFlag0Verify: + @pytest.mark.asyncio + async def test_verify_valid_contact(self): + flag = AdvancedFlag0(number=1) + services = {'app_svc': MagicMock()} + services['app_svc'].get_config = MagicMock(return_value='http://10.10.10.10:8888') + result = await flag.verify(services) + assert result is True + + @pytest.mark.asyncio + async def test_verify_loopback_contact(self): + flag = AdvancedFlag0(number=1) + services = {'app_svc': MagicMock()} + services['app_svc'].get_config = MagicMock(return_value='http://127.0.0.1:8888') + result = await flag.verify(services) + assert result is False + + @pytest.mark.asyncio + async def test_verify_none_contact(self): + flag = AdvancedFlag0(number=1) + services = {'app_svc': MagicMock()} + services['app_svc'].get_config = MagicMock(return_value=None) + result = await flag.verify(services) + assert not result + + @pytest.mark.asyncio + async def test_verify_empty_contact(self): + flag = AdvancedFlag0(number=1) + services = {'app_svc': MagicMock()} + services['app_svc'].get_config = MagicMock(return_value='') + result = await flag.verify(services) + assert not result + + def test_name_and_challenge(self): + flag = AdvancedFlag0(number=1) + assert flag.name == 'Update configs' + assert 'app.contact.http' in flag.challenge diff --git a/tests/app/flags/advanced/test_flag1.py b/tests/app/flags/advanced/test_flag1.py new file mode 100644 index 0000000..bf6cbb2 --- /dev/null +++ b/tests/app/flags/advanced/test_flag1.py @@ -0,0 +1,49 @@ +"""Tests for advanced flag_1 (AdvancedFlag1).""" +import pytest +from unittest.mock import MagicMock, AsyncMock + +from plugins.training.app.flags.advanced.flag_1 import AdvancedFlag1 + + +class TestAdvancedFlag1: + + def test_name(self): + f = AdvancedFlag1(number=1) + assert f.name == 'Adjust sources' + + @pytest.mark.asyncio + async def test_verify_source_with_facts_and_rules(self): + f = AdvancedFlag1(number=1) + source = MagicMock() + source.facts = [MagicMock()] + source.rules = [MagicMock()] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[source]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_source(self): + f = AdvancedFlag1(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_source_no_facts(self): + f = AdvancedFlag1(number=1) + source = MagicMock() + source.facts = [] + source.rules = [MagicMock()] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[source]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_source_no_rules(self): + f = AdvancedFlag1(number=1) + source = MagicMock() + source.facts = [MagicMock()] + source.rules = [] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[source]) + assert await f.verify(services) is False diff --git a/tests/app/flags/advanced/test_flag2.py b/tests/app/flags/advanced/test_flag2.py new file mode 100644 index 0000000..2c1b806 --- /dev/null +++ b/tests/app/flags/advanced/test_flag2.py @@ -0,0 +1,49 @@ +"""Tests for advanced flag_2 (AdvancedFlag2).""" +import pytest +from unittest.mock import MagicMock, AsyncMock + +from plugins.training.app.flags.advanced.flag_2 import AdvancedFlag2 + + +class TestAdvancedFlag2: + + def test_name(self): + f = AdvancedFlag2(number=1) + assert f.name == 'Add new user' + + @pytest.mark.asyncio + async def test_verify_correct_user(self): + f = AdvancedFlag2(number=1) + user = MagicMock() + user.password = 'test' + user.permissions = ['red'] + services = {'auth_svc': MagicMock()} + services['auth_svc'].user_map = {'test': user} + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_user(self): + f = AdvancedFlag2(number=1) + services = {'auth_svc': MagicMock()} + services['auth_svc'].user_map = {} + assert not await f.verify(services) + + @pytest.mark.asyncio + async def test_verify_wrong_password(self): + f = AdvancedFlag2(number=1) + user = MagicMock() + user.password = 'wrong' + user.permissions = ['red'] + services = {'auth_svc': MagicMock()} + services['auth_svc'].user_map = {'test': user} + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_red_permission(self): + f = AdvancedFlag2(number=1) + user = MagicMock() + user.password = 'test' + user.permissions = ['blue'] + services = {'auth_svc': MagicMock()} + services['auth_svc'].user_map = {'test': user} + assert await f.verify(services) is False diff --git a/tests/app/flags/adversaries/__init__.py b/tests/app/flags/adversaries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/adversaries/test_flags.py b/tests/app/flags/adversaries/test_flags.py new file mode 100644 index 0000000..51d7e00 --- /dev/null +++ b/tests/app/flags/adversaries/test_flags.py @@ -0,0 +1,127 @@ +"""Tests for adversaries flags.""" +import pytest +from unittest.mock import MagicMock, AsyncMock + +from plugins.training.app.flags.adversaries.flag_0 import AdversariesFlag0 +from plugins.training.app.flags.adversaries.flag_1 import AdversariesFlag1 +from plugins.training.app.flags.adversaries.flag_2 import AdversariesFlag2 + + +class TestAdversariesFlag0: + def test_name(self): + assert AdversariesFlag0(number=1).name == 'Create adversary' + + @pytest.mark.asyncio + async def test_verify_enough_abilities(self): + f = AdversariesFlag0(number=1) + adv = MagicMock() + adv.atomic_ordering = ['a1', 'a2', 'a3'] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[adv]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_too_few_abilities(self): + f = AdversariesFlag0(number=1) + adv = MagicMock() + adv.atomic_ordering = ['a1'] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[adv]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_adversary(self): + f = AdversariesFlag0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_exactly_three(self): + f = AdversariesFlag0(number=1) + adv = MagicMock() + adv.atomic_ordering = ['a1', 'a2', 'a3'] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[adv]) + assert await f.verify(services) is True + + +class TestAdversariesFlag1: + def test_name(self): + assert AdversariesFlag1(number=1).name == 'Combine adversaries' + + @pytest.mark.asyncio + async def test_verify_has_nosy_neighbor(self): + f = AdversariesFlag1(number=1) + adv = MagicMock() + adv.atomic_ordering = ['other', '2fe2d5e6-7b06-4fc0-bf71-6966a1226731'] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[adv]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_missing_nosy(self): + f = AdversariesFlag1(number=1) + adv = MagicMock() + adv.atomic_ordering = ['other'] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[adv]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_adversary(self): + f = AdversariesFlag1(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestAdversariesFlag2: + def test_name(self): + assert AdversariesFlag2(number=1).name == 'Create ability' + + @pytest.mark.asyncio + async def test_verify_correct_ability(self): + f = AdversariesFlag2(number=1) + executor = MagicMock() + executor.cleanup = 'rm -f /tmp/test' + executor.command = 'ifconfig' + ability = MagicMock() + ability.tactic = 'discovery' + ability.executors = [executor] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[ability]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_wrong_tactic(self): + f = AdversariesFlag2(number=1) + executor = MagicMock() + executor.cleanup = 'cmd' + executor.command = 'cmd' + ability = MagicMock() + ability.tactic = 'persistence' + ability.executors = [executor] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[ability]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_cleanup(self): + f = AdversariesFlag2(number=1) + executor = MagicMock() + executor.cleanup = '' + executor.command = 'ifconfig' + ability = MagicMock() + ability.tactic = 'discovery' + ability.executors = [executor] + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[ability]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_ability(self): + f = AdversariesFlag2(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False diff --git a/tests/app/flags/agents/__init__.py b/tests/app/flags/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/agents/test_blue_flags.py b/tests/app/flags/agents/test_blue_flags.py new file mode 100644 index 0000000..a058be0 --- /dev/null +++ b/tests/app/flags/agents/test_blue_flags.py @@ -0,0 +1,108 @@ +"""Tests for agents blue flag modules.""" +import pytest +from unittest.mock import MagicMock, AsyncMock + +from plugins.training.app.flags.agents.blue_0 import AgentsBlue0 +from plugins.training.app.flags.agents.blue_1 import AgentsBlue1 +from plugins.training.app.flags.agents.blue_2 import AgentsBlue2 +from plugins.training.app.flags.agents.blue_3 import AgentsBlue3 + + +class TestAgentsBlue0: + def test_name(self): + assert AgentsBlue0(number=1).name == 'Red agent - *nix' + + @pytest.mark.asyncio + async def test_verify_linux_cert_nix(self): + f = AgentsBlue0(number=1) + agent = MagicMock(platform='linux', group='cert-nix') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_darwin_cert_nix(self): + f = AgentsBlue0(number=1) + agent = MagicMock(platform='darwin', group='cert-nix') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_wrong_platform(self): + f = AgentsBlue0(number=1) + agent = MagicMock(platform='windows', group='cert-nix') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_wrong_group(self): + f = AgentsBlue0(number=1) + agent = MagicMock(platform='linux', group='red') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False + + +class TestAgentsBlue1: + def test_name(self): + assert AgentsBlue1(number=1).name == 'Blue agent - *nix' + + @pytest.mark.asyncio + async def test_verify_elevated_blue(self): + f = AgentsBlue1(number=1) + agent = MagicMock(platform='linux', group='blue', privilege='Elevated') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_not_elevated(self): + f = AgentsBlue1(number=1) + agent = MagicMock(platform='linux', group='blue', privilege='User') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False + + +class TestAgentsBlue2: + def test_name(self): + assert AgentsBlue2(number=1).name == 'Red agent - Windows' + + @pytest.mark.asyncio + async def test_verify_win_cert_win(self): + f = AgentsBlue2(number=1) + agent = MagicMock(platform='windows', group='cert-win') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_linux_fails(self): + f = AgentsBlue2(number=1) + agent = MagicMock(platform='linux', group='cert-win') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False + + +class TestAgentsBlue3: + def test_name(self): + assert AgentsBlue3(number=1).name == 'Blue agent - Windows' + + @pytest.mark.asyncio + async def test_verify_windows_elevated_blue(self): + f = AgentsBlue3(number=1) + agent = MagicMock(platform='windows', group='blue', privilege='Elevated') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_not_windows(self): + f = AgentsBlue3(number=1) + agent = MagicMock(platform='linux', group='blue', privilege='Elevated') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False diff --git a/tests/app/flags/agents/test_flags.py b/tests/app/flags/agents/test_flags.py new file mode 100644 index 0000000..06ee5fe --- /dev/null +++ b/tests/app/flags/agents/test_flags.py @@ -0,0 +1,226 @@ +"""Tests for agents flag modules.""" +import sys +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from plugins.training.app.flags.agents.flag_0 import AgentsFlag0 +from plugins.training.app.flags.agents.flag_1 import AgentsFlag1 +from plugins.training.app.flags.agents.flag_2 import AgentsFlag2 +from plugins.training.app.flags.agents.flag_3 import AgentsFlag3 +from plugins.training.app.flags.agents.flag_4 import AgentsFlag4 +from plugins.training.app.flags.agents.flag_5 import AgentsFlag5 +from plugins.training.app.flags.agents.flag_6 import AgentsFlag6 +from plugins.training.app.flags.agents.flag_7 import AgentsFlag7 + + +class TestAgentsFlag0: + def test_name(self): + assert AgentsFlag0(number=1).name == 'Local agent' + + @pytest.mark.asyncio + async def test_verify_agent_exists(self): + f = AgentsFlag0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[MagicMock()]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_agents(self): + f = AgentsFlag0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestAgentsFlag1: + def test_name(self): + assert AgentsFlag1(number=1).name == 'Remote agent' + + @pytest.mark.asyncio + async def test_verify_remote_agent(self): + f = AgentsFlag1(number=1) + local_agent = MagicMock(platform=sys.platform, server='10.0.0.1') + remote_agent = MagicMock(platform='different_platform', server='10.0.0.2') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[local_agent, remote_agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_single_agent(self): + f = AgentsFlag1(number=1) + agent = MagicMock(platform=sys.platform, server='10.0.0.1') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_same_platform(self): + f = AgentsFlag1(number=1) + a1 = MagicMock(platform=sys.platform, server='10.0.0.1') + a2 = MagicMock(platform=sys.platform, server='10.0.0.2') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[a1, a2]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_localhost_server(self): + f = AgentsFlag1(number=1) + a1 = MagicMock(platform=sys.platform, server='127.0.0.1') + a2 = MagicMock(platform='other', server='localhost') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[a1, a2]) + assert await f.verify(services) is False + + +class TestAgentsFlag2: + def test_name(self): + assert AgentsFlag2(number=1).name == 'Understanding trust' + + @pytest.mark.asyncio + async def test_verify_correct_timer(self): + f = AgentsFlag2(number=1) + from app.utility.base_world import BaseWorld + original = BaseWorld.get_config + BaseWorld.get_config = staticmethod(lambda name=None, prop=None: 60 if prop == 'untrusted_timer' else None) + try: + assert await f.verify({}) is True + finally: + BaseWorld.get_config = original + + @pytest.mark.asyncio + async def test_verify_wrong_timer(self): + f = AgentsFlag2(number=1) + from app.utility.base_world import BaseWorld + original = BaseWorld.get_config + BaseWorld.get_config = staticmethod(lambda name=None, prop=None: 30) + try: + assert await f.verify({}) is False + finally: + BaseWorld.get_config = original + + +class TestAgentsFlag3: + def test_name(self): + assert AgentsFlag3(number=1).name == 'Update agent' + + @pytest.mark.asyncio + async def test_verify_modified_agent(self): + f = AgentsFlag3(number=1) + agent = MagicMock(group='custom', sleep_min=10, sleep_max=20) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_default_agent(self): + f = AgentsFlag3(number=1) + agent = MagicMock(group='red', sleep_min=30, sleep_max=60) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False + + +class TestAgentsFlag4: + def test_name(self): + assert AgentsFlag4(number=1).name == 'Agent filename' + + @pytest.mark.asyncio + async def test_verify_correct_name(self): + f = AgentsFlag4(number=1) + from app.utility.base_world import BaseWorld + original = BaseWorld.get_config + BaseWorld.get_config = staticmethod(lambda name=None, prop=None: 'super_scary' if prop == 'implant_name' else None) + try: + assert await f.verify({}) is True + finally: + BaseWorld.get_config = original + + @pytest.mark.asyncio + async def test_verify_wrong_name(self): + f = AgentsFlag4(number=1) + from app.utility.base_world import BaseWorld + original = BaseWorld.get_config + BaseWorld.get_config = staticmethod(lambda name=None, prop=None: 'sandcat') + try: + assert await f.verify({}) is False + finally: + BaseWorld.get_config = original + + +class TestAgentsFlag5: + def test_name(self): + assert AgentsFlag5(number=1).name == 'Bootstrap abilities' + + @pytest.mark.asyncio + async def test_verify_correct_bootstrap(self): + f = AgentsFlag5(number=1) + from app.utility.base_world import BaseWorld + original = BaseWorld.get_config + BaseWorld.get_config = staticmethod( + lambda name=None, prop=None: ['c0da588f-79f0-4263-8998-7496b1a40596'] if prop == 'bootstrap_abilities' else None + ) + try: + assert await f.verify({}) is True + finally: + BaseWorld.get_config = original + + @pytest.mark.asyncio + async def test_verify_missing_bootstrap(self): + f = AgentsFlag5(number=1) + from app.utility.base_world import BaseWorld + original = BaseWorld.get_config + BaseWorld.get_config = staticmethod(lambda name=None, prop=None: []) + try: + assert await f.verify({}) is False + finally: + BaseWorld.get_config = original + + +class TestAgentsFlag6: + def test_name(self): + assert AgentsFlag6(number=1).name == 'Contact points' + + @pytest.mark.asyncio + async def test_verify_multiple_contacts(self): + f = AgentsFlag6(number=1) + a1 = MagicMock(contact='http') + a2 = MagicMock(contact='tcp') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[a1, a2]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_single_contact(self): + f = AgentsFlag6(number=1) + a1 = MagicMock(contact='http') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[a1]) + assert await f.verify(services) is False + + +class TestAgentsFlag7: + def test_name(self): + assert AgentsFlag7(number=1).name == 'Kill agent' + + @pytest.mark.asyncio + async def test_verify_watchdog_set(self): + f = AgentsFlag7(number=1) + agent = MagicMock(watchdog=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_watchdog(self): + f = AgentsFlag7(number=1) + agent = MagicMock(watchdog=0) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_agents(self): + f = AgentsFlag7(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False diff --git a/tests/app/flags/attack/__init__.py b/tests/app/flags/attack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/attack/test_flags.py b/tests/app/flags/attack/test_flags.py new file mode 100644 index 0000000..29ea676 --- /dev/null +++ b/tests/app/flags/attack/test_flags.py @@ -0,0 +1,47 @@ +"""Tests for attack blue flag modules.""" +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from plugins.training.app.flags.attack.blue_0 import AttackBlue0 +from plugins.training.app.flags.attack.blue_1 import AttackBlue1 + + +class TestAttackBlue0: + def test_name(self): + assert AttackBlue0(number=1).name == 'ATT&CK Quiz 1' + + @pytest.mark.asyncio + async def test_verify_delegates_to_base_flag(self): + f = AttackBlue0(number=1) + with patch('plugins.training.app.base_flag.BaseFlag.verify_attack_flag', + new_callable=AsyncMock, return_value=True) as mock: + result = await f.verify({'data_svc': AsyncMock(), 'rest_svc': AsyncMock()}) + assert result is True + mock.assert_called_once() + args = mock.call_args + assert args[0][1] == 'T1033' + assert args[0][2] == 'blue_quiz_1' + + @pytest.mark.asyncio + async def test_verify_fails(self): + f = AttackBlue0(number=1) + with patch('plugins.training.app.base_flag.BaseFlag.verify_attack_flag', + new_callable=AsyncMock, return_value=False): + result = await f.verify({'data_svc': AsyncMock(), 'rest_svc': AsyncMock()}) + assert result is False + + +class TestAttackBlue1: + def test_name(self): + assert AttackBlue1(number=1).name == 'ATT&CK Quiz 2' + + @pytest.mark.asyncio + async def test_verify_technique(self): + f = AttackBlue1(number=1) + with patch('plugins.training.app.base_flag.BaseFlag.verify_attack_flag', + new_callable=AsyncMock, return_value=True) as mock: + result = await f.verify({'data_svc': AsyncMock(), 'rest_svc': AsyncMock()}) + assert result is True + args = mock.call_args + assert args[0][1] == 'T1083' + assert args[0][2] == 'blue_quiz_2' diff --git a/tests/app/flags/autonomous/__init__.py b/tests/app/flags/autonomous/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/autonomous/test_flags.py b/tests/app/flags/autonomous/test_flags.py new file mode 100644 index 0000000..1330dc4 --- /dev/null +++ b/tests/app/flags/autonomous/test_flags.py @@ -0,0 +1,119 @@ +"""Tests for autonomous flag modules.""" +import pytest +from unittest.mock import MagicMock, AsyncMock + +from plugins.training.app.flags.autonomous.blue_0 import AutonomousBlue0 +from plugins.training.app.flags.autonomous.blue_1 import AutonomousBlue1 +from plugins.training.app.flags.autonomous.blue_3 import AutonomousBlue3 + + +class TestAutonomousBlue0: + def test_name(self): + assert AutonomousBlue0(number=1).name == 'Enable Autonomous Operation' + + @pytest.mark.asyncio + async def test_verify_correct_operation(self): + f = AutonomousBlue0(number=1) + adv = MagicMock(adversary_id='7e422753-ad7a-4401-bc8b-b12a28e69c25') + planner = MagicMock() + planner.name = 'batch' + op = MagicMock(adversary=adv, planner=planner) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_wrong_adversary(self): + f = AutonomousBlue0(number=1) + adv = MagicMock(adversary_id='wrong-id') + planner = MagicMock() + planner.name = 'batch' + op = MagicMock(adversary=adv, planner=planner) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_wrong_planner(self): + f = AutonomousBlue0(number=1) + adv = MagicMock(adversary_id='7e422753-ad7a-4401-bc8b-b12a28e69c25') + planner = MagicMock() + planner.name = 'sequential' + op = MagicMock(adversary=adv, planner=planner) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_operations(self): + f = AutonomousBlue0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestAutonomousBlue1: + def test_name(self): + assert AutonomousBlue1(number=1).name == 'Process on Unauthorized Port' + + @pytest.mark.asyncio + async def test_verify_detected_and_killed(self): + f = AutonomousBlue1(number=1) + op = MagicMock() + op.ran_ability_id = MagicMock(side_effect=lambda aid: aid in [ + '3b4640bc-eacb-407a-a997-105e39788781', + '02fb7fa9-8886-4330-9e65-fa7bb1bc5271', + ]) + fact1 = MagicMock(trait='remote.port.unauthorized') + fact2 = MagicMock(trait='host.pid.unauthorized') + op.all_facts = AsyncMock(return_value=[fact1, fact2]) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_not_detected(self): + f = AutonomousBlue1(number=1) + op = MagicMock() + op.ran_ability_id = MagicMock(return_value=False) + op.all_facts = AsyncMock(return_value=[]) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + +class TestAutonomousBlue3: + def test_name(self): + assert AutonomousBlue3(number=1).name == 'Suspicious URL in mail' + + @pytest.mark.asyncio + async def test_verify_url_found_and_inoculated(self): + f = AutonomousBlue3(number=1) + fact = MagicMock(trait='remote.suspicious.url') + op = MagicMock() + op.all_facts = AsyncMock(return_value=[fact]) + op.ran_ability_id = MagicMock(side_effect=lambda aid: aid in [ + '1226f8ec-e2e5-4311-88e7-378c0e5cc7ce', + '2ca64acd-dc12-4cc8-b78a-6a182508a50b', + ]) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_not_inoculated(self): + f = AutonomousBlue3(number=1) + fact = MagicMock(trait='remote.suspicious.url') + op = MagicMock() + op.all_facts = AsyncMock(return_value=[fact]) + op.ran_ability_id = MagicMock(side_effect=lambda aid: aid == '1226f8ec-e2e5-4311-88e7-378c0e5cc7ce') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_operations(self): + f = AutonomousBlue3(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False diff --git a/tests/app/flags/developers/__init__.py b/tests/app/flags/developers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/developers/test_flags.py b/tests/app/flags/developers/test_flags.py new file mode 100644 index 0000000..25ebcdf --- /dev/null +++ b/tests/app/flags/developers/test_flags.py @@ -0,0 +1,100 @@ +"""Tests for developers flag modules.""" +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from plugins.training.app.flags.developers.flag_0 import DevelopersFlag0 +from plugins.training.app.flags.developers.flag_2 import DevelopersFlag2 +from plugins.training.app.flags.developers.flag_6 import DevelopersFlag6 +from plugins.training.app.flags.developers.flag_7 import DevelopersFlag7 + + +class TestDevelopersFlag0: + def test_name(self): + assert DevelopersFlag0(number=1).name == 'Create a new plugin' + + @pytest.mark.asyncio + async def test_verify_no_plugin(self): + f = DevelopersFlag0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestDevelopersFlag2: + def test_name(self): + assert DevelopersFlag2(number=1).name == 'Bypass authentication' + + @pytest.mark.asyncio + async def test_verify_bypass_set(self): + f = DevelopersFlag2(number=1) + services = {'auth_svc': MagicMock()} + services['auth_svc'].bypass = ['127.0.0.1'] + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_bypass_not_set(self): + f = DevelopersFlag2(number=1) + services = {'auth_svc': MagicMock()} + services['auth_svc'].bypass = [] + assert await f.verify(services) is False + + +class TestDevelopersFlag6: + def test_name(self): + assert DevelopersFlag6(number=1).name == 'Build an agent' + + @pytest.mark.asyncio + async def test_verify_zsh_agent(self): + f = DevelopersFlag6(number=1) + agent = MagicMock(executors=['zsh']) + op = MagicMock(chain=[MagicMock()], agents=[agent]) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_zsh(self): + f = DevelopersFlag6(number=1) + agent = MagicMock(executors=['sh']) + op = MagicMock(chain=[MagicMock()], agents=[agent]) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_empty_chain(self): + f = DevelopersFlag6(number=1) + op = MagicMock(chain=[], agents=[]) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + +class TestDevelopersFlag7: + def test_name(self): + assert DevelopersFlag7(number=1).name == 'Understanding access' + + @pytest.mark.asyncio + async def test_verify_blue_access(self): + f = DevelopersFlag7(number=1) + from app.utility.base_world import BaseWorld + plugin = MagicMock(access=BaseWorld.Access.BLUE) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[plugin]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_red_access(self): + f = DevelopersFlag7(number=1) + from app.utility.base_world import BaseWorld + plugin = MagicMock(access=BaseWorld.Access.RED) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[plugin]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_plugin(self): + f = DevelopersFlag7(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False diff --git a/tests/app/flags/manual/__init__.py b/tests/app/flags/manual/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/manual/test_flags.py b/tests/app/flags/manual/test_flags.py new file mode 100644 index 0000000..181ddcd --- /dev/null +++ b/tests/app/flags/manual/test_flags.py @@ -0,0 +1,59 @@ +"""Tests for manual blue flag modules.""" +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from plugins.training.app.flags.manual.blue_0 import ManualBlue0 +from plugins.training.app.flags.manual.blue_1a import ManualBlue1a + + +class TestManualBlue0: + def test_name(self): + assert ManualBlue0(number=1).name == 'Enable Manual Operation' + + @pytest.mark.asyncio + async def test_verify_adhoc_blue_manual(self): + f = ManualBlue0(number=1) + adv = MagicMock(adversary_id='ad-hoc') + op = MagicMock(adversary=adv) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_wrong_adversary(self): + f = ManualBlue0(number=1) + adv = MagicMock(adversary_id='not-adhoc') + op = MagicMock(adversary=adv) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_op(self): + f = ManualBlue0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestManualBlue1a: + def test_name(self): + assert ManualBlue1a(number=1).name == 'Detect process on Unauthorized Port' + + def test_additional_fields(self): + f = ManualBlue1a(number=1) + assert f.additional_fields['operation_name'] == 'training_manual_1' + assert f.additional_fields['adversary_id'] == '72c0b333-f6fe-4fa0-a342-4215e8de3947' + assert f.additional_fields['agent_group'] == 'cert-nix' + + def test_is_resettable(self): + f = ManualBlue1a(number=1) + assert f._is_resettable() == 'True' + + @pytest.mark.asyncio + async def test_verify_delegates_to_standard_verify(self): + f = ManualBlue1a(number=1) + with patch('plugins.training.app.base_flag.BaseFlag.standard_verify_with_operation', + new_callable=AsyncMock, return_value=True): + result = await f.verify({'data_svc': AsyncMock()}) + assert result is True diff --git a/tests/app/flags/operations/__init__.py b/tests/app/flags/operations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/operations/test_flags.py b/tests/app/flags/operations/test_flags.py new file mode 100644 index 0000000..8734abb --- /dev/null +++ b/tests/app/flags/operations/test_flags.py @@ -0,0 +1,143 @@ +"""Tests for operations flag modules.""" +import pytest +from unittest.mock import MagicMock, AsyncMock + +from plugins.training.app.flags.operations.flag_0 import OperationsFlag0 +from plugins.training.app.flags.operations.flag_1 import OperationsFlag1 +from plugins.training.app.flags.operations.flag_2 import OperationsFlag2 +from plugins.training.app.flags.operations.flag_3 import OperationsFlag3 + + +class TestOperationsFlag0: + def test_name(self): + assert OperationsFlag0(number=1).name == 'Basic operation' + + @pytest.mark.asyncio + async def test_verify_finished_check(self): + f = OperationsFlag0(number=1) + adv = MagicMock(adversary_id='01d77744-2515-401a-a497-d9f7241aac3c') + op = MagicMock(finish='2023-01-01', adversary=adv) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_not_finished(self): + f = OperationsFlag0(number=1) + adv = MagicMock(adversary_id='01d77744-2515-401a-a497-d9f7241aac3c') + op = MagicMock(finish=None, adversary=adv) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_wrong_adversary(self): + f = OperationsFlag0(number=1) + adv = MagicMock(adversary_id='wrong') + op = MagicMock(finish='done', adversary=adv) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_ops(self): + f = OperationsFlag0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestOperationsFlag1: + def test_name(self): + assert OperationsFlag1(number=1).name == 'Stealthy operation' + + @pytest.mark.asyncio + async def test_verify_stealthy(self): + f = OperationsFlag1(number=1) + adv = MagicMock(adversary_id='01d77744-2515-401a-a497-d9f7241aac3c') + op = MagicMock(finish='done', adversary=adv, obfuscator='base64', jitter='10/20') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_wrong_obfuscator(self): + f = OperationsFlag1(number=1) + adv = MagicMock(adversary_id='01d77744-2515-401a-a497-d9f7241aac3c') + op = MagicMock(finish='done', adversary=adv, obfuscator='plain', jitter='10/20') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_wrong_jitter(self): + f = OperationsFlag1(number=1) + adv = MagicMock(adversary_id='01d77744-2515-401a-a497-d9f7241aac3c') + op = MagicMock(finish='done', adversary=adv, obfuscator='base64', jitter='5/10') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + +class TestOperationsFlag2: + def test_name(self): + assert OperationsFlag2(number=1).name == 'Manual operation' + + @pytest.mark.asyncio + async def test_verify_manual(self): + f = OperationsFlag2(number=1) + adv = MagicMock(adversary_id='different') + op = MagicMock(finish='done', adversary=adv, autonomous=False) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_autonomous_fails(self): + f = OperationsFlag2(number=1) + adv = MagicMock(adversary_id='different') + op = MagicMock(finish='done', adversary=adv, autonomous=True) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_check_adversary_fails(self): + f = OperationsFlag2(number=1) + adv = MagicMock(adversary_id='01d77744-2515-401a-a497-d9f7241aac3c') + op = MagicMock(finish='done', adversary=adv, autonomous=False) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + +class TestOperationsFlag3: + def test_name(self): + assert OperationsFlag3(number=1).name == 'Empty operation' + + @pytest.mark.asyncio + async def test_verify_empty_op(self): + f = OperationsFlag3(number=1) + adv = MagicMock(adversary_id='ad-hoc') + op = MagicMock(finish='done', adversary=adv, chain=[1, 2, 3, 4, 5], group='') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_too_few_links(self): + f = OperationsFlag3(number=1) + adv = MagicMock(adversary_id='ad-hoc') + op = MagicMock(finish='done', adversary=adv, chain=[1, 2], group='') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_has_group_fails(self): + f = OperationsFlag3(number=1) + adv = MagicMock(adversary_id='ad-hoc') + op = MagicMock(finish='done', adversary=adv, chain=[1, 2, 3, 4, 5], group='red') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is False diff --git a/tests/app/flags/plugins/__init__.py b/tests/app/flags/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/plugins/compass/__init__.py b/tests/app/flags/plugins/compass/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/plugins/manx/__init__.py b/tests/app/flags/plugins/manx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/plugins/response/__init__.py b/tests/app/flags/plugins/response/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/flags/plugins/test_flags.py b/tests/app/flags/plugins/test_flags.py new file mode 100644 index 0000000..1110845 --- /dev/null +++ b/tests/app/flags/plugins/test_flags.py @@ -0,0 +1,122 @@ +"""Tests for plugin flag modules (compass, manx, response).""" +import pytest +from unittest.mock import MagicMock, AsyncMock + +from plugins.training.app.flags.plugins.compass.flag_0 import PluginsCompassFlag0 +from plugins.training.app.flags.plugins.manx.flag_0 import PluginsManxFlag0 +from plugins.training.app.flags.plugins.manx.flag_1 import PluginsManxFlag1 +from plugins.training.app.flags.plugins.response.flag_0 import PluginsResponseFlag0 +from plugins.training.app.flags.plugins.response.flag_1 import PluginsResponseFlag1 + + +class TestPluginsCompassFlag0: + def test_name(self): + assert PluginsCompassFlag0(number=1).name == 'Compass plugin' + + @pytest.mark.asyncio + async def test_verify_compass_adversary(self): + f = PluginsCompassFlag0(number=1) + adv = MagicMock(description='created in compass') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[adv]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_compass(self): + f = PluginsCompassFlag0(number=1) + adv = MagicMock(description='not from plugin') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[adv]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_adversaries(self): + f = PluginsCompassFlag0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestPluginsManxFlag0: + def test_name(self): + assert PluginsManxFlag0(number=1).name == 'Manx plugin' + + @pytest.mark.asyncio + async def test_verify_no_tcp_agents(self): + f = PluginsManxFlag0(number=1) + services = { + 'data_svc': AsyncMock(), + 'contact_svc': MagicMock(report={'websocket': []}), + } + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestPluginsManxFlag1: + def test_name(self): + assert PluginsManxFlag1(number=1).name == 'Manx UDP' + + @pytest.mark.asyncio + async def test_verify_udp_agent(self): + f = PluginsManxFlag1(number=1) + agent = MagicMock(contact='udp') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_udp(self): + f = PluginsManxFlag1(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestPluginsResponseFlag0: + def test_name(self): + assert PluginsResponseFlag0(number=1).name == 'Blue agent' + + @pytest.mark.asyncio + async def test_verify_blue_agent(self): + f = PluginsResponseFlag0(number=1) + agent = MagicMock(group='blue') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_blue(self): + f = PluginsResponseFlag0(number=1) + agent = MagicMock(group='red') + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[agent]) + assert await f.verify(services) is False + + @pytest.mark.asyncio + async def test_verify_no_agents(self): + f = PluginsResponseFlag0(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False + + +class TestPluginsResponseFlag1: + def test_name(self): + assert PluginsResponseFlag1(number=1).name == 'Blue operation' + + @pytest.mark.asyncio + async def test_verify_correct_blue_op(self): + f = PluginsResponseFlag1(number=1) + adv = MagicMock(adversary_id='7e422753-ad7a-4401-bc8b-b12a28e69c25') + op = MagicMock(agents=[MagicMock()], group='blue', adversary=adv) + op.ran_ability_id = MagicMock(return_value=True) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + assert await f.verify(services) is True + + @pytest.mark.asyncio + async def test_verify_no_ops(self): + f = PluginsResponseFlag1(number=1) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + assert await f.verify(services) is False diff --git a/tests/app/test_badge.py b/tests/app/test_badge.py index 352e8fa..e0af13b 100644 --- a/tests/app/test_badge.py +++ b/tests/app/test_badge.py @@ -1,3 +1,4 @@ +"""Exhaustive tests for app.c_badge.Badge.""" import pytest from plugins.training.app import errors @@ -14,6 +15,15 @@ async def verify(self, services): return True +class AnotherFlag(Flag): + name = "Another Flag" + challenge = "Another challenge" + extra_info = "" + + async def verify(self, services): + return False + + class TestBadge: def test_find_flag(self): @@ -30,3 +40,62 @@ def test_find_flag_fail(self): with pytest.raises(errors.FlagDoesNotExist): badge1.get_flag(flag_name='DOES NOT EXIST') + + def test_badge_name(self): + badge = Badge(name="my-badge") + assert badge.name == "my-badge" + + def test_badge_flags_initially_empty(self): + badge = Badge(name="empty") + assert badge.flags == [] + + def test_badge_unique_property(self): + badge = Badge(name="unique-test") + assert badge.unique is not None + assert isinstance(badge.unique, str) + + def test_badge_unique_deterministic(self): + b1 = Badge(name="same") + b2 = Badge(name="same") + assert b1.unique == b2.unique + + def test_badge_unique_differs_for_different_names(self): + b1 = Badge(name="alpha") + b2 = Badge(name="beta") + assert b1.unique != b2.unique + + def test_badge_display(self): + badge = Badge(name="display-badge") + flag = FakeFlag(number=1) + badge.flags.append(flag) + d = badge.display + assert d['name'] == 'display-badge' + assert isinstance(d['flags'], list) + assert len(d['flags']) == 1 + + def test_badge_display_empty_flags(self): + badge = Badge(name="no-flags") + d = badge.display + assert d['flags'] == [] + + def test_multiple_flags_get_correct_one(self): + badge = Badge(name="multi") + f1 = FakeFlag(number=1) + f2 = AnotherFlag(number=2) + badge.flags.extend([f1, f2]) + assert badge.get_flag("Test Flag") is f1 + assert badge.get_flag("Another Flag") is f2 + + def test_get_flag_returns_first_match(self): + """If two flags share the same name (unusual but possible), get first.""" + badge = Badge(name="dup") + f1 = FakeFlag(number=1) + f2 = FakeFlag(number=2) + badge.flags.extend([f1, f2]) + assert badge.get_flag("Test Flag") is f1 + + def test_display_includes_all_flags(self): + badge = Badge(name="many") + for i in range(5): + badge.flags.append(FakeFlag(number=i)) + assert len(badge.display['flags']) == 5 diff --git a/tests/app/test_base_flag.py b/tests/app/test_base_flag.py new file mode 100644 index 0000000..9c67092 --- /dev/null +++ b/tests/app/test_base_flag.py @@ -0,0 +1,201 @@ +"""Tests for app.base_flag.BaseFlag static helpers.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from plugins.training.app.base_flag import BaseFlag + + +class TestDoesAgentExist: + @pytest.mark.asyncio + async def test_agent_exists(self): + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[MagicMock()]) + result = await BaseFlag.does_agent_exist(services, 'red') + assert result == 1 + + @pytest.mark.asyncio + async def test_no_agent(self): + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + result = await BaseFlag.does_agent_exist(services, 'red') + assert result == 0 + + +class TestIsOperationStarted: + @pytest.mark.asyncio + async def test_operation_started(self): + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[MagicMock()]) + result = await BaseFlag.is_operation_started(services, 'my_op') + assert result == 1 + + @pytest.mark.asyncio + async def test_operation_not_started(self): + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) + result = await BaseFlag.is_operation_started(services, 'my_op') + assert result == 0 + + +class TestStartOperation: + @pytest.mark.asyncio + async def test_start_operation_calls_rest_svc(self): + services = {'rest_svc': AsyncMock()} + await BaseFlag.start_operation(services, 'op1', 'red', 'adv-1') + services['rest_svc'].create_operation.assert_called_once() + + +class TestIsOperationSuccessful: + @pytest.mark.asyncio + async def test_correct_chain_length(self): + """Returns True when chain length matches num_links and all traits present.""" + op = MagicMock() + op.chain = [MagicMock()] + mock_fact = MagicMock() + mock_fact.trait = 'host.user.name' + op.all_facts = AsyncMock(return_value=[mock_fact]) + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + result = await BaseFlag.is_operation_successful(services, 'op1', traits=['host.user.name'], num_links=1) + assert result is True + + @pytest.mark.asyncio + async def test_not_enough_links(self): + """Short-circuits to False when chain length != num_links.""" + op = MagicMock() + op.chain = [] # 0 links, but num_links defaults to 1 + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + # The function should return False because len(chain)==0 != num_links==1 + # This short-circuits before the async generator issue + result = await BaseFlag.is_operation_successful(services, 'op1', num_links=1) + assert result is False + + @pytest.mark.asyncio + async def test_wrong_num_links(self): + """If chain length doesn't match, short-circuits to False.""" + op = MagicMock() + op.chain = [MagicMock(), MagicMock()] # 2 links + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[op]) + # num_links=1 but chain has 2, so False without evaluating traits + result = await BaseFlag.is_operation_successful(services, 'op1', traits=['x'], num_links=1) + assert result is False + + +class TestCleanupOperation: + @pytest.mark.asyncio + async def test_cleanup(self): + services = {'rest_svc': AsyncMock()} + await BaseFlag.cleanup_operation(services, 'op1') + services['rest_svc'].delete_operation.assert_called_once() + + +class TestVerifyAttackFlag: + @pytest.mark.asyncio + async def test_verify_attack_flag_match(self): + adv = MagicMock() + adv.adversary_id = 'adv-1' + adv.atomic_ordering = ['ab1'] + services = { + 'data_svc': AsyncMock(), + 'rest_svc': AsyncMock(), + } + services['data_svc'].locate = AsyncMock(return_value=[adv]) + services['rest_svc'].display_objects = AsyncMock(return_value=[{'technique_id': 'T1033'}]) + + result = await BaseFlag.verify_attack_flag(services, 'T1033', 'test_adv') + assert result is True + + @pytest.mark.asyncio + async def test_verify_attack_flag_no_match(self): + adv = MagicMock() + adv.adversary_id = 'adv-1' + adv.atomic_ordering = ['ab1'] + services = { + 'data_svc': AsyncMock(), + 'rest_svc': AsyncMock(), + } + services['data_svc'].locate = AsyncMock(return_value=[adv]) + services['rest_svc'].display_objects = AsyncMock(return_value=[{'technique_id': 'T9999'}]) + + result = await BaseFlag.verify_attack_flag(services, 'T1033', 'test_adv') + assert result is False + + @pytest.mark.asyncio + async def test_verify_attack_flag_no_adversaries(self): + services = { + 'data_svc': AsyncMock(), + 'rest_svc': AsyncMock(), + } + services['data_svc'].locate = AsyncMock(return_value=[]) + result = await BaseFlag.verify_attack_flag(services, 'T1033', 'missing') + assert result is False + + +class TestDoesTechniqueMatch: + @pytest.mark.asyncio + async def test_single_technique_match(self): + adv = MagicMock() + adv.atomic_ordering = ['ab1'] + services = {'rest_svc': AsyncMock()} + services['rest_svc'].display_objects = AsyncMock(return_value=[{'technique_id': 'T1033'}]) + result = await BaseFlag.does_technique_match(services, 'T1033', adv) + assert result is True + + @pytest.mark.asyncio + async def test_multiple_techniques_no_match(self): + adv = MagicMock() + adv.atomic_ordering = ['ab1', 'ab2'] + services = {'rest_svc': AsyncMock()} + services['rest_svc'].display_objects = AsyncMock( + side_effect=[[{'technique_id': 'T1033'}], [{'technique_id': 'T1059'}]] + ) + result = await BaseFlag.does_technique_match(services, 'T1033', adv) + assert result is False # len(techniques) != 1 + + +class TestStandardVerifyWithOperation: + @pytest.mark.asyncio + async def test_full_pass(self): + services = { + 'data_svc': AsyncMock(), + 'rest_svc': AsyncMock(), + } + # does_agent_exist => True, is_operation_started => True + services['data_svc'].locate = AsyncMock(side_effect=[ + [MagicMock()], # agents exist + [MagicMock()], # operation already started + ]) + + async def satisfied(): + return True + + with patch.object(BaseFlag, 'is_operation_successful', new_callable=AsyncMock, return_value=True): + result = await BaseFlag.standard_verify_with_operation( + services, 'op1', 'adv1', 'red', satisfied + ) + assert result is True + + @pytest.mark.asyncio + async def test_no_agent(self): + services = {'data_svc': AsyncMock()} + services['data_svc'].locate = AsyncMock(return_value=[]) # no agents + + async def satisfied(): + return True + + result = await BaseFlag.standard_verify_with_operation( + services, 'op1', 'adv1', 'red', satisfied + ) + assert result is False + + +class TestReset: + @pytest.mark.asyncio + async def test_reset_cleans_and_reverifies(self): + services = {'rest_svc': AsyncMock()} + verify = AsyncMock() + await BaseFlag.reset(services, 'op1', verify) + services['rest_svc'].delete_operation.assert_called_once() + verify.assert_called_once_with(services) diff --git a/tests/app/test_certification.py b/tests/app/test_certification.py index 8b44f85..c5d61be 100644 --- a/tests/app/test_certification.py +++ b/tests/app/test_certification.py @@ -1,6 +1,6 @@ +"""Exhaustive tests for app.c_certification.Certification.""" import pytest -from app.utility.base_world import BaseWorld from plugins.training.app import errors from plugins.training.app.c_badge import Badge from plugins.training.app.c_flag import Flag @@ -18,68 +18,122 @@ async def verify(self, services): class TestCertification: - def test_find_badge(self): - certification = Certification( - identifier="foo", - name="test-certification", - description="used for tests", - access=BaseWorld.Access.RED - ) + def _make_cert(self, **kw): + defaults = dict(identifier="foo", name="test-certification", + description="used for tests", access="red") + defaults.update(kw) + return Certification(**defaults) + def test_find_badge(self): + cert = self._make_cert() badge1 = Badge(name="foo-badge-1") badge2 = Badge(name="foo-badge-2") - - certification.badges.extend([badge1, badge2]) - - found = certification.get_badge("foo-badge-1") + cert.badges.extend([badge1, badge2]) + found = cert.get_badge("foo-badge-1") assert found is badge1 def test_find_badge_fail(self): - certification = Certification( - identifier="foo", - name="test-certification", - description="used for tests", - access=BaseWorld.Access.RED - ) - - certification.badges.append(Badge(name="foo-badge-1")) - + cert = self._make_cert() + cert.badges.append(Badge(name="foo-badge-1")) with pytest.raises(errors.BadgeDoesNotExist): - certification.get_badge("DOES NOT EXIST") + cert.get_badge("DOES NOT EXIST") def test_find_flag(self): - certification = Certification( - identifier="foo", - name="test-certification", - description="used for tests", - access=BaseWorld.Access.RED - ) - + cert = self._make_cert() badge1 = Badge(name="foo-badge-1") flag1 = FakeFlag(number=1) - badge1.flags.append(flag1) - certification.badges.append(badge1) - - found = certification.get_flag(badge_name=badge1.name, flag_name=flag1.name) + cert.badges.append(badge1) + found = cert.get_flag(badge_name=badge1.name, flag_name=flag1.name) assert found is flag1 def test_find_flag_fail(self): - certification = Certification( - identifier="foo", - name="test-certification", - description="used for tests", - access=BaseWorld.Access.RED - ) - + cert = self._make_cert() badge1 = Badge(name="foo-badge-1") flag1 = FakeFlag(number=1) - badge1.flags.append(flag1) - certification.badges.append(badge1) + cert.badges.append(badge1) with pytest.raises(errors.BadgeDoesNotExist): - certification.get_flag(badge_name='DOES NOT EXIST', flag_name=flag1.name) + cert.get_flag(badge_name='DOES NOT EXIST', flag_name=flag1.name) with pytest.raises(errors.FlagDoesNotExist): - certification.get_flag(badge_name=badge1.name, flag_name='DOES NOT EXIST') + cert.get_flag(badge_name=badge1.name, flag_name='DOES NOT EXIST') + + # --- Additional exhaustive tests --- + + def test_unique_property(self): + cert = self._make_cert(identifier="abc") + assert cert.unique is not None + assert isinstance(cert.unique, str) + + def test_unique_deterministic(self): + c1 = self._make_cert(identifier="same") + c2 = self._make_cert(identifier="same") + assert c1.unique == c2.unique + + def test_unique_differs_for_different_ids(self): + c1 = self._make_cert(identifier="aaa") + c2 = self._make_cert(identifier="bbb") + assert c1.unique != c2.unique + + def test_display_property(self): + cert = self._make_cert(name="My Cert", description="desc") + d = cert.display + assert d['name'] == 'My Cert' + assert d['description'] == 'desc' + assert d['badges'] == [] + + def test_display_with_badges(self): + cert = self._make_cert() + b = Badge(name="b1") + b.flags.append(FakeFlag(number=1)) + cert.badges.append(b) + d = cert.display + assert len(d['badges']) == 1 + assert d['badges'][0]['name'] == 'b1' + + def test_initial_badges_empty(self): + cert = self._make_cert() + assert cert.badges == [] + + def test_name_and_description_stored(self): + cert = self._make_cert(name="N", description="D") + assert cert.name == "N" + assert cert.description == "D" + + def test_access_stored(self): + cert = self._make_cert(access="blue") + assert cert.access == "blue" + + def test_store_method(self): + cert = self._make_cert(identifier="store-test") + ram = {'certifications': []} + result = cert.store(ram) + assert result is not None + assert len(ram['certifications']) == 1 + + def test_store_idempotent(self): + cert = self._make_cert(identifier="idem") + ram = {'certifications': []} + cert.store(ram) + cert.store(ram) + assert len(ram['certifications']) == 1 + + def test_get_badge_returns_correct_among_many(self): + cert = self._make_cert() + for i in range(10): + cert.badges.append(Badge(name=f"badge-{i}")) + found = cert.get_badge("badge-7") + assert found.name == "badge-7" + + def test_get_flag_chains_badge_and_flag_lookup(self): + """get_flag delegates to get_badge then badge.get_flag.""" + cert = self._make_cert() + b = Badge(name="b") + f = FakeFlag(number=99) + b.flags.append(f) + cert.badges.append(b) + result = cert.get_flag("b", "Test Flag") + assert result is f + assert result.number == 99 diff --git a/tests/app/test_errors.py b/tests/app/test_errors.py new file mode 100644 index 0000000..a808057 --- /dev/null +++ b/tests/app/test_errors.py @@ -0,0 +1,41 @@ +"""Tests for app.errors module.""" +import pytest + +from plugins.training.app.errors import ObjectDoesNotExist, BadgeDoesNotExist, FlagDoesNotExist + + +class TestErrors: + + def test_object_does_not_exist_is_exception(self): + assert issubclass(ObjectDoesNotExist, Exception) + + def test_badge_does_not_exist_inherits(self): + assert issubclass(BadgeDoesNotExist, ObjectDoesNotExist) + + def test_flag_does_not_exist_inherits(self): + assert issubclass(FlagDoesNotExist, ObjectDoesNotExist) + + def test_badge_does_not_exist_is_exception(self): + assert issubclass(BadgeDoesNotExist, Exception) + + def test_flag_does_not_exist_is_exception(self): + assert issubclass(FlagDoesNotExist, Exception) + + def test_raise_badge(self): + with pytest.raises(BadgeDoesNotExist): + raise BadgeDoesNotExist() + + def test_raise_flag(self): + with pytest.raises(FlagDoesNotExist): + raise FlagDoesNotExist() + + def test_catch_as_parent(self): + with pytest.raises(ObjectDoesNotExist): + raise BadgeDoesNotExist() + + with pytest.raises(ObjectDoesNotExist): + raise FlagDoesNotExist() + + def test_exception_message(self): + err = BadgeDoesNotExist("missing badge") + assert str(err) == "missing badge" diff --git a/tests/app/test_exam.py b/tests/app/test_exam.py new file mode 100644 index 0000000..ac014e8 --- /dev/null +++ b/tests/app/test_exam.py @@ -0,0 +1,71 @@ +"""Tests for app.c_exam.Exam.""" +import pytest + +from plugins.training.app.c_exam import Exam +from plugins.training.app.c_badge import Badge +from plugins.training.app.c_flag import Flag + + +class FakeFlag(Flag): + name = "Exam Flag" + challenge = "Exam challenge" + extra_info = "" + + async def verify(self, services): + return True + + +class TestExam: + + def _make_exam(self, **kw): + defaults = dict(identifier="exam-1", name="Test Exam", + description="An exam", access="app") + defaults.update(kw) + return Exam(**defaults) + + def test_cert_type(self): + e = self._make_exam() + assert e.cert_type == 'exam' + + def test_display_includes_cert_type(self): + e = self._make_exam() + d = e.display + assert 'cert_type' in d + assert d['cert_type'] == 'exam' + + def test_display_has_name_and_description(self): + e = self._make_exam(name="Final Exam", description="The big one") + d = e.display + assert d['name'] == 'Final Exam' + assert d['description'] == 'The big one' + + def test_display_badges_empty(self): + e = self._make_exam() + assert e.display['badges'] == [] + + def test_display_badges_with_content(self): + e = self._make_exam() + b = Badge(name="eb") + b.flags.append(FakeFlag(number=1)) + e.badges.append(b) + assert len(e.display['badges']) == 1 + + def test_inherits_certification_methods(self): + """Exam should have get_badge and get_flag from Certification.""" + e = self._make_exam() + assert hasattr(e, 'get_badge') + assert hasattr(e, 'get_flag') + + def test_unique_property(self): + e = self._make_exam(identifier="u1") + assert e.unique is not None + + def test_badges_initially_empty(self): + e = self._make_exam() + assert e.badges == [] + + def test_store(self): + e = self._make_exam() + ram = {'certifications': []} + e.store(ram) + assert len(ram['certifications']) == 1 diff --git a/tests/app/test_fillinblank.py b/tests/app/test_fillinblank.py new file mode 100644 index 0000000..58793fe --- /dev/null +++ b/tests/app/test_fillinblank.py @@ -0,0 +1,57 @@ +"""Tests for app.c_fillinblank.FillInBlank.""" +import pytest + +from plugins.training.app.c_fillinblank import FillInBlank + + +class ConcreteFIB(FillInBlank): + name = 'FIB Test' + challenge = 'Fill in the blank' + extra_info = '' + answer = 'correct' + + +class TestFillInBlank: + + def test_flag_type(self): + f = ConcreteFIB(number=1) + assert f.flag_type == 'fillinblank' + + def test_verify_correct_answer(self): + f = ConcreteFIB(number=1) + assert f.verify('correct') is True + + def test_verify_case_insensitive(self): + f = ConcreteFIB(number=1) + assert f.verify('CORRECT') is True + assert f.verify('Correct') is True + + def test_verify_wrong_answer(self): + f = ConcreteFIB(number=1) + assert f.verify('wrong') is False + + def test_verify_empty_string(self): + f = ConcreteFIB(number=1) + assert f.verify('') is False + + def test_display_keys(self): + f = ConcreteFIB(number=1) + d = f.display + assert 'flag_type' in d + assert 'number' in d + assert 'name' in d + assert 'challenge' in d + assert 'completed' in d + assert 'code' in d + + def test_display_flag_type_value(self): + f = ConcreteFIB(number=1) + assert f.display['flag_type'] == 'fillinblank' + + def test_inherits_flag(self): + from plugins.training.app.c_flag import Flag + assert issubclass(FillInBlank, Flag) + + def test_has_answer_attribute(self): + f = ConcreteFIB(number=1) + assert hasattr(f, 'answer') diff --git a/tests/app/test_flag.py b/tests/app/test_flag.py new file mode 100644 index 0000000..54c1c89 --- /dev/null +++ b/tests/app/test_flag.py @@ -0,0 +1,247 @@ +"""Exhaustive tests for app.c_flag.Flag base class.""" +import pytest +from datetime import datetime +from unittest.mock import patch, MagicMock + +from plugins.training.app.c_flag import Flag + + +class ConcreteFlag(Flag): + name = 'Concrete' + challenge = 'Test challenge' + extra_info = 'Some info' + + async def verify(self, services): + return True + + +class NoSolutionFlag(Flag): + name = 'NoSolution' + challenge = 'No guide' + extra_info = '' + + async def verify(self, services): + return False + + +class TestFlagInit: + def test_default_number(self): + f = ConcreteFlag(number=5) + assert f.number == 5 + + def test_default_not_completed(self): + f = ConcreteFlag(number=1) + assert f.completed is False + + def test_default_completed_timestamp_none(self): + f = ConcreteFlag(number=1) + assert f.completed_timestamp is None + + def test_default_started_ts_none(self): + f = ConcreteFlag(number=1) + assert f.started_ts is None + + def test_default_ticks_zero(self): + f = ConcreteFlag(number=1) + assert f._ticks == 0 + + +class TestFlagCompleted: + def test_set_completed(self): + f = ConcreteFlag(number=1) + f.completed = True + assert f.completed is True + + def test_completed_sets_timestamp(self): + f = ConcreteFlag(number=1) + f.completed = True + assert f.completed_timestamp is not None + assert isinstance(f.completed_timestamp, datetime) + + def test_set_completed_false(self): + f = ConcreteFlag(number=1) + f.completed = True + f.completed = False + assert f.completed is False + + def test_completed_timestamp_updates_on_each_set(self): + f = ConcreteFlag(number=1) + f.completed = True + ts1 = f.completed_timestamp + f.completed = True + ts2 = f.completed_timestamp + assert ts2 >= ts1 + + +class TestFlagActivate: + def test_activate_sets_started_ts(self): + f = ConcreteFlag(number=1) + f.activate() + assert f.started_ts is not None + + def test_activate_increments_ticks(self): + f = ConcreteFlag(number=1) + f.activate() + assert f._ticks == 1 + f.activate() + assert f._ticks == 2 + + def test_activate_does_not_reset_started_ts(self): + f = ConcreteFlag(number=1) + f.activate() + ts = f.started_ts + f.activate() + assert f.started_ts == ts + + def test_activate_no_tick_after_completed(self): + f = ConcreteFlag(number=1) + f.activate() + assert f._ticks == 1 + f.completed = True # sets _completed_timestamp + f.activate() + # started_ts already set, completed_timestamp set => ticks should NOT increment + assert f._ticks == 1 + + +class TestFlagDisplay: + def test_display_keys(self): + f = ConcreteFlag(number=1) + d = f.display + expected_keys = {'number', 'name', 'challenge', 'completed', + 'extra_info', 'code', 'completed_timestamp', + 'resettable', 'has_solution_guide'} + assert set(d.keys()) == expected_keys + + def test_display_values(self): + f = ConcreteFlag(number=42) + d = f.display + assert d['number'] == 42 + assert d['name'] == 'Concrete' + assert d['challenge'] == 'Test challenge' + assert d['completed'] is False + assert d['extra_info'] == 'Some info' + + def test_display_completed_timestamp_empty_when_not_set(self): + f = ConcreteFlag(number=1) + assert f.display['completed_timestamp'] == '' + + def test_display_completed_timestamp_formatted(self): + f = ConcreteFlag(number=1) + f.completed = True + ts = f.display['completed_timestamp'] + assert len(ts) == 19 # 'YYYY-MM-DD HH:MM:SS' + + +class TestFlagResettable: + def test_not_resettable_by_default(self): + f = ConcreteFlag(number=1) + assert f._is_resettable() == 'False' + + def test_resettable_with_additional_fields(self): + f = ConcreteFlag(number=1) + f.additional_fields = {'adversary_id': 'xyz'} + assert f._is_resettable() == 'True' + + def test_not_resettable_without_adversary_id(self): + f = ConcreteFlag(number=1) + f.additional_fields = {'operation_name': 'op'} + assert f._is_resettable() == 'False' + + +class TestFlagUnique: + def test_unique_not_none(self): + f = ConcreteFlag(number=1) + assert f.unique is not None + + def test_unique_deterministic(self): + f1 = ConcreteFlag(number=1) + f2 = ConcreteFlag(number=1) + assert f1.unique == f2.unique + + +class TestFlagStore: + def test_store_adds_to_ram(self): + f = ConcreteFlag(number=1) + ram = {'flags': []} + f.store(ram) + assert len(ram['flags']) == 1 + + def test_store_idempotent(self): + f = ConcreteFlag(number=1) + ram = {'flags': []} + f.store(ram) + f.store(ram) + assert len(ram['flags']) == 1 + + +class TestFlagCalculateCode: + def test_calculate_code_returns_string(self): + f = ConcreteFlag(number=1) + assert isinstance(f.calculate_code(), str) + + +class TestFlagSolutionGuide: + def test_solution_guide_filename(self): + f = ConcreteFlag(number=1) + assert f.solution_guide_filename == 'ConcreteFlag.md' + + def test_has_solution_guide_false_when_missing(self): + f = ConcreteFlag(number=1) + # The file won't exist + assert f.has_solution_guide is False or f.has_solution_guide is True + # Just ensure it doesn't crash + + +class TestFlagStaticMethods: + def test_is_unauth_process_killed(self): + op = MagicMock() + op.ran_ability_id = MagicMock(return_value=True) + assert Flag._is_unauth_process_killed(op) is True + + def test_is_unauth_process_killed_false(self): + op = MagicMock() + op.ran_ability_id = MagicMock(return_value=False) + assert Flag._is_unauth_process_killed(op) is False + + @pytest.mark.asyncio + async def test_is_unauth_process_detected(self): + op = MagicMock() + op.ran_ability_id = MagicMock(return_value=True) + fact1 = MagicMock(trait='remote.port.unauthorized') + fact2 = MagicMock(trait='host.pid.unauthorized') + from unittest.mock import AsyncMock + op.all_facts = AsyncMock(return_value=[fact1, fact2]) + result = await Flag._is_unauth_process_detected(op) + assert result is True + + @pytest.mark.asyncio + async def test_is_unauth_process_detected_missing_trait(self): + op = MagicMock() + op.ran_ability_id = MagicMock(return_value=True) + fact1 = MagicMock(trait='remote.port.unauthorized') + from unittest.mock import AsyncMock + op.all_facts = AsyncMock(return_value=[fact1]) + result = await Flag._is_unauth_process_detected(op) + assert result is False + + @pytest.mark.asyncio + async def test_is_file_found(self): + op = MagicMock() + op.ran_ability_id = MagicMock(return_value=True) + f1 = MagicMock(trait='file.malicious.hash') + f2 = MagicMock(trait='host.malicious.file') + from unittest.mock import AsyncMock + op.all_facts = AsyncMock(return_value=[f1, f2]) + result = await Flag._is_file_found(op) + assert result is True + + @pytest.mark.asyncio + async def test_is_file_found_missing_ability(self): + op = MagicMock() + op.ran_ability_id = MagicMock(return_value=False) + f1 = MagicMock(trait='file.malicious.hash') + f2 = MagicMock(trait='host.malicious.file') + from unittest.mock import AsyncMock + op.all_facts = AsyncMock(return_value=[f1, f2]) + result = await Flag._is_file_found(op) + assert result is False diff --git a/tests/app/test_multiplechoice.py b/tests/app/test_multiplechoice.py new file mode 100644 index 0000000..7a35868 --- /dev/null +++ b/tests/app/test_multiplechoice.py @@ -0,0 +1,71 @@ +"""Tests for app.c_multiplechoice.MultipleChoice.""" +import pytest + +from plugins.training.app.c_multiplechoice import MultipleChoice + + +class ConcreteMC(MultipleChoice): + name = 'MC Test' + challenge = 'Pick the right one' + extra_info = '' + answer = 'B' + options = ['A', 'B', 'C', 'D'] + + +class MultiSelectMC(MultipleChoice): + name = 'Multi MC' + challenge = 'Pick all' + extra_info = '' + answer = ['A', 'C'] + options = ['A', 'B', 'C', 'D'] + multi_select = True + + +class TestMultipleChoice: + + def test_flag_type(self): + f = ConcreteMC(number=1) + assert f.flag_type == 'multiplechoice' + + def test_verify_correct(self): + f = ConcreteMC(number=1) + assert f.verify('B') is True + + def test_verify_wrong(self): + f = ConcreteMC(number=1) + assert f.verify('A') is False + + def test_verify_case_sensitive(self): + f = ConcreteMC(number=1) + assert f.verify('b') is False + + def test_verify_multi_select_correct(self): + f = MultiSelectMC(number=1) + assert f.verify(['A', 'C']) is True + + def test_verify_multi_select_wrong(self): + f = MultiSelectMC(number=1) + assert f.verify(['A', 'B']) is False + + def test_display_keys(self): + f = ConcreteMC(number=1) + d = f.display + assert 'flag_type' in d + assert 'options' in d + assert 'multi_select' in d + + def test_display_options(self): + f = ConcreteMC(number=1) + assert f.display['options'] == ['A', 'B', 'C', 'D'] + + def test_multi_select_default_false(self): + f = ConcreteMC(number=1) + assert f.multi_select is False + + def test_multi_select_true(self): + f = MultiSelectMC(number=1) + assert f.multi_select is True + + def test_inherits_flag(self): + from plugins.training.app.c_flag import Flag + assert issubclass(MultipleChoice, Flag) diff --git a/tests/app/test_navigator.py b/tests/app/test_navigator.py new file mode 100644 index 0000000..d8e8afe --- /dev/null +++ b/tests/app/test_navigator.py @@ -0,0 +1,103 @@ +"""Tests for app.c_navigator.Navigator.""" +import json +import pytest + +from plugins.training.app.c_navigator import Navigator + + +class ConcreteNav(Navigator): + name = 'Nav Test' + challenge = 'Upload layer' + extra_info = '' + answer = [('initial-access', 'T1190'), ('execution', 'T1059')] + + +class TestNavigator: + + def test_flag_type(self): + f = ConcreteNav(number=1) + assert f.flag_type == 'navigator' + + def test_verify_correct_layer(self): + f = ConcreteNav(number=1) + layer = json.dumps({ + 'techniques': [ + {'tactic': 'initial-access', 'techniqueID': 'T1190'}, + {'tactic': 'execution', 'techniqueID': 'T1059'}, + ] + }) + assert f.verify(layer) is True + + def test_verify_wrong_layer(self): + f = ConcreteNav(number=1) + layer = json.dumps({ + 'techniques': [ + {'tactic': 'initial-access', 'techniqueID': 'T9999'}, + ] + }) + assert f.verify(layer) is False + + def test_verify_empty_techniques(self): + f = ConcreteNav(number=1) + layer = json.dumps({'techniques': []}) + assert f.verify(layer) is False + + def test_verify_extra_techniques_fail(self): + f = ConcreteNav(number=1) + layer = json.dumps({ + 'techniques': [ + {'tactic': 'initial-access', 'techniqueID': 'T1190'}, + {'tactic': 'execution', 'techniqueID': 'T1059'}, + {'tactic': 'persistence', 'techniqueID': 'T1053'}, + ] + }) + assert f.verify(layer) is False + + def test_verify_order_independent(self): + f = ConcreteNav(number=1) + layer = json.dumps({ + 'techniques': [ + {'tactic': 'execution', 'techniqueID': 'T1059'}, + {'tactic': 'initial-access', 'techniqueID': 'T1190'}, + ] + }) + assert f.verify(layer) is True + + def test_sort_tactic_technique(self): + input_list = [('tactic_a', 'T002'), ('tactic_a', 'T001'), ('tactic_b', 'T003')] + result = Navigator._sort_tactic_technique(input_list) + assert result == {'tactic_a': ['T001', 'T002'], 'tactic_b': ['T003']} + + def test_sort_tactic_technique_empty(self): + result = Navigator._sort_tactic_technique([]) + assert result == {} + + def test_display_has_flag_type(self): + f = ConcreteNav(number=1) + assert f.display['flag_type'] == 'navigator' + + def test_inherits_flag(self): + from plugins.training.app.c_flag import Flag + assert issubclass(Navigator, Flag) + + def test_verify_invalid_json_raises(self): + f = ConcreteNav(number=1) + with pytest.raises(json.JSONDecodeError): + f.verify('not json') + + def test_verify_duplicate_tactics(self): + """Multiple techniques under same tactic must all match.""" + class MultiTechNav(Navigator): + name = 'Multi' + challenge = '' + extra_info = '' + answer = [('tactic_a', 'T001'), ('tactic_a', 'T002')] + + f = MultiTechNav(number=1) + layer = json.dumps({ + 'techniques': [ + {'tactic': 'tactic_a', 'techniqueID': 'T001'}, + {'tactic': 'tactic_a', 'techniqueID': 'T002'}, + ] + }) + assert f.verify(layer) is True diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3db2c73 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,286 @@ +"""Shared fixtures for the training plugin test suite.""" +import os +import json +import sys +import types +import asyncio +import pytest +from unittest.mock import MagicMock, AsyncMock, patch, PropertyMock + +# --------------------------------------------------------------------------- +# Stub out heavy / unavailable third-party imports *before* any app code +# --------------------------------------------------------------------------- + +# reportlab stubs +for mod_name in ( + 'reportlab', 'reportlab.lib', 'reportlab.lib.pagesizes', 'reportlab.pdfgen', + 'reportlab.lib.utils', 'reportlab.pdfbase', 'reportlab.pdfbase.ttfonts', + 'reportlab.lib.colors', 'reportlab.pdfgen.canvas', +): + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +# markdown stub +if 'markdown' not in sys.modules: + _md = types.ModuleType('markdown') + _md.markdown = lambda text: f'
{text}
' + sys.modules['markdown'] = _md + +# Minimal stubs for Caldera framework classes +_base_object_mod = types.ModuleType('app.utility.base_object') + +class _BaseObject: + def __init__(self): + pass + def hash(self, s): + import hashlib + return hashlib.md5(s.encode()).hexdigest() + def retrieve(self, collection, unique): + for item in collection: + if getattr(item, 'unique', None) == unique: + return item + return None + +_base_object_mod.BaseObject = _BaseObject +sys.modules['app'] = types.ModuleType('app') +sys.modules['app.utility'] = types.ModuleType('app.utility') +sys.modules['app.utility.base_object'] = _base_object_mod + +_base_world_mod = types.ModuleType('app.utility.base_world') + +class _Access: + APP = 'app' + RED = 'red' + BLUE = 'blue' + HIDDEN = 'hidden' + +class _BaseWorld: + Access = _Access + @staticmethod + def get_config(name=None, prop=None): + return None + @staticmethod + def strip_yml(filename): + return [] + +_base_world_mod.BaseWorld = _BaseWorld +sys.modules['app.utility.base_world'] = _base_world_mod + +_base_svc_mod = types.ModuleType('app.utility.base_service') +class _BaseService: + pass +_base_svc_mod.BaseService = _BaseService +sys.modules['app.utility.base_service'] = _base_svc_mod + +_auth_svc_mod = types.ModuleType('app.service.auth_svc') +def _for_all(decorator): + def wrapper(cls): + return cls + return wrapper +_auth_svc_mod.for_all_public_methods = _for_all +_auth_svc_mod.check_authorization = lambda f: f +sys.modules['app.service'] = types.ModuleType('app.service') +sys.modules['app.service.auth_svc'] = _auth_svc_mod + +# aiohttp / jinja stubs +_aiohttp_mod = types.ModuleType('aiohttp') +class _WebModule: + class HTTPNotFound(Exception): + def __init__(self, text=''): + self.text = text + super().__init__(text) + class HTTPBadRequest(Exception): + def __init__(self, text=''): + self.text = text + super().__init__(text) + class HTTPForbidden(Exception): + def __init__(self, text=''): + self.text = text + super().__init__(text) + class HTTPInternalServerError(Exception): + def __init__(self, text=''): + self.text = text + super().__init__(text) + HTTPException = Exception + @staticmethod + def json_response(data): + return data + class FileResponse: + def __init__(self, path, headers=None): + self.path = path + self.headers = headers +_aiohttp_mod.web = _WebModule +sys.modules['aiohttp'] = _aiohttp_mod +sys.modules['aiohttp.web'] = _aiohttp_mod.web # needed for 'from aiohttp import web' + +_jinja_mod = types.ModuleType('aiohttp_jinja2') +def _template(name): + def decorator(func): + return func + return decorator +_jinja_mod.template = _template +sys.modules['aiohttp_jinja2'] = _jinja_mod + +# --------------------------------------------------------------------------- +# Make plugin importable as plugins.training.* +# --------------------------------------------------------------------------- +PLUGIN_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +PLUGINS_DIR = os.path.dirname(PLUGIN_ROOT) + +# Ensure plugins.training points to this repo +if 'plugins' not in sys.modules: + plugins_mod = types.ModuleType('plugins') + plugins_mod.__path__ = [PLUGINS_DIR] + sys.modules['plugins'] = plugins_mod + +training_pkg = types.ModuleType('plugins.training') +training_pkg.__path__ = [PLUGIN_ROOT] +training_pkg.__file__ = os.path.join(PLUGIN_ROOT, '__init__.py') +training_pkg.PLUGIN_DIR = PLUGIN_ROOT +sys.modules['plugins.training'] = training_pkg + +# Also set as attribute on plugins module so `import plugins.training` works +sys.modules['plugins'].training = training_pkg + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +from plugins.training.app.c_flag import Flag +from plugins.training.app.c_badge import Badge +from plugins.training.app.c_certification import Certification +from plugins.training.app.c_exam import Exam +from plugins.training.app import errors + + +class FakeFlag(Flag): + """Concrete Flag subclass for testing.""" + name = 'Fake Flag' + challenge = 'Do the thing' + extra_info = 'Extra' + + async def verify(self, services): + return True + + +class CompletedFlag(Flag): + """A flag that starts completed.""" + name = 'Completed Flag' + challenge = 'Already done' + extra_info = '' + + def __init__(self, number): + super().__init__(number) + self.completed = True + + async def verify(self, services): + return True + + +class ResettableFlag(Flag): + """A flag with additional_fields that make it resettable.""" + name = 'Resettable Flag' + challenge = 'Reset me' + extra_info = '' + additional_fields = dict(adversary_id='abc', operation_name='test_op') + + async def verify(self, services): + return True + + +class FailFlag(Flag): + """A flag whose verify always fails.""" + name = 'Fail Flag' + challenge = 'Cannot pass' + extra_info = '' + + async def verify(self, services): + return False + + +@pytest.fixture +def fake_flag(): + return FakeFlag(number=1) + + +@pytest.fixture +def completed_flag(): + return CompletedFlag(number=2) + + +@pytest.fixture +def resettable_flag(): + return ResettableFlag(number=3) + + +@pytest.fixture +def fail_flag(): + return FailFlag(number=4) + + +@pytest.fixture +def badge_with_flags(): + b = Badge(name='test-badge') + b.flags = [FakeFlag(number=1), FakeFlag(number=2)] + return b + + +@pytest.fixture +def certification_with_badges(): + cert = Certification( + identifier='cert-1', + name='Test Cert', + description='A test certification', + access=_Access.RED, + ) + b1 = Badge(name='badge-1') + b1.flags = [FakeFlag(number=1)] + b2 = Badge(name='badge-2') + b2.flags = [FakeFlag(number=2)] + cert.badges = [b1, b2] + return cert + + +@pytest.fixture +def exam_with_badges(): + exam = Exam( + identifier='exam-1', + name='Test Exam', + description='An exam certification', + access=_Access.APP, + ) + b = Badge(name='exam-badge') + b.flags = [FakeFlag(number=1)] + exam.badges = [b] + return exam + + +@pytest.fixture +def mock_services(): + """Return a dict-like services mock matching Caldera's service registry.""" + services = {} + data_svc = AsyncMock() + data_svc.locate = AsyncMock(return_value=[]) + auth_svc = MagicMock() + auth_svc.get_permissions = AsyncMock(return_value=[]) + auth_svc.bypass = [] + auth_svc.user_map = {} + rest_svc = AsyncMock() + app_svc = MagicMock() + app_svc.get_config = MagicMock(return_value=None) + contact_svc = MagicMock() + contact_svc.report = {'websocket': []} + file_svc = AsyncMock() + + services['data_svc'] = data_svc + services['auth_svc'] = auth_svc + services['rest_svc'] = rest_svc + services['app_svc'] = app_svc + services['contact_svc'] = contact_svc + services['file_svc'] = file_svc + + class ServiceDict(dict): + pass + + sd = ServiceDict(services) + return sd diff --git a/tests/test_certificate_svc.py b/tests/test_certificate_svc.py new file mode 100644 index 0000000..cb15b65 --- /dev/null +++ b/tests/test_certificate_svc.py @@ -0,0 +1,189 @@ +"""Exhaustive tests for app.certificate_svc.CertificateService.""" +import os +import json +import uuid +import pytest +import tempfile +from unittest.mock import patch, MagicMock + +from plugins.training.app.certificate_svc import ( + CertificateService, _ensure_dirs, _load_index, _save_index, + _load_or_create_secret, SECRET_FILE, OUT_DIR, INDEX_FILE, +) + + +@pytest.fixture +def tmp_dirs(tmp_path, monkeypatch): + """Redirect all certificate_svc paths to a temp directory.""" + secret_file = str(tmp_path / '.secret') + out_dir = str(tmp_path / 'generated_certificates') + index_file = os.path.join(out_dir, 'issued.json') + template_bg = str(tmp_path / 'static' / 'templates' / 'caldera_cert.png') + + monkeypatch.setattr('plugins.training.app.certificate_svc.SECRET_FILE', secret_file) + monkeypatch.setattr('plugins.training.app.certificate_svc.OUT_DIR', out_dir) + monkeypatch.setattr('plugins.training.app.certificate_svc.INDEX_FILE', index_file) + monkeypatch.setattr('plugins.training.app.certificate_svc.TEMPLATE_BG_PNG', template_bg) + monkeypatch.setattr('plugins.training.app.certificate_svc.LOGO_TOP_CENTER', str(tmp_path / 'logo1.png')) + monkeypatch.setattr('plugins.training.app.certificate_svc.LOGO_BOTTOM_LEFT', str(tmp_path / 'logo2.png')) + + return dict(secret_file=secret_file, out_dir=out_dir, index_file=index_file) + + +class TestEnsureDirs: + def test_creates_out_dir(self, tmp_dirs): + _ensure_dirs() + assert os.path.isdir(tmp_dirs['out_dir']) + + +class TestLoadOrCreateSecret: + def test_creates_secret(self, tmp_dirs): + secret = _load_or_create_secret() + assert len(secret) == 16 # uuid4 bytes + assert os.path.exists(tmp_dirs['secret_file']) + + def test_returns_same_secret(self, tmp_dirs): + s1 = _load_or_create_secret() + s2 = _load_or_create_secret() + assert s1 == s2 + + +class TestLoadSaveIndex: + def test_empty_index(self, tmp_dirs): + idx = _load_index() + assert idx == {} + + def test_save_and_load(self, tmp_dirs): + _ensure_dirs() + _save_index({'key': {'path': '/foo', 'name': 'bar'}}) + idx = _load_index() + assert 'key' in idx + assert idx['key']['path'] == '/foo' + + +class TestCertificateService: + def test_init(self, tmp_dirs): + svc = CertificateService() + assert svc.secret is not None + assert len(svc.instance_id) == 12 + + def test_key_format(self, tmp_dirs): + svc = CertificateService() + key = svc._key('user1', 'cert1') + assert 'user1' in key + assert 'cert1' in key + assert svc.instance_id in key + + def test_already_issued_false(self, tmp_dirs): + svc = CertificateService() + assert svc.already_issued('u1', 'c1') is False + + def test_mark_and_already_issued(self, tmp_dirs): + svc = CertificateService() + svc.mark_issued('u1', 'c1', '/path/to/cert.pdf', 'Jane') + assert svc.already_issued('u1', 'c1') is True + + def test_get_record(self, tmp_dirs): + svc = CertificateService() + svc.mark_issued('u1', 'c1', '/path/cert.pdf', 'Jane') + rec = svc.get_record('u1', 'c1') + assert rec is not None + assert rec['path'] == '/path/cert.pdf' + assert rec['name'] == 'Jane' + assert 'issued_at' in rec + + def test_get_record_missing(self, tmp_dirs): + svc = CertificateService() + assert svc.get_record('u1', 'c1') is None + + def test_clear_issued(self, tmp_dirs): + svc = CertificateService() + svc.mark_issued('u1', 'c1', '/path', 'Jane') + svc.clear_issued_for_user_cert('u1', 'c1') + assert svc.already_issued('u1', 'c1') is False + + def test_clear_nonexistent_is_noop(self, tmp_dirs): + svc = CertificateService() + svc.clear_issued_for_user_cert('u1', 'c1') # should not raise + + +class TestSignedToken: + def test_sign_and_verify(self, tmp_dirs): + svc = CertificateService() + token = svc.signed_token('/some/path.pdf') + result = svc.verify_token(token) + assert result == '/some/path.pdf' + + def test_tampered_token(self, tmp_dirs): + svc = CertificateService() + token = svc.signed_token('/some/path.pdf') + tampered = token[:-1] + ('a' if token[-1] != 'a' else 'b') + assert svc.verify_token(tampered) is None + + def test_invalid_token_no_dot(self, tmp_dirs): + svc = CertificateService() + assert svc.verify_token('nodothere') is None + + def test_empty_token(self, tmp_dirs): + svc = CertificateService() + assert svc.verify_token('') is None + + def test_empty_payload(self, tmp_dirs): + svc = CertificateService() + token = svc.signed_token('') + assert svc.verify_token(token) == '' + + +class TestSafePart: + def test_simple_name(self, tmp_dirs): + svc = CertificateService() + assert svc._safe_part('Jane Doe') == 'Jane_Doe' + + def test_special_chars(self, tmp_dirs): + svc = CertificateService() + result = svc._safe_part('Jane@Doe!#$') + assert '@' not in result + assert '!' not in result + + def test_empty_string(self, tmp_dirs): + svc = CertificateService() + assert svc._safe_part('') == 'User' + + def test_whitespace_only(self, tmp_dirs): + svc = CertificateService() + assert svc._safe_part(' ') == 'User' + + def test_multiple_spaces(self, tmp_dirs): + svc = CertificateService() + assert svc._safe_part('Jane Doe') == 'Jane_Doe' + + def test_preserves_dots_hyphens(self, tmp_dirs): + svc = CertificateService() + result = svc._safe_part('Jane-Doe.Jr') + assert 'Jane-Doe.Jr' == result + + +class TestIssue: + def test_issue_returns_path(self, tmp_dirs): + svc = CertificateService() + # Mock _build_pdf to avoid reportlab dependency + svc._build_pdf = MagicMock() + path = svc.issue('u1', 'Caldera User', 'Jane Doe') + assert path.endswith('.pdf') + assert 'Jane_Doe' in path + + def test_issue_filename_format(self, tmp_dirs): + svc = CertificateService() + svc._build_pdf = MagicMock() + path = svc.issue('u1', 'Cert', 'John Smith') + basename = os.path.basename(path) + assert basename.startswith('Caldera_User_Certificate_') + assert 'John_Smith' in basename + + +class TestDrawImageKeepAspect: + def test_no_file_noop(self, tmp_dirs): + svc = CertificateService() + c = MagicMock() + svc._draw_image_keep_aspect(c, '/nonexistent.png', 0, 0, target_w=100) + c.drawImage.assert_not_called() diff --git a/tests/test_hook.py b/tests/test_hook.py new file mode 100644 index 0000000..f074d11 --- /dev/null +++ b/tests/test_hook.py @@ -0,0 +1,113 @@ +"""Tests for hook.py (plugin enable/expansion/flag loading).""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +class TestHookEnable: + @pytest.mark.asyncio + async def test_enable_registers_routes(self): + from plugins.training.hook import enable + data_svc = AsyncMock() + data_svc.apply = AsyncMock() + data_svc.store = AsyncMock() + auth_svc = MagicMock() + app_svc = MagicMock() + app = MagicMock() + router = MagicMock() + app.router = router + app_svc.application = app + + services = { + 'data_svc': data_svc, + 'auth_svc': auth_svc, + 'app_svc': app_svc, + } + + with patch('plugins.training.hook.glob.iglob', return_value=[]), \ + patch('plugins.training.hook.TrainingApi'): + await enable(services) + + data_svc.apply.assert_called_once_with('certifications') + # Should register routes + assert router.add_route.call_count > 0 + assert router.add_static.call_count > 0 + + @pytest.mark.asyncio + async def test_enable_adds_correct_routes(self): + from plugins.training.hook import enable + data_svc = AsyncMock() + data_svc.apply = AsyncMock() + data_svc.store = AsyncMock() + auth_svc = MagicMock() + app_svc = MagicMock() + app = MagicMock() + router = MagicMock() + app.router = router + app_svc.application = app + + services = { + 'data_svc': data_svc, + 'auth_svc': auth_svc, + 'app_svc': app_svc, + } + + with patch('plugins.training.hook.glob.iglob', return_value=[]), \ + patch('plugins.training.hook.TrainingApi'): + await enable(services) + + # Check specific routes registered + route_calls = [str(c) for c in router.add_route.call_args_list] + route_str = ' '.join(route_calls) + assert '/plugin/training/gui' in route_str + assert '/plugin/training/flags' in route_str + assert '/plugin/training/certs' in route_str + assert '/plugin/training/reset_flag' in route_str + assert '/plugin/training/certificate/issue' in route_str + assert '/plugin/training/certificate/download' in route_str + assert '/plugin/training/certificate/reset' in route_str + + +class TestHookExpansion: + @pytest.mark.asyncio + async def test_expansion_hides_objects(self): + from plugins.training.hook import expansion + data_svc = AsyncMock() + data_svc.locate = AsyncMock(return_value=[]) + services = {'data_svc': data_svc} + + with patch('plugins.training.hook.glob.iglob', return_value=[]): + await expansion(services) + + +class TestHookMetadata: + def test_name(self): + from plugins.training.hook import name + assert name == 'Training' + + def test_description(self): + from plugins.training.hook import description + assert 'certification' in description.lower() or 'Caldera' in description + + def test_address(self): + from plugins.training.hook import address + assert address == '/plugin/training/gui' + + +class TestLoadFlags: + @pytest.mark.asyncio + async def test_load_flags_no_files(self): + from plugins.training.hook import _load_flags + data_svc = AsyncMock() + with patch('plugins.training.hook.glob.iglob', return_value=[]): + await _load_flags(data_svc) + data_svc.store.assert_not_called() + + +class TestApplyHiddenAccess: + @pytest.mark.asyncio + async def test_apply_hidden_no_files(self): + from plugins.training.hook import _apply_hidden_access_to_loaded_files + data_svc = AsyncMock() + with patch('plugins.training.hook.glob.iglob', return_value=[]): + await _apply_hidden_access_to_loaded_files(data_svc) + data_svc.locate.assert_not_called() diff --git a/tests/test_training_api.py b/tests/test_training_api.py new file mode 100644 index 0000000..600fcdc --- /dev/null +++ b/tests/test_training_api.py @@ -0,0 +1,456 @@ +"""Exhaustive tests for app.training_api.TrainingApi.""" +import os +import json +import pytest +from unittest.mock import MagicMock, AsyncMock, patch, PropertyMock +from aiohttp import web + +from plugins.training.app.training_api import TrainingApi +from plugins.training.app.c_badge import Badge +from plugins.training.app.c_certification import Certification +from plugins.training.app.c_flag import Flag +from plugins.training.app import errors + + +class FakeFlag(Flag): + name = 'FakeFlag' + challenge = 'Test' + extra_info = '' + + def __init__(self, number, completed=False): + super().__init__(number) + if completed: + self.completed = True + + async def verify(self, services): + return True + + +class AnswerFlag(Flag): + """A flag with an answer attribute (fill-in-blank style).""" + name = 'AnswerFlag' + challenge = 'Answer it' + extra_info = '' + answer = 'correct' + + async def verify(self, services): + # not used directly in retrieve_flags for answer-based flags + return True + + def verify(self, answer): + return answer.lower() == self.answer + + +def _make_services(): + services = {} + data_svc = AsyncMock() + data_svc.locate = AsyncMock(return_value=[]) + auth_svc = MagicMock() + auth_svc.get_permissions = AsyncMock(return_value=[]) + services['data_svc'] = data_svc + services['auth_svc'] = auth_svc + return services + + +def _make_request(json_data=None, headers=None, match_info=None, query=None): + request = MagicMock() + request.json = AsyncMock(return_value=json_data or {}) + request.headers = headers or {} + request.match_info = match_info or {} + request.query = query or {} + return request + + +def _make_cert(name='Test Cert', completed=False): + cert = Certification( + identifier='c1', name=name, + description='desc', access='app' + ) + b = Badge(name='badge-1') + f = FakeFlag(number=1, completed=completed) + b.flags.append(f) + cert.badges.append(b) + return cert + + +class TestTrainingApiInit: + def test_init(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + assert api.auth_svc is services['auth_svc'] + assert api.data_svc is services['data_svc'] + + +class TestRequestUserId: + def test_x_user_id_header(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(headers={'X-User-ID': 'user123'}) + assert api._request_user_id(request) == 'user123' + + def test_key_header_fallback(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(headers={'KEY': 'apikey'}) + assert api._request_user_id(request) == 'apikey' + + def test_unknown_fallback(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(headers={}) + assert api._request_user_id(request) == 'unknown' + + def test_x_user_id_takes_precedence(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(headers={'X-User-ID': 'uid', 'KEY': 'key'}) + assert api._request_user_id(request) == 'uid' + + +class TestSplash: + @pytest.mark.asyncio + async def test_splash_returns_certificates(self): + services = _make_services() + cert = _make_cert() + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request() + result = await api.splash(request) + assert 'certificates' in result + assert len(result['certificates']) == 1 + + +class TestRetrieveCerts: + @pytest.mark.asyncio + async def test_retrieve_certs(self): + services = _make_services() + cert = _make_cert() + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request() + result = await api.retrieve_certs(request) + assert 'certificates' in result + assert len(result['certificates']) == 1 + + +class TestRetrieveFlags: + @pytest.mark.asyncio + async def test_retrieve_flags_basic(self): + services = _make_services() + cert = _make_cert() + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(json_data={'name': 'Test Cert'}) + result = await api.retrieve_flags(request) + assert 'badges' in result + + @pytest.mark.asyncio + async def test_retrieve_flags_with_answers(self): + services = _make_services() + cert = Certification(identifier='c1', name='Test', description='d', access='app') + b = Badge(name='b') + af = AnswerFlag(number=1) + b.flags.append(af) + cert.badges.append(b) + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(json_data={'name': 'Test', 'answers': {'1': 'correct'}}) + result = await api.retrieve_flags(request) + assert 'badges' in result + + +class TestResetFlag: + @pytest.mark.asyncio + async def test_reset_no_resettable_flags(self): + services = _make_services() + cert = _make_cert() + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(json_data={'name': 'Test Cert'}) + result = await api.reset_flag(request) + assert result['reset'] == 0 + + +class TestFlagSolutionGuide: + @pytest.mark.asyncio + async def test_cert_not_found(self): + services = _make_services() + services['data_svc'].locate = AsyncMock(return_value=[]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(match_info={ + 'cert_name': 'Missing', 'badge_name': 'b', 'flag_name': 'f' + }) + with pytest.raises(web.HTTPNotFound): + await api.flag_solution_guide(request) + + @pytest.mark.asyncio + async def test_badge_not_found(self): + services = _make_services() + cert = _make_cert() + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(match_info={ + 'cert_name': 'Test Cert', 'badge_name': 'nonexistent', 'flag_name': 'f' + }) + with pytest.raises(web.HTTPNotFound): + await api.flag_solution_guide(request) + + @pytest.mark.asyncio + async def test_flag_not_found(self): + services = _make_services() + cert = _make_cert() + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(match_info={ + 'cert_name': 'Test Cert', 'badge_name': 'badge-1', 'flag_name': 'nonexistent' + }) + with pytest.raises(web.HTTPNotFound): + await api.flag_solution_guide(request) + + +class TestCertificateSolutionGuide: + @pytest.mark.asyncio + async def test_cert_not_found(self): + services = _make_services() + services['data_svc'].locate = AsyncMock(return_value=[]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(match_info={'cert_name': 'Missing'}) + with pytest.raises(web.HTTPNotFound): + await api.certificate_solution_guide(request) + + @pytest.mark.asyncio + async def test_cert_found(self): + services = _make_services() + cert = _make_cert() + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(match_info={'cert_name': 'Test Cert'}) + result = await api.certificate_solution_guide(request) + assert 'certificate' in result + + +class TestCanIssue: + @pytest.mark.asyncio + async def test_cert_not_found(self): + services = _make_services() + services['data_svc'].locate = AsyncMock(return_value=[]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request() + with pytest.raises(web.HTTPNotFound): + await api.can_issue(request, 'Missing') + + @pytest.mark.asyncio + async def test_incomplete(self): + services = _make_services() + cert = _make_cert(completed=False) + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request() + _, complete = await api.can_issue(request, 'Test Cert') + assert complete is False + + @pytest.mark.asyncio + async def test_complete(self): + services = _make_services() + cert = _make_cert(completed=True) + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request() + _, complete = await api.can_issue(request, 'Test Cert') + assert complete is True + + @pytest.mark.asyncio + async def test_bypass_env(self): + services = _make_services() + cert = _make_cert(completed=False) + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request() + with patch.dict(os.environ, {'TRAINING_CERT_BYPASS': '1'}): + _, complete = await api.can_issue(request, 'Test Cert') + assert complete is True + + +class TestIssueCertificate: + @pytest.mark.asyncio + async def test_missing_fields(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(json_data={'certificate': 'Cert'}) # missing name + with pytest.raises(web.HTTPBadRequest): + await api.issue_certificate(request) + + @pytest.mark.asyncio + async def test_not_complete(self): + services = _make_services() + cert = _make_cert(completed=False) + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request( + json_data={'certificate': 'Test Cert', 'name': 'Jane'}, + headers={'X-User-ID': 'u1'} + ) + with pytest.raises(web.HTTPForbidden): + await api.issue_certificate(request) + + @pytest.mark.asyncio + async def test_already_issued(self): + services = _make_services() + cert = _make_cert(completed=True) + services['data_svc'].locate = AsyncMock(return_value=[cert]) + mock_cert_svc = MagicMock() + mock_cert_svc.already_issued.return_value = True + mock_cert_svc.get_record.return_value = {'path': '/cert.pdf'} + mock_cert_svc.signed_token.return_value = '/cert.pdf.sig' + with patch('plugins.training.app.training_api.CertificateService', return_value=mock_cert_svc): + api = TrainingApi(services) + request = _make_request( + json_data={'certificate': 'Test Cert', 'name': 'Jane'}, + headers={'X-User-ID': 'u1'} + ) + result = await api.issue_certificate(request) + assert result['alreadyIssued'] is True + + @pytest.mark.asyncio + async def test_new_issuance(self): + services = _make_services() + cert = _make_cert(completed=True) + services['data_svc'].locate = AsyncMock(return_value=[cert]) + mock_cert_svc = MagicMock() + mock_cert_svc.already_issued.return_value = False + mock_cert_svc.issue.return_value = '/path/cert.pdf' + mock_cert_svc.signed_token.return_value = '/path/cert.pdf.sig' + with patch('plugins.training.app.training_api.CertificateService', return_value=mock_cert_svc): + api = TrainingApi(services) + request = _make_request( + json_data={'certificate': 'Test Cert', 'name': 'Jane'}, + headers={'X-User-ID': 'u1'} + ) + result = await api.issue_certificate(request) + assert result['alreadyIssued'] is False + assert 'download' in result + + +class TestResetIssuance: + @pytest.mark.asyncio + async def test_missing_cert_name(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(json_data={}) + with pytest.raises(web.HTTPBadRequest): + await api.reset_issuance(request) + + @pytest.mark.asyncio + async def test_cert_not_found(self): + services = _make_services() + services['data_svc'].locate = AsyncMock(return_value=[]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(json_data={'certificate': 'Missing'}) + with pytest.raises(web.HTTPNotFound): + await api.reset_issuance(request) + + @pytest.mark.asyncio + async def test_successful_reset(self): + services = _make_services() + cert = _make_cert() + services['data_svc'].locate = AsyncMock(return_value=[cert]) + mock_cert_svc = MagicMock() + with patch('plugins.training.app.training_api.CertificateService', return_value=mock_cert_svc): + api = TrainingApi(services) + request = _make_request( + json_data={'certificate': 'Test Cert'}, + headers={'X-User-ID': 'u1'} + ) + result = await api.reset_issuance(request) + assert result['ok'] is True + mock_cert_svc.clear_issued_for_user_cert.assert_called_once() + + +class TestDownloadCertificate: + @pytest.mark.asyncio + async def test_invalid_token(self): + services = _make_services() + mock_cert_svc = MagicMock() + mock_cert_svc.verify_token.return_value = None + with patch('plugins.training.app.training_api.CertificateService', return_value=mock_cert_svc): + api = TrainingApi(services) + request = _make_request(query={'token': 'bad'}) + with pytest.raises(web.HTTPNotFound): + await api.download_certificate(request) + + @pytest.mark.asyncio + async def test_missing_file(self): + services = _make_services() + mock_cert_svc = MagicMock() + mock_cert_svc.verify_token.return_value = '/nonexistent.pdf' + with patch('plugins.training.app.training_api.CertificateService', return_value=mock_cert_svc): + api = TrainingApi(services) + request = _make_request(query={'token': 'valid'}) + with pytest.raises(web.HTTPNotFound): + await api.download_certificate(request) + + +class TestAttachmentHeaders: + def test_pdf_content_type(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + headers = api._attachment_headers('cert.pdf') + assert 'Content-Type' in headers + assert 'Content-Disposition' in headers + assert 'cert.pdf' in headers['Content-Disposition'] + + def test_unknown_extension(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + headers = api._attachment_headers('file.qqqzzz') + assert headers['Content-Type'] == 'application/octet-stream' + + +class TestIssueCertificateBytes: + @pytest.mark.asyncio + async def test_missing_fields(self): + services = _make_services() + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request(json_data={'certificate': 'C'}) + with pytest.raises(web.HTTPBadRequest): + await api.issue_certificate_bytes(request) + + @pytest.mark.asyncio + async def test_not_complete(self): + services = _make_services() + cert = _make_cert(completed=False) + services['data_svc'].locate = AsyncMock(return_value=[cert]) + with patch('plugins.training.app.training_api.CertificateService'): + api = TrainingApi(services) + request = _make_request( + json_data={'certificate': 'Test Cert', 'name': 'Jane'}, + headers={'X-User-ID': 'u1'} + ) + with pytest.raises(web.HTTPForbidden): + await api.issue_certificate_bytes(request)