diff --git a/tests/api/test_bulk.py b/tests/api/test_bulk.py index b5a60ecb..910a3d2c 100644 --- a/tests/api/test_bulk.py +++ b/tests/api/test_bulk.py @@ -203,7 +203,7 @@ def test_update_users_proxy_settings(access_token): response = client.post( "/api/users/bulk/proxy_settings", headers={"Authorization": f"Bearer {access_token}"}, - json={"flow": "xtls-rprx-vision"}, + json={"method": "xchacha20-poly1305"}, ) assert response.status_code == status.HTTP_200_OK @@ -213,12 +213,22 @@ def test_update_users_proxy_settings(access_token): for u in response.json()["users"] if u["username"] in {users[0]["username"], users[1]["username"]} } - assert listed[users[0]["username"]]["proxy_settings"]["vless"]["flow"] == "xtls-rprx-vision" - assert listed[users[1]["username"]]["proxy_settings"]["vless"]["flow"] == "xtls-rprx-vision" + assert listed[users[0]["username"]]["proxy_settings"]["shadowsocks"]["method"] == "xchacha20-poly1305" + assert listed[users[1]["username"]]["proxy_settings"]["shadowsocks"]["method"] == "xchacha20-poly1305" finally: cleanup(access_token, core, groups, users) +def test_update_users_proxy_settings_no_method_returns_400(access_token): + """Bulk proxy settings update with no supported settings should return 400.""" + response = client.post( + "/api/users/bulk/proxy_settings", + headers={"Authorization": f"Bearer {access_token}"}, + json={}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_bulk_expire_with_range(access_token): # Setup core = create_core(access_token) diff --git a/tests/test_model_changes.py b/tests/test_model_changes.py new file mode 100644 index 00000000..4594768e --- /dev/null +++ b/tests/test_model_changes.py @@ -0,0 +1,326 @@ +""" +Tests for model changes introduced in this PR: +- app/models/proxy.py: XTLSFlows removed, VlessSettings.flow removed +- app/models/user_template.py: ExtraSettings.flow removed +- app/models/user.py: BulkUsersProxy.flow removed +- app/models/settings.py: General.default_flow removed +- app/models/validators.py: ProxyValidator.validate_proxy_url falsy check +- app/models/host.py: XHttpSettings.uplink_chunk_size changed to str|int|None with pattern +- app/models/status_emojis.py: STATUS_EMOJIS dict added +""" + +import pytest +from pydantic import ValidationError + + +class TestVlessSettings: + """VlessSettings no longer has a flow field.""" + + def test_vless_settings_has_no_flow_field(self): + from app.models.proxy import VlessSettings + + instance = VlessSettings() + assert not hasattr(instance, "flow") + + def test_vless_settings_has_id_field(self): + from app.models.proxy import VlessSettings + + instance = VlessSettings() + assert hasattr(instance, "id") + assert instance.id is not None + + def test_vless_settings_ignores_flow_in_input(self): + """Extra fields should be ignored (Pydantic default).""" + from app.models.proxy import VlessSettings + + # model_config may or may not forbid extras; just test no flow attribute + instance = VlessSettings(id="12345678-1234-5678-1234-567812345678") + assert not hasattr(instance, "flow") + + def test_xtls_flows_enum_removed(self): + """XTLSFlows should no longer exist in app.models.proxy.""" + import app.models.proxy as proxy_module + + assert not hasattr(proxy_module, "XTLSFlows") + + +class TestExtraSettings: + """ExtraSettings no longer has a flow field.""" + + def test_extra_settings_has_no_flow_field(self): + from app.models.user_template import ExtraSettings + + instance = ExtraSettings() + assert not hasattr(instance, "flow") + + def test_extra_settings_has_method_field(self): + from app.models.user_template import ExtraSettings + + instance = ExtraSettings() + assert hasattr(instance, "method") + + def test_extra_settings_default_method(self): + from app.models.proxy import ShadowsocksMethods + from app.models.user_template import ExtraSettings + + instance = ExtraSettings() + assert instance.method == ShadowsocksMethods.CHACHA20_POLY1305 + + def test_extra_settings_dict_method_no_obj(self): + from app.models.proxy import ShadowsocksMethods + from app.models.user_template import ExtraSettings + + instance = ExtraSettings(method=ShadowsocksMethods.AES_256_GCM) + result = instance.dict() + assert "method" in result + assert "flow" not in result + + def test_extra_settings_dict_contains_only_method(self): + from app.models.user_template import ExtraSettings + + instance = ExtraSettings() + result = instance.dict() + assert set(result.keys()) == {"method"} + + def test_extra_settings_none_method(self): + from app.models.user_template import ExtraSettings + + instance = ExtraSettings(method=None) + assert instance.method is None + + +class TestBulkUsersProxy: + """BulkUsersProxy no longer has a flow field.""" + + def test_bulk_users_proxy_has_no_flow_field(self): + from app.models.user import BulkUsersProxy + + instance = BulkUsersProxy() + assert not hasattr(instance, "flow") + + def test_bulk_users_proxy_has_method_field(self): + from app.models.user import BulkUsersProxy + + instance = BulkUsersProxy() + assert hasattr(instance, "method") + assert instance.method is None + + def test_bulk_users_proxy_accepts_shadowsocks_method(self): + from app.models.proxy import ShadowsocksMethods + from app.models.user import BulkUsersProxy + + instance = BulkUsersProxy(method=ShadowsocksMethods.AES_128_GCM) + assert instance.method == ShadowsocksMethods.AES_128_GCM + + def test_bulk_users_proxy_method_none_by_default(self): + from app.models.user import BulkUsersProxy + + instance = BulkUsersProxy() + assert instance.method is None + + +class TestGeneralSettings: + """General settings no longer has a default_flow field.""" + + def test_general_has_no_default_flow(self): + from app.models.settings import General + + instance = General() + assert not hasattr(instance, "default_flow") + + def test_general_has_default_method(self): + from app.models.proxy import ShadowsocksMethods + from app.models.settings import General + + instance = General() + assert hasattr(instance, "default_method") + assert instance.default_method == ShadowsocksMethods.CHACHA20_POLY1305 + + def test_general_default_method_can_be_set(self): + from app.models.proxy import ShadowsocksMethods + from app.models.settings import General + + instance = General(default_method=ShadowsocksMethods.AES_256_GCM) + assert instance.default_method == ShadowsocksMethods.AES_256_GCM + + def test_general_fields(self): + from app.models.settings import General + + fields = set(General.model_fields.keys()) + assert "default_method" in fields + assert "default_flow" not in fields + + +class TestProxyValidatorChanges: + """ + ProxyValidator.validate_proxy_url changed from `if value is None` to `if not value`. + Empty string now returns None instead of attempting pattern validation. + """ + + def test_none_returns_none(self): + from app.models.validators import ProxyValidator + + assert ProxyValidator.validate_proxy_url(None) is None + + def test_empty_string_returns_none(self): + """Changed behavior: empty string now returns None.""" + from app.models.validators import ProxyValidator + + assert ProxyValidator.validate_proxy_url("") is None + + def test_valid_http_url_passes(self): + from app.models.validators import ProxyValidator + + result = ProxyValidator.validate_proxy_url("http://127.0.0.1:8080") + assert result == "http://127.0.0.1:8080" + + def test_valid_https_url_passes(self): + from app.models.validators import ProxyValidator + + result = ProxyValidator.validate_proxy_url("https://proxy.example.com:443") + assert result == "https://proxy.example.com:443" + + def test_valid_socks5_url_passes(self): + from app.models.validators import ProxyValidator + + result = ProxyValidator.validate_proxy_url("socks5://127.0.0.1:1080") + assert result == "socks5://127.0.0.1:1080" + + def test_valid_socks4_url_passes(self): + from app.models.validators import ProxyValidator + + result = ProxyValidator.validate_proxy_url("socks4://proxy.example.com:1080") + assert result == "socks4://proxy.example.com:1080" + + def test_valid_url_with_credentials_passes(self): + from app.models.validators import ProxyValidator + + result = ProxyValidator.validate_proxy_url("socks5://user:pass@127.0.0.1:1080") + assert result == "socks5://user:pass@127.0.0.1:1080" + + def test_invalid_scheme_raises(self): + from app.models.validators import ProxyValidator + + with pytest.raises(ValueError, match="proxy_url must be a valid proxy address"): + ProxyValidator.validate_proxy_url("ftp://127.0.0.1:21") + + def test_invalid_url_no_port_raises(self): + from app.models.validators import ProxyValidator + + with pytest.raises(ValueError, match="proxy_url must be a valid proxy address"): + ProxyValidator.validate_proxy_url("http://127.0.0.1") + + def test_invalid_url_no_scheme_raises(self): + from app.models.validators import ProxyValidator + + with pytest.raises(ValueError, match="proxy_url must be a valid proxy address"): + ProxyValidator.validate_proxy_url("127.0.0.1:8080") + + def test_whitespace_string_returns_none(self): + """Whitespace is falsy in Python? No - it's truthy. But it will fail validation.""" + from app.models.validators import ProxyValidator + + # Whitespace is truthy, so it goes to the pattern check and should raise + with pytest.raises(ValueError): + ProxyValidator.validate_proxy_url(" ") + + +class TestXHttpSettingsUplinkChunkSize: + """ + XHttpSettings.uplink_chunk_size changed from int|None to str|int|None with pattern. + Also added to _empty_str_to_none validator list. + """ + + def test_uplink_chunk_size_none_by_default(self): + from app.models.host import XHttpSettings + + instance = XHttpSettings() + assert instance.uplink_chunk_size is None + + def test_uplink_chunk_size_accepts_integer_string(self): + from app.models.host import XHttpSettings + + instance = XHttpSettings(uplink_chunk_size="1024") + assert instance.uplink_chunk_size == "1024" + + def test_uplink_chunk_size_accepts_range_string(self): + from app.models.host import XHttpSettings + + instance = XHttpSettings(uplink_chunk_size="1024-2048") + assert instance.uplink_chunk_size == "1024-2048" + + def test_uplink_chunk_size_empty_string_becomes_none(self): + from app.models.host import XHttpSettings + + instance = XHttpSettings(uplink_chunk_size="") + assert instance.uplink_chunk_size is None + + def test_uplink_chunk_size_accepts_integer(self): + from app.models.host import XHttpSettings + + instance = XHttpSettings(uplink_chunk_size=512) + assert instance.uplink_chunk_size == 512 + + def test_uplink_chunk_size_rejects_invalid_pattern(self): + from app.models.host import XHttpSettings + + with pytest.raises(ValidationError): + XHttpSettings(uplink_chunk_size="abc") + + def test_uplink_chunk_size_rejects_negative(self): + from app.models.host import XHttpSettings + + with pytest.raises(ValidationError): + XHttpSettings(uplink_chunk_size="-1") + + def test_uplink_chunk_size_single_large_number(self): + from app.models.host import XHttpSettings + + instance = XHttpSettings(uplink_chunk_size="1234567890123456") + assert instance.uplink_chunk_size == "1234567890123456" + + def test_uplink_chunk_size_range_with_large_numbers(self): + from app.models.host import XHttpSettings + + instance = XHttpSettings(uplink_chunk_size="100-9999999999999999") + assert instance.uplink_chunk_size == "100-9999999999999999" + + +class TestStatusEmojis: + """STATUS_EMOJIS dict is a new module.""" + + def test_status_emojis_contains_all_statuses(self): + from app.models.status_emojis import STATUS_EMOJIS + + expected_keys = {"active", "expired", "limited", "disabled", "on_hold"} + assert set(STATUS_EMOJIS.keys()) == expected_keys + + def test_status_emojis_active(self): + from app.models.status_emojis import STATUS_EMOJIS + + assert STATUS_EMOJIS["active"] == "✅" + + def test_status_emojis_expired(self): + from app.models.status_emojis import STATUS_EMOJIS + + assert STATUS_EMOJIS["expired"] == "âŒ›ī¸" + + def test_status_emojis_limited(self): + from app.models.status_emojis import STATUS_EMOJIS + + assert STATUS_EMOJIS["limited"] == "đŸĒĢ" + + def test_status_emojis_disabled(self): + from app.models.status_emojis import STATUS_EMOJIS + + assert STATUS_EMOJIS["disabled"] == "❌" + + def test_status_emojis_on_hold(self): + from app.models.status_emojis import STATUS_EMOJIS + + assert STATUS_EMOJIS["on_hold"] == "🔌" + + def test_status_emojis_is_dict(self): + from app.models.status_emojis import STATUS_EMOJIS + + assert isinstance(STATUS_EMOJIS, dict) \ No newline at end of file diff --git a/tests/test_node_user.py b/tests/test_node_user.py new file mode 100644 index 00000000..73685240 --- /dev/null +++ b/tests/test_node_user.py @@ -0,0 +1,355 @@ +""" +Tests for app/node/user.py - _serialize_user_for_node refactored to use allowed_protocols. + +Key change: the function now conditionally includes proxy kwargs based on the +allowed_protocols frozenset, instead of always including all protocol fields. +""" + +import sys +import types +from unittest.mock import MagicMock, call, patch + +import pytest + +from app.models.protocol import ProxyProtocol + + +# --------------------------------------------------------------------------- +# Helpers to set up the PasarGuardNodeBridge mock before importing node.user +# --------------------------------------------------------------------------- + + +def _make_bridge_mock(): + """Return a (module_mock, create_proxy_mock, create_user_mock) triple.""" + bridge_module = types.ModuleType("PasarGuardNodeBridge") + common_module = types.ModuleType("PasarGuardNodeBridge.common") + service_module = types.ModuleType("PasarGuardNodeBridge.common.service_pb2") + + create_proxy_mock = MagicMock(name="create_proxy", return_value=MagicMock(name="proxy_obj")) + create_user_mock = MagicMock(name="create_user", return_value=MagicMock(name="user_obj")) + + bridge_module.create_proxy = create_proxy_mock + bridge_module.create_user = create_user_mock + service_module.User = MagicMock(name="ProtoUser") + + common_module.service_pb2 = service_module + bridge_module.common = common_module + + return bridge_module, create_proxy_mock, create_user_mock + + +@pytest.fixture(autouse=True) +def mock_bridge(monkeypatch): + """Inject fake PasarGuardNodeBridge before any import of app.node.user.""" + bridge_module, create_proxy_mock, create_user_mock = _make_bridge_mock() + + monkeypatch.setitem(sys.modules, "PasarGuardNodeBridge", bridge_module) + monkeypatch.setitem(sys.modules, "PasarGuardNodeBridge.common", bridge_module.common) + monkeypatch.setitem(sys.modules, "PasarGuardNodeBridge.common.service_pb2", bridge_module.common.service_pb2) + + # Remove cached app.node.user so it re-imports with our mock + monkeypatch.delitem(sys.modules, "app.node.user", raising=False) + + yield create_proxy_mock, create_user_mock + + +def _call_serialize(user_settings: dict, inbounds=None, allowed_protocols=None): + """Import _serialize_user_for_node fresh and call it.""" + from app.node.user import _serialize_user_for_node + + return _serialize_user_for_node( + id=42, + username="testuser", + user_settings=user_settings, + inbounds=inbounds, + allowed_protocols=allowed_protocols, + ) + + +FULL_USER_SETTINGS = { + "vmess": {"id": "vmess-uuid-1234"}, + "vless": {"id": "vless-uuid-5678"}, + "trojan": {"password": "trojan-pass"}, + "shadowsocks": {"password": "ss-pass", "method": "chacha20-ietf-poly1305"}, + "wireguard": {"public_key": "wg-pub-key", "peer_ips": ["10.0.0.1/32"]}, + "hysteria": {"auth": "hysteria-auth"}, +} + + +class TestSerializeUserForNodeAllProtocols: + """When allowed_protocols is None, all protocols should be included.""" + + def test_none_allowed_protocols_includes_vmess(self, mock_bridge): + create_proxy, create_user = mock_bridge + _call_serialize(FULL_USER_SETTINGS, inbounds=["tag1"], allowed_protocols=None) + kwargs = create_proxy.call_args.kwargs + assert "vmess_id" in kwargs + assert kwargs["vmess_id"] == "vmess-uuid-1234" + + def test_none_allowed_protocols_includes_vless(self, mock_bridge): + create_proxy, create_user = mock_bridge + _call_serialize(FULL_USER_SETTINGS, inbounds=["tag1"], allowed_protocols=None) + kwargs = create_proxy.call_args.kwargs + assert "vless_id" in kwargs + assert kwargs["vless_id"] == "vless-uuid-5678" + + def test_none_allowed_protocols_includes_trojan(self, mock_bridge): + create_proxy, create_user = mock_bridge + _call_serialize(FULL_USER_SETTINGS, inbounds=["tag1"], allowed_protocols=None) + kwargs = create_proxy.call_args.kwargs + assert "trojan_password" in kwargs + assert kwargs["trojan_password"] == "trojan-pass" + + def test_none_allowed_protocols_includes_shadowsocks(self, mock_bridge): + create_proxy, create_user = mock_bridge + _call_serialize(FULL_USER_SETTINGS, inbounds=["tag1"], allowed_protocols=None) + kwargs = create_proxy.call_args.kwargs + assert "shadowsocks_password" in kwargs + assert kwargs["shadowsocks_password"] == "ss-pass" + assert "shadowsocks_method" in kwargs + assert kwargs["shadowsocks_method"] == "chacha20-ietf-poly1305" + + def test_none_allowed_protocols_includes_wireguard(self, mock_bridge): + create_proxy, create_user = mock_bridge + _call_serialize(FULL_USER_SETTINGS, inbounds=["tag1"], allowed_protocols=None) + kwargs = create_proxy.call_args.kwargs + assert "wireguard_public_key" in kwargs + assert kwargs["wireguard_public_key"] == "wg-pub-key" + assert "wireguard_peer_ips" in kwargs + assert kwargs["wireguard_peer_ips"] == ["10.0.0.1/32"] + + def test_none_allowed_protocols_includes_hysteria(self, mock_bridge): + create_proxy, create_user = mock_bridge + _call_serialize(FULL_USER_SETTINGS, inbounds=["tag1"], allowed_protocols=None) + kwargs = create_proxy.call_args.kwargs + assert "hysteria_auth" in kwargs + assert kwargs["hysteria_auth"] == "hysteria-auth" + + def test_all_protocols_frozenset_same_as_none(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize(FULL_USER_SETTINGS, inbounds=["tag1"], allowed_protocols=frozenset(ProxyProtocol)) + kwargs = create_proxy.call_args.kwargs + assert "vmess_id" in kwargs + assert "vless_id" in kwargs + assert "trojan_password" in kwargs + assert "shadowsocks_password" in kwargs + assert "wireguard_public_key" in kwargs + assert "hysteria_auth" in kwargs + + +class TestSerializeUserForNodeFilteredProtocols: + """When allowed_protocols is provided, only those protocols should appear.""" + + def test_vmess_only(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + FULL_USER_SETTINGS, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.vmess}), + ) + kwargs = create_proxy.call_args.kwargs + assert "vmess_id" in kwargs + assert "vless_id" not in kwargs + assert "trojan_password" not in kwargs + assert "shadowsocks_password" not in kwargs + assert "wireguard_public_key" not in kwargs + assert "hysteria_auth" not in kwargs + + def test_vless_only(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + FULL_USER_SETTINGS, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.vless}), + ) + kwargs = create_proxy.call_args.kwargs + assert "vless_id" in kwargs + assert "vmess_id" not in kwargs + assert "trojan_password" not in kwargs + + def test_shadowsocks_only(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + FULL_USER_SETTINGS, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.shadowsocks}), + ) + kwargs = create_proxy.call_args.kwargs + assert "shadowsocks_password" in kwargs + assert "shadowsocks_method" in kwargs + assert "vmess_id" not in kwargs + assert "vless_id" not in kwargs + assert "trojan_password" not in kwargs + assert "wireguard_public_key" not in kwargs + assert "hysteria_auth" not in kwargs + + def test_wireguard_only(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + FULL_USER_SETTINGS, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.wireguard}), + ) + kwargs = create_proxy.call_args.kwargs + assert "wireguard_public_key" in kwargs + assert "wireguard_peer_ips" in kwargs + assert "vmess_id" not in kwargs + assert "vless_id" not in kwargs + + def test_trojan_only(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + FULL_USER_SETTINGS, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.trojan}), + ) + kwargs = create_proxy.call_args.kwargs + assert "trojan_password" in kwargs + assert "vmess_id" not in kwargs + assert "shadowsocks_password" not in kwargs + + def test_hysteria_only(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + FULL_USER_SETTINGS, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.hysteria}), + ) + kwargs = create_proxy.call_args.kwargs + assert "hysteria_auth" in kwargs + assert "vmess_id" not in kwargs + assert "vless_id" not in kwargs + + def test_empty_allowed_protocols_no_proxy_kwargs(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + FULL_USER_SETTINGS, + inbounds=["tag1"], + allowed_protocols=frozenset(), + ) + kwargs = create_proxy.call_args.kwargs + assert len(kwargs) == 0 + + def test_vmess_and_vless(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + FULL_USER_SETTINGS, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.vmess, ProxyProtocol.vless}), + ) + kwargs = create_proxy.call_args.kwargs + assert "vmess_id" in kwargs + assert "vless_id" in kwargs + assert "trojan_password" not in kwargs + + def test_xray_protocols(self, mock_bridge): + """Test with typical XRay protocols (vmess, vless, trojan, shadowsocks).""" + create_proxy, _ = mock_bridge + xray_protocols = frozenset({ + ProxyProtocol.vmess, + ProxyProtocol.vless, + ProxyProtocol.trojan, + ProxyProtocol.shadowsocks, + }) + _call_serialize(FULL_USER_SETTINGS, inbounds=["tag1"], allowed_protocols=xray_protocols) + kwargs = create_proxy.call_args.kwargs + assert "vmess_id" in kwargs + assert "vless_id" in kwargs + assert "trojan_password" in kwargs + assert "shadowsocks_password" in kwargs + assert "wireguard_public_key" not in kwargs + assert "hysteria_auth" not in kwargs + + +class TestSerializeUserForNodeMissingSettings: + """Test behavior when user_settings dict is missing protocol keys.""" + + def test_missing_vmess_settings(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + {}, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.vmess}), + ) + kwargs = create_proxy.call_args.kwargs + assert "vmess_id" in kwargs + assert kwargs["vmess_id"] is None + + def test_missing_shadowsocks_settings(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + {}, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.shadowsocks}), + ) + kwargs = create_proxy.call_args.kwargs + assert kwargs["shadowsocks_password"] is None + assert kwargs["shadowsocks_method"] is None + + def test_missing_wireguard_settings(self, mock_bridge): + create_proxy, _ = mock_bridge + _call_serialize( + {}, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.wireguard}), + ) + kwargs = create_proxy.call_args.kwargs + assert kwargs["wireguard_public_key"] is None + assert kwargs["wireguard_peer_ips"] == [] + + def test_wireguard_peer_ips_none_becomes_empty_list(self, mock_bridge): + """wireguard_peer_ips=None should be coerced to [].""" + create_proxy, _ = mock_bridge + settings = {"wireguard": {"public_key": "some-key", "peer_ips": None}} + _call_serialize( + settings, + inbounds=["tag1"], + allowed_protocols=frozenset({ProxyProtocol.wireguard}), + ) + kwargs = create_proxy.call_args.kwargs + assert kwargs["wireguard_peer_ips"] == [] + + +class TestSerializeUserForNodeCreateUserCall: + """Test that create_user is called with the correct arguments.""" + + def test_create_user_called_with_correct_name(self, mock_bridge): + create_proxy, create_user = mock_bridge + _call_serialize({}, inbounds=["tag1"], allowed_protocols=None) + args = create_user.call_args.args + assert args[0] == "42.testuser" + + def test_create_user_called_with_inbounds(self, mock_bridge): + create_proxy, create_user = mock_bridge + inbounds = ["tag1", "tag2", "tag3"] + _call_serialize({}, inbounds=inbounds, allowed_protocols=None) + args = create_user.call_args.args + assert args[2] == inbounds + + def test_create_user_called_with_none_inbounds(self, mock_bridge): + create_proxy, create_user = mock_bridge + _call_serialize({}, inbounds=None, allowed_protocols=None) + args = create_user.call_args.args + assert args[2] is None + + +class TestNoVlessFlowInSerialize: + """ + The old code handled vless_flow; the new code does not set vless_flow at all. + Regression test: ensure flow-related kwargs are not passed to create_proxy. + """ + + def test_vless_flow_not_in_proxy_kwargs(self, mock_bridge): + create_proxy, _ = mock_bridge + settings = {"vless": {"id": "some-uuid", "flow": "xtls-rprx-vision"}} + _call_serialize(settings, inbounds=["tag1"], allowed_protocols=frozenset({ProxyProtocol.vless})) + kwargs = create_proxy.call_args.kwargs + assert "vless_flow" not in kwargs + + def test_vless_only_id_is_passed(self, mock_bridge): + create_proxy, _ = mock_bridge + settings = {"vless": {"id": "my-vless-uuid"}} + _call_serialize(settings, inbounds=["tag1"], allowed_protocols=frozenset({ProxyProtocol.vless})) + kwargs = create_proxy.call_args.kwargs + assert kwargs["vless_id"] == "my-vless-uuid" + assert "vless_flow" not in kwargs \ No newline at end of file diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 00000000..778b052a --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,136 @@ +"""Tests for app/models/protocol.py - ProxyProtocol enum and from_value classmethod.""" + +import pytest + +from app.models.protocol import ProxyProtocol, _PROXY_PROTOCOL_BY_NAME + + +class TestProxyProtocolValues: + """Test that ProxyProtocol enum has the expected values.""" + + def test_vmess_value(self): + assert ProxyProtocol.vmess == 1 + + def test_vless_value(self): + assert ProxyProtocol.vless == 2 + + def test_trojan_value(self): + assert ProxyProtocol.trojan == 3 + + def test_shadowsocks_value(self): + assert ProxyProtocol.shadowsocks == 4 + + def test_wireguard_value(self): + assert ProxyProtocol.wireguard == 5 + + def test_hysteria_value(self): + assert ProxyProtocol.hysteria == 6 + + def test_total_protocols_count(self): + assert len(ProxyProtocol) == 6 + + def test_is_int_enum(self): + assert isinstance(ProxyProtocol.vmess, int) + assert int(ProxyProtocol.vmess) == 1 + + def test_all_protocol_names(self): + expected_names = {"vmess", "vless", "trojan", "shadowsocks", "wireguard", "hysteria"} + actual_names = {p.name for p in ProxyProtocol} + assert actual_names == expected_names + + +class TestProxyProtocolFromValue: + """Test the from_value classmethod.""" + + def test_from_value_vmess(self): + result = ProxyProtocol.from_value("vmess") + assert result == ProxyProtocol.vmess + + def test_from_value_vless(self): + result = ProxyProtocol.from_value("vless") + assert result == ProxyProtocol.vless + + def test_from_value_trojan(self): + result = ProxyProtocol.from_value("trojan") + assert result == ProxyProtocol.trojan + + def test_from_value_shadowsocks(self): + result = ProxyProtocol.from_value("shadowsocks") + assert result == ProxyProtocol.shadowsocks + + def test_from_value_wireguard(self): + result = ProxyProtocol.from_value("wireguard") + assert result == ProxyProtocol.wireguard + + def test_from_value_hysteria(self): + result = ProxyProtocol.from_value("hysteria") + assert result == ProxyProtocol.hysteria + + def test_from_value_unknown_returns_none(self): + assert ProxyProtocol.from_value("unknown") is None + + def test_from_value_empty_string_returns_none(self): + assert ProxyProtocol.from_value("") is None + + def test_from_value_uppercase_returns_none(self): + # Names are case-sensitive (uses dict lookup by name) + assert ProxyProtocol.from_value("VMESS") is None + assert ProxyProtocol.from_value("Vless") is None + + def test_from_value_partial_name_returns_none(self): + assert ProxyProtocol.from_value("vmes") is None + assert ProxyProtocol.from_value("vless2") is None + + def test_from_value_all_protocols(self): + """Every protocol name should resolve to its enum member.""" + for protocol in ProxyProtocol: + result = ProxyProtocol.from_value(protocol.name) + assert result is protocol, f"from_value({protocol.name!r}) returned {result!r}, expected {protocol!r}" + + def test_from_value_returns_protocol_type(self): + result = ProxyProtocol.from_value("shadowsocks") + assert isinstance(result, ProxyProtocol) + + +class TestProxyProtocolFrozenset: + """Test ProxyProtocol usage in frozensets (as used in the codebase).""" + + def test_protocol_in_frozenset(self): + protocols = frozenset(ProxyProtocol) + assert ProxyProtocol.vmess in protocols + assert ProxyProtocol.wireguard in protocols + + def test_frozenset_of_single_protocol(self): + wireguard_set = frozenset((ProxyProtocol.wireguard,)) + assert ProxyProtocol.wireguard in wireguard_set + assert ProxyProtocol.vmess not in wireguard_set + + def test_frozenset_membership_check(self): + allowed = frozenset({ProxyProtocol.vmess, ProxyProtocol.vless}) + assert ProxyProtocol.vmess in allowed + assert ProxyProtocol.vless in allowed + assert ProxyProtocol.trojan not in allowed + assert ProxyProtocol.shadowsocks not in allowed + + def test_all_protocols_frozenset(self): + all_protocols = frozenset(ProxyProtocol) + assert len(all_protocols) == 6 + + +class TestProxyProtocolByNameDict: + """Test the module-level _PROXY_PROTOCOL_BY_NAME dict.""" + + def test_dict_contains_all_protocols(self): + assert len(_PROXY_PROTOCOL_BY_NAME) == 6 + + def test_dict_maps_name_to_protocol(self): + assert _PROXY_PROTOCOL_BY_NAME["vmess"] == ProxyProtocol.vmess + assert _PROXY_PROTOCOL_BY_NAME["vless"] == ProxyProtocol.vless + assert _PROXY_PROTOCOL_BY_NAME["trojan"] == ProxyProtocol.trojan + assert _PROXY_PROTOCOL_BY_NAME["shadowsocks"] == ProxyProtocol.shadowsocks + assert _PROXY_PROTOCOL_BY_NAME["wireguard"] == ProxyProtocol.wireguard + assert _PROXY_PROTOCOL_BY_NAME["hysteria"] == ProxyProtocol.hysteria + + def test_dict_missing_key_raises_key_error(self): + with pytest.raises(KeyError): + _ = _PROXY_PROTOCOL_BY_NAME["unknown"] \ No newline at end of file diff --git a/tests/test_xray_protocols.py b/tests/test_xray_protocols.py new file mode 100644 index 00000000..d07dd224 --- /dev/null +++ b/tests/test_xray_protocols.py @@ -0,0 +1,267 @@ +""" +Tests for protocol-related changes in app/core/xray.py and app/core/wireguard.py: +- _protocols_from_inbounds_by_tag: extracts ProxyProtocol from inbound configs +- XRayConfig.protocols property: frozenset of detected protocols +- WireGuardConfig.protocols property: always frozenset({ProxyProtocol.wireguard}) +- AbstractCore.protocols abstract property added +""" + +import pytest + +from app.models.protocol import ProxyProtocol + + +class TestProtocolsFromInboundsByTag: + """Tests for the _protocols_from_inbounds_by_tag function in app/core/xray.py.""" + + def _call(self, inbounds_by_tag: dict) -> frozenset: + from app.core.xray import _protocols_from_inbounds_by_tag + + return _protocols_from_inbounds_by_tag(inbounds_by_tag) + + def test_empty_inbounds_returns_empty_frozenset(self): + result = self._call({}) + assert result == frozenset() + assert isinstance(result, frozenset) + + def test_single_vmess_inbound(self): + inbounds = {"tag1": {"protocol": "vmess"}} + result = self._call(inbounds) + assert result == frozenset({ProxyProtocol.vmess}) + + def test_single_vless_inbound(self): + inbounds = {"tag1": {"protocol": "vless"}} + result = self._call(inbounds) + assert result == frozenset({ProxyProtocol.vless}) + + def test_single_trojan_inbound(self): + inbounds = {"tag1": {"protocol": "trojan"}} + result = self._call(inbounds) + assert result == frozenset({ProxyProtocol.trojan}) + + def test_single_shadowsocks_inbound(self): + inbounds = {"tag1": {"protocol": "shadowsocks"}} + result = self._call(inbounds) + assert result == frozenset({ProxyProtocol.shadowsocks}) + + def test_single_hysteria_inbound(self): + inbounds = {"tag1": {"protocol": "hysteria"}} + result = self._call(inbounds) + assert result == frozenset({ProxyProtocol.hysteria}) + + def test_multiple_same_protocol_returns_single_entry(self): + inbounds = { + "tag1": {"protocol": "vmess"}, + "tag2": {"protocol": "vmess"}, + "tag3": {"protocol": "vmess"}, + } + result = self._call(inbounds) + assert result == frozenset({ProxyProtocol.vmess}) + assert len(result) == 1 + + def test_multiple_distinct_protocols(self): + inbounds = { + "tag1": {"protocol": "vmess"}, + "tag2": {"protocol": "vless"}, + "tag3": {"protocol": "trojan"}, + "tag4": {"protocol": "shadowsocks"}, + } + result = self._call(inbounds) + assert result == frozenset({ + ProxyProtocol.vmess, + ProxyProtocol.vless, + ProxyProtocol.trojan, + ProxyProtocol.shadowsocks, + }) + + def test_unknown_protocol_excluded(self): + inbounds = { + "tag1": {"protocol": "unknown_protocol"}, + "tag2": {"protocol": "vmess"}, + } + result = self._call(inbounds) + assert result == frozenset({ProxyProtocol.vmess}) + + def test_all_unknown_protocols_returns_empty(self): + inbounds = { + "tag1": {"protocol": "http"}, + "tag2": {"protocol": "socks"}, + } + result = self._call(inbounds) + assert result == frozenset() + + def test_mixed_known_and_unknown_protocols(self): + inbounds = { + "tag1": {"protocol": "vless"}, + "tag2": {"protocol": "not-a-protocol"}, + "tag3": {"protocol": "shadowsocks"}, + } + result = self._call(inbounds) + assert ProxyProtocol.vless in result + assert ProxyProtocol.shadowsocks in result + assert len(result) == 2 + + def test_returns_frozenset_type(self): + result = self._call({"tag1": {"protocol": "vmess"}}) + assert isinstance(result, frozenset) + + def test_result_is_immutable(self): + result = self._call({"tag1": {"protocol": "vmess"}}) + with pytest.raises(AttributeError): + result.add(ProxyProtocol.vless) + + +class TestXRayConfigProtocols: + """Tests for XRayConfig.protocols property.""" + + def _make_xray_config_from_json(self, inbounds_by_tag: dict) -> object: + """Use from_json to construct XRayConfig with specific inbounds_by_tag.""" + from app.core.xray import XRayConfig + + data = { + "config": {}, + "exclude_inbound_tags": [], + "fallbacks_inbound_tags": [], + "inbounds": list(inbounds_by_tag.keys()), + "inbounds_by_tag": inbounds_by_tag, + } + return XRayConfig.from_json(data) + + def test_empty_inbounds_by_tag_gives_empty_protocols(self): + config = self._make_xray_config_from_json({}) + assert config.protocols == frozenset() + + def test_vmess_inbound_gives_vmess_protocol(self): + config = self._make_xray_config_from_json({"tag1": {"protocol": "vmess"}}) + assert ProxyProtocol.vmess in config.protocols + + def test_vless_inbound_gives_vless_protocol(self): + config = self._make_xray_config_from_json({"tag1": {"protocol": "vless"}}) + assert ProxyProtocol.vless in config.protocols + + def test_multiple_protocols_all_present(self): + inbounds_by_tag = { + "vmess_tag": {"protocol": "vmess"}, + "vless_tag": {"protocol": "vless"}, + "trojan_tag": {"protocol": "trojan"}, + } + config = self._make_xray_config_from_json(inbounds_by_tag) + assert config.protocols == frozenset({ + ProxyProtocol.vmess, + ProxyProtocol.vless, + ProxyProtocol.trojan, + }) + + def test_unknown_protocol_not_in_protocols(self): + config = self._make_xray_config_from_json({"http_proxy": {"protocol": "http"}}) + assert config.protocols == frozenset() + + def test_protocols_is_frozenset(self): + config = self._make_xray_config_from_json({"tag1": {"protocol": "shadowsocks"}}) + assert isinstance(config.protocols, frozenset) + + def test_protocols_set_from_resolve_inbounds(self): + """ + Test that when XRayConfig is built from a valid minimal config, + _protocols is set from _resolve_inbounds (via _protocols_from_inbounds_by_tag). + """ + from app.core.xray import XRayConfig + + # Minimal but valid xray config + minimal_config = { + "inbounds": [ + { + "tag": "vless-in", + "protocol": "vless", + "port": 1234, + "settings": {"clients": [], "decryption": "none"}, + "streamSettings": {"network": "tcp"}, + } + ], + "outbounds": [{"tag": "direct", "protocol": "freedom"}], + } + config = XRayConfig(minimal_config) + assert ProxyProtocol.vless in config.protocols + + def test_from_json_sets_protocols_from_inbounds_by_tag(self): + """from_json reconstructs protocols from inbounds_by_tag.""" + from app.core.xray import XRayConfig + + data = { + "config": {}, + "exclude_inbound_tags": [], + "fallbacks_inbound_tags": [], + "inbounds": ["ss-tag"], + "inbounds_by_tag": {"ss-tag": {"protocol": "shadowsocks"}}, + } + config = XRayConfig.from_json(data) + assert config.protocols == frozenset({ProxyProtocol.shadowsocks}) + + +class TestWireGuardConfigProtocols: + """Tests for WireGuardConfig.protocols property.""" + + def test_protocols_always_returns_wireguard(self): + from app.core.wireguard import WireGuardConfig + + config = WireGuardConfig(skip_validation=True) + assert config.protocols == frozenset({ProxyProtocol.wireguard}) + + def test_protocols_is_frozenset(self): + from app.core.wireguard import WireGuardConfig + + config = WireGuardConfig(skip_validation=True) + assert isinstance(config.protocols, frozenset) + + def test_protocols_contains_only_wireguard(self): + from app.core.wireguard import WireGuardConfig + + config = WireGuardConfig(skip_validation=True) + assert len(config.protocols) == 1 + assert ProxyProtocol.wireguard in config.protocols + + def test_protocols_does_not_contain_other_protocols(self): + from app.core.wireguard import WireGuardConfig + + config = WireGuardConfig(skip_validation=True) + assert ProxyProtocol.vmess not in config.protocols + assert ProxyProtocol.vless not in config.protocols + assert ProxyProtocol.trojan not in config.protocols + assert ProxyProtocol.shadowsocks not in config.protocols + assert ProxyProtocol.hysteria not in config.protocols + + def test_module_level_constant_is_frozenset_with_wireguard(self): + from app.core.wireguard import _WIREGUARD_PROTOCOLS + + assert _WIREGUARD_PROTOCOLS == frozenset({ProxyProtocol.wireguard}) + + def test_protocols_returns_same_constant(self): + from app.core.wireguard import WireGuardConfig, _WIREGUARD_PROTOCOLS + + config = WireGuardConfig(skip_validation=True) + assert config.protocols is _WIREGUARD_PROTOCOLS + + +class TestAbstractCoreProtocolsProperty: + """Tests for the new protocols abstract property in AbstractCore.""" + + def test_abstract_core_has_protocols_property(self): + from app.core.abstract_core import AbstractCore + + # protocols should be an abstractmethod + assert "protocols" in AbstractCore.__abstractmethods__ + + def test_xray_config_satisfies_protocols_abstract(self): + """XRayConfig must implement the protocols property.""" + from app.core.xray import XRayConfig + + assert hasattr(XRayConfig, "protocols") + # Check it's a property + assert isinstance(XRayConfig.__dict__.get("protocols"), property) + + def test_wireguard_config_satisfies_protocols_abstract(self): + """WireGuardConfig must implement the protocols property.""" + from app.core.wireguard import WireGuardConfig + + assert hasattr(WireGuardConfig, "protocols") + assert isinstance(WireGuardConfig.__dict__.get("protocols"), property)