From a6f56f9af1990ad0f0077690d79be2f7b13f9e1d Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Sat, 2 May 2026 22:35:03 -0400 Subject: [PATCH 1/3] test: avoid substring assertions in metadata URL coverage --- tests/scanners/test_metadata_scanner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/scanners/test_metadata_scanner.py b/tests/scanners/test_metadata_scanner.py index d3fede98c..b88e0b6a2 100644 --- a/tests/scanners/test_metadata_scanner.py +++ b/tests/scanners/test_metadata_scanner.py @@ -88,8 +88,6 @@ def test_scan_suspicious_urls_in_readme(self): "bit.ly", "ngrok.io", } - assert any("bit.ly" in issue.message for issue in result.issues) - assert any("ngrok.io" in issue.message for issue in result.issues) def test_repeated_benign_urls_are_parsed_once(self, monkeypatch: pytest.MonkeyPatch) -> None: """Skip duplicate benign URLs before reparsing them.""" From 0a67883f178965cc4052ddee7da8e37a935459e2 Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Sat, 2 May 2026 22:36:05 -0400 Subject: [PATCH 2/3] test: normalize sarif formatter imports --- tests/integrations/test_sarif_formatter.py | 31 +++++++++++----------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/integrations/test_sarif_formatter.py b/tests/integrations/test_sarif_formatter.py index 04650eb2f..21a8a12d9 100644 --- a/tests/integrations/test_sarif_formatter.py +++ b/tests/integrations/test_sarif_formatter.py @@ -7,25 +7,24 @@ import pytest import modelaudit.integrations.sarif_formatter as sarif_formatter -from modelaudit.integrations.sarif_formatter import ( - _create_artifacts, - _create_results, - _create_rules, - _create_run, - _get_mime_type, - _get_rule_full_description, - _get_rule_id, - _get_rule_name, - _get_rule_short_description, - _get_tags_for_issue, - _normalize_path_to_uri, - _severity_to_rank, - _severity_to_sarif_level, - format_sarif_output, -) from modelaudit.models import AssetModel, FileHashesModel, FileMetadataModel, create_initial_audit_result from modelaudit.scanners.base import Issue, IssueSeverity +_create_artifacts = sarif_formatter._create_artifacts +_create_results = sarif_formatter._create_results +_create_rules = sarif_formatter._create_rules +_create_run = sarif_formatter._create_run +_get_mime_type = sarif_formatter._get_mime_type +_get_rule_full_description = sarif_formatter._get_rule_full_description +_get_rule_id = sarif_formatter._get_rule_id +_get_rule_name = sarif_formatter._get_rule_name +_get_rule_short_description = sarif_formatter._get_rule_short_description +_get_tags_for_issue = sarif_formatter._get_tags_for_issue +_normalize_path_to_uri = sarif_formatter._normalize_path_to_uri +_severity_to_rank = sarif_formatter._severity_to_rank +_severity_to_sarif_level = sarif_formatter._severity_to_sarif_level +format_sarif_output = sarif_formatter.format_sarif_output + class TestFormatSarifOutput: """Tests for main SARIF output formatting.""" From c187e5cbb05bb1b9cec55c7380928a71b7da276c Mon Sep 17 00:00:00 2001 From: Michael D'Angelo Date: Sat, 2 May 2026 22:43:40 -0400 Subject: [PATCH 3/3] fix: address remaining quality findings --- .../src/modelaudit_picklescan/call_graph.py | 100 +++--- .../test_call_graph_import_statements.py | 22 +- tests/integrations/test_sarif_formatter.py | 18 +- tests/scanners/test_keras_zip_scanner.py | 303 +++++++----------- tests/scanners/test_metadata_scanner.py | 30 +- tests/scanners/test_skops_scanner.py | 30 +- 6 files changed, 237 insertions(+), 266 deletions(-) diff --git a/packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py b/packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py index 9fbd6d3fb..ab1039e73 100644 --- a/packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py +++ b/packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py @@ -6,6 +6,7 @@ import os import sys import sysconfig +import threading from collections import deque from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import contextmanager @@ -14,8 +15,16 @@ from functools import lru_cache from importlib.machinery import EXTENSION_SUFFIXES, BuiltinImporter, FrozenImporter, ModuleSpec, PathFinder from pathlib import Path +from typing import Protocol, TypeVar, cast +# Bound per-pass import/callable fan-out for untrusted inputs. The 32-reference +# cap has kept call-graph enrichment useful while preventing pathological scan +# growth; raising it improves completeness at a runtime cost, lowering it can +# reduce detection coverage. _MAX_IMPORT_REFERENCES = 32 +# Limit per-module source reads to 1 MiB so AST parsing remains bounded on large +# inputs. This is an explicit coverage/performance tradeoff and can be tuned if +# scan precision or throughput needs change. _MAX_SOURCE_BYTES = 1024 * 1024 _MAX_CALL_GRAPH_DEPTH = 4 _MAX_VISITED_FUNCTIONS = 64 @@ -45,6 +54,21 @@ "_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH", default=0, ) +_SHARED_SOURCE_SENSITIVE_CACHE_LOCK = threading.RLock() +_CachedFunctionT = TypeVar("_CachedFunctionT", bound=Callable[..., object]) + + +class _CacheClearable(Protocol): + def cache_clear(self) -> None: ... + + +_SOURCE_SENSITIVE_CACHED_FUNCTIONS: set[_CacheClearable] = set() + + +def _register_source_sensitive_cache(function: _CachedFunctionT) -> _CachedFunctionT: + _SOURCE_SENSITIVE_CACHED_FUNCTIONS.add(cast(_CacheClearable, function)) + return function + _CLASS_ENTRYPOINT_METHODS = ( "__getattribute__", @@ -290,7 +314,9 @@ def find_dangerous_call_graphs( positional_arg_counts = _callable_invocation_positional_arg_counts(callable_invocations) callable_references = _iter_callable_invocation_references(callable_invocations) invoked_references = { - (str(reference.get("module", "")), str(reference.get("name", ""))) for reference in callable_references + (str(reference.get("module", "")), str(reference.get("name", ""))) + for reference in callable_references + if str(reference.get("module", "")) and str(reference.get("name", "")) } for reference in _iter_call_graph_references(import_references, callable_references, invoked_references): @@ -454,15 +480,18 @@ def find_unanalyzed_callable_call_graph_references( @contextmanager def shared_source_sensitive_caches() -> Iterator[None]: """Share one fresh cache generation across related enrichment passes.""" - if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() == 0: - _clear_source_sensitive_caches_now() - token = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.set(_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() + 1) - try: - yield - finally: - _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.reset(token) + with _SHARED_SOURCE_SENSITIVE_CACHE_LOCK: + depth = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() + if depth == 0: + _clear_source_sensitive_caches_now() + token = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.set(depth + 1) + try: + yield + finally: + _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.reset(token) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _safe_call_graph_entrypoints(function_name: str) -> tuple[str, ...]: try: @@ -700,6 +729,7 @@ def _is_skippable_torch_extension_global_reference(module: str, name: str) -> bo return not _has_static_torch_extension_global_target(module, name) +@_register_source_sensitive_cache @lru_cache(maxsize=256) def _has_static_torch_extension_global_target(module: str, name: str) -> bool: analysis = _analyze_module(module) @@ -735,37 +765,14 @@ def has_unanalyzed_call_graph_import_references(import_references: object) -> bo def _clear_source_sensitive_caches() -> None: - if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() > 0: - return - _clear_source_sensitive_caches_now() + with _SHARED_SOURCE_SENSITIVE_CACHE_LOCK: + if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() > 0: + return + _clear_source_sensitive_caches_now() def _clear_source_sensitive_caches_now() -> None: - for function in ( - _safe_call_graph_entrypoints, - _has_static_torch_extension_global_target, - _find_sink_path, - _find_invoked_import_execution_path, - _find_file_open_path, - _find_file_write_path, - _call_graph_entrypoints, - _resolve_function_target, - _resolve_wildcard_reexport_alias, - _wildcard_export_summary, - _resolve_class_target, - _split_function_name, - _module_source_context, - _analyze_module, - _source_function_context, - _source_class_context, - _constructor_parameter_self_attribute_targets, - _iter_call_nodes, - _collect_function_import_aliases, - _parameter_controlled_names, - _can_invoke_function_with_positional_args, - _can_follow_import_execution_fallback, - _resolve_module_source, - ): + for function in _SOURCE_SENSITIVE_CACHED_FUNCTIONS: function.cache_clear() @@ -842,6 +849,7 @@ def _find_meta_path_module_spec_without_imports(module_name: str) -> ModuleSpec return None +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _find_sink_path(start: str) -> tuple[str, ...] | None: return _find_matching_call_path( @@ -851,6 +859,7 @@ def _find_sink_path(start: str) -> tuple[str, ...] | None: ) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _find_invoked_import_execution_path( start: str, @@ -881,11 +890,13 @@ def _find_invoked_import_execution_path( return path +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _find_file_open_path(start: str) -> tuple[str, ...] | None: return _find_matching_call_path(start, _file_open_sink) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _find_file_write_path(start: str) -> tuple[str, ...] | None: return _find_matching_call_path(start, _file_write_sink) @@ -997,6 +1008,7 @@ def _calls_for_function(function_name: str) -> tuple[str, ...] | None: return analysis.calls_by_function.get(f"{module_name}.{qualified_name}") +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _call_graph_entrypoints(function_name: str) -> tuple[str, ...]: resolved = _resolve_function_target(function_name) @@ -1009,6 +1021,7 @@ def _call_graph_entrypoints(function_name: str) -> tuple[str, ...]: return _class_entrypoints(class_target) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _resolve_function_target(function_name: str) -> str | None: alias_target = _static_import_reference_alias(function_name) @@ -1113,6 +1126,7 @@ def _static_import_reference_alias(function_name: str) -> str | None: return None +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _resolve_wildcard_reexport_alias(module_name: str, qualified_name: str) -> str | None: return _resolve_wildcard_reexport_alias_inner(module_name, qualified_name, set(), 0) @@ -1150,6 +1164,7 @@ def _resolve_wildcard_reexport_alias_inner( return None +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _wildcard_export_summary(module_name: str) -> _WildcardExportSummary | None: context = _module_source_context(module_name) @@ -1158,6 +1173,7 @@ def _wildcard_export_summary(module_name: str) -> _WildcardExportSummary | None: return _collect_module_export_summary(context.module_statements, module_name, context.is_package) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _resolve_class_target(function_name: str) -> str | None: module_name, qualified_name = _split_function_name(function_name) @@ -1193,6 +1209,7 @@ def _class_entrypoints(class_name: str) -> tuple[str, ...]: return analysis.class_entrypoints.get(f"{module_name}.{qualified_name}", ()) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _split_function_name(function_name: str) -> tuple[str | None, str]: parts = function_name.split(".") @@ -1205,6 +1222,7 @@ def _split_function_name(function_name: str) -> tuple[str | None, str]: return None, function_name +@_register_source_sensitive_cache @lru_cache(maxsize=1024) def _analyze_module(module_name: str) -> _ModuleAnalysis | None: context = _module_source_context(module_name) @@ -1259,6 +1277,7 @@ def _analyze_module(module_name: str) -> _ModuleAnalysis | None: ) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _module_source_context(module_name: str) -> _ModuleSourceContext | None: source_path = _resolve_module_source(module_name) @@ -1367,6 +1386,7 @@ def _collect_import_aliases(nodes: Iterable[ast.AST], module_name: str, is_packa return aliases +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _collect_function_import_aliases( function_node: ast.FunctionDef | ast.AsyncFunctionDef, @@ -1530,6 +1550,7 @@ def _class_source_context_for_target( return _source_class_context(class_target) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _source_function_context( function_name: str, @@ -1584,6 +1605,7 @@ def _inherited_source_function_context( return inherited_method.module_name, inherited_is_package, inherited_method.node +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _source_class_context(class_name: str) -> _ClassSourceContext | None: module_name, qualified_name = _split_source_qualified_name(class_name) @@ -1927,6 +1949,7 @@ def _class_base_targets( return tuple(targets) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _constructor_parameter_self_attribute_targets(class_name: str, parameter_name: str) -> tuple[str, ...]: module_name, qualified_name = _split_source_qualified_name(class_name) @@ -2295,6 +2318,7 @@ def _has_required_user_arguments(function_node: ast.FunctionDef | ast.AsyncFunct return has_required_keyword_only +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _can_invoke_function_with_positional_args(function_name: str, positional_arg_count: int) -> bool: resolved = _resolve_function_target(function_name) @@ -2307,6 +2331,7 @@ def _can_invoke_function_with_positional_args(function_name: str, positional_arg return _can_enter_function_with_positional_args(function_node, positional_arg_count) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _can_follow_import_execution_fallback(function_name: str, positional_arg_count: int) -> bool: resolved = _resolve_function_target(function_name) @@ -2608,6 +2633,7 @@ def _function_instance_alias_value( return _resolve_class_target(call_target) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _iter_call_nodes(function_node: ast.FunctionDef | ast.AsyncFunctionDef) -> tuple[ast.Call, ...]: calls: list[ast.Call] = [] @@ -2647,6 +2673,7 @@ def _visit_nested_function_signature(self, node: ast.FunctionDef | ast.AsyncFunc return tuple(calls) +@_register_source_sensitive_cache @lru_cache(maxsize=4096) def _parameter_controlled_names(function_node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: controlled = _initial_parameter_controlled_names(function_node) @@ -2819,6 +2846,7 @@ def _resolve_import_from_module(module_name: str, is_package: bool, level: int, return ".".join(part for part in parts if part) +@_register_source_sensitive_cache @lru_cache(maxsize=1024) def _resolve_module_source(module_name: str) -> Path | None: parts = module_name.split(".") diff --git a/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py b/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py index 1ea126ce8..11b227445 100644 --- a/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py +++ b/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py @@ -70,26 +70,8 @@ def _global_call_payload(module: str, name: str, *arg_operands: bytes) -> bytes: def _clear_call_graph_caches() -> None: - call_graph_module = sys.modules["modelaudit_picklescan.call_graph"] - for function_name in ( - "_analyze_module", - "_call_graph_entrypoints", - "_find_sink_path", - "_has_static_torch_extension_global_target", - "_resolve_class_target", - "_resolve_function_target", - "_resolve_module_source", - "_safe_call_graph_entrypoints", - "_iter_call_nodes", - "_collect_function_import_aliases", - "_parameter_controlled_names", - "_split_function_name", - "_wildcard_export_summary", - "_module_source_context", - ): - cache_clear = getattr(getattr(call_graph_module, function_name), "cache_clear", None) - if cache_clear is not None: - cache_clear() + for function in call_graph._SOURCE_SENSITIVE_CACHED_FUNCTIONS: + function.cache_clear() def test_wildcard_summary_and_analysis_share_module_parse( diff --git a/tests/integrations/test_sarif_formatter.py b/tests/integrations/test_sarif_formatter.py index 21a8a12d9..b7c916882 100644 --- a/tests/integrations/test_sarif_formatter.py +++ b/tests/integrations/test_sarif_formatter.py @@ -149,7 +149,15 @@ def test_primary_issue_filter_runs_once(self, monkeypatch: pytest.MonkeyPatch) - details={"pickle_rule_code": "DANGEROUS_CALL"}, rule_code="S104", timestamp=time.time(), - ) + ), + Issue( + message="Supporting import module", + severity=IssueSeverity.WARNING, + location="/test/file.pkl", + details={"supporting_rule_code": True, "primary_rule_code": "S104"}, + rule_code="S100", + timestamp=time.time(), + ), ] result.finalize_statistics() @@ -163,9 +171,11 @@ def counting_primary_sarif_issues(issues: list[Issue]) -> list[Issue]: monkeypatch.setattr(sarif_formatter, "_primary_sarif_issues", counting_primary_sarif_issues) - _create_run(result, ["/test"], verbose=False) + run = _create_run(result, ["/test"], verbose=False) assert call_count == 1 + assert len(run["results"]) == 1 + assert run["results"][0]["message"]["text"] == "Primary dangerous call" def test_invocation_properties(self): """Test invocation includes scan properties.""" @@ -551,9 +561,9 @@ def test_get_tags_for_issue_pickle(self): assert "pickle" in tags assert "deserialization" in tags - def test_get_tags_for_issue_code_execution(self): + def test_get_tags_for_issue_code_execution(self) -> None: """Test tags for code execution issues.""" - issue = Issue(message="eval() import detected", severity=IssueSeverity.WARNING, timestamp=time.time()) + issue = Issue(message="eval() call detected", severity=IssueSeverity.WARNING, timestamp=time.time()) tags = _get_tags_for_issue(issue) assert "code-execution" in tags diff --git a/tests/scanners/test_keras_zip_scanner.py b/tests/scanners/test_keras_zip_scanner.py index 187ac134d..d1fc420b3 100644 --- a/tests/scanners/test_keras_zip_scanner.py +++ b/tests/scanners/test_keras_zip_scanner.py @@ -10,8 +10,6 @@ import base64 import json import marshal -import os -import tempfile import warnings import zipfile from pathlib import Path @@ -50,6 +48,16 @@ def create_configured_keras_zip( return keras_path +def _build_test_keras_zip(config: dict[str, Any] | str, tmp_path: Path, keras_version: str) -> str: + """Create a minimal .keras ZIP archive for CVE regression tests.""" + keras_path = tmp_path / "model.keras" + config_json = config if isinstance(config, str) else json.dumps(config) + with zipfile.ZipFile(keras_path, "w") as zf: + zf.writestr("config.json", config_json) + zf.writestr("metadata.json", json.dumps({"keras_version": keras_version})) + return str(keras_path) + + def _assert_inconclusive_keras_zip_scan(model_path: Path, reason: str, expected_check_name: str) -> None: result = KerasZipScanner().scan(str(model_path)) @@ -790,7 +798,7 @@ def test_scan_skips_oversized_metadata_json_without_warning_noise(self, tmp_path assert "keras_version" not in result.metadata assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) - def test_lambda_layer_with_exec(self): + def test_lambda_layer_with_exec(self, tmp_path: Path) -> None: """Test detection of Lambda layer with exec() call.""" scanner = KerasZipScanner() @@ -819,36 +827,27 @@ def test_lambda_layer_with_exec(self): }, } - with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as f: - with zipfile.ZipFile(f, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - zf.writestr("metadata.json", json.dumps({"keras_version": "3.0.0"})) - temp_path = f.name - - try: - result = scanner.scan(temp_path) - - # Should detect Lambda layer with exec - assert len(result.issues) > 0, "Should detect Lambda layer with dangerous code" + temp_path = Path(_build_test_keras_zip(config, tmp_path, "3.0.0")) + result = scanner.scan(str(temp_path)) - # Check for critical issue - critical_issues = [i for i in result.issues if i.severity == IssueSeverity.CRITICAL] - assert len(critical_issues) > 0, "Lambda with exec should be CRITICAL" + # Should detect Lambda layer with exec + assert len(result.issues) > 0, "Should detect Lambda layer with dangerous code" - # Check that exec was detected - exec_found = False - for issue in result.issues: - if "exec" in issue.message.lower() and "lambda" in issue.message.lower(): - exec_found = True - assert "lambda_1" in issue.message or "lambda_1" in str(issue.details) - break + # Check for critical issue + critical_issues = [i for i in result.issues if i.severity == IssueSeverity.CRITICAL] + assert len(critical_issues) > 0, "Lambda with exec should be CRITICAL" - assert exec_found, "Should detect exec in Lambda layer" + # Check that exec was detected + exec_found = False + for issue in result.issues: + if "exec" in issue.message.lower() and "lambda" in issue.message.lower(): + exec_found = True + assert "lambda_1" in issue.message or "lambda_1" in str(issue.details) + break - finally: - os.unlink(temp_path) + assert exec_found, "Should detect exec in Lambda layer" - def test_multiple_dangerous_patterns(self): + def test_multiple_dangerous_patterns(self, tmp_path: Path) -> None: """Test detection of multiple dangerous patterns in Lambda layers.""" scanner = KerasZipScanner() @@ -877,30 +876,25 @@ def test_multiple_dangerous_patterns(self): }, } - with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as f: - with zipfile.ZipFile(f, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - temp_path = f.name - - try: - result = scanner.scan(temp_path) + temp_path = tmp_path / "multiple_patterns.keras" + with zipfile.ZipFile(temp_path, "w") as zf: + zf.writestr("config.json", json.dumps(config)) - # Should detect dangerous patterns - assert len(result.issues) > 0, "Should detect dangerous patterns" + result = scanner.scan(str(temp_path)) - # Check that multiple patterns were detected - all_messages = " ".join(issue.message for issue in result.issues) - patterns_detected = [] - for pattern in ["eval", "subprocess", "__import__", "pickle"]: - if pattern in all_messages.lower(): - patterns_detected.append(pattern) + # Should detect dangerous patterns + assert len(result.issues) > 0, "Should detect dangerous patterns" - assert len(patterns_detected) > 0, f"Should detect dangerous patterns, found: {patterns_detected}" + # Check that multiple patterns were detected + all_messages = " ".join(issue.message for issue in result.issues) + patterns_detected = [] + for pattern in ["eval", "subprocess", "__import__", "pickle"]: + if pattern in all_messages.lower(): + patterns_detected.append(pattern) - finally: - os.unlink(temp_path) + assert len(patterns_detected) > 0, f"Should detect dangerous patterns, found: {patterns_detected}" - def test_safe_lambda_layer(self): + def test_safe_lambda_layer(self, tmp_path: Path) -> None: """Test that safe Lambda layers are handled appropriately.""" scanner = KerasZipScanner() @@ -923,20 +917,15 @@ def test_safe_lambda_layer(self): }, } - with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as f: - with zipfile.ZipFile(f, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - temp_path = f.name - - try: - result = scanner.scan(temp_path) + temp_path = tmp_path / "safe_lambda.keras" + with zipfile.ZipFile(temp_path, "w") as zf: + zf.writestr("config.json", json.dumps(config)) - # Safe Lambda should not be CRITICAL - critical_issues = [i for i in result.issues if i.severity == IssueSeverity.CRITICAL] - assert len(critical_issues) == 0, "Safe Lambda should not be CRITICAL" + result = scanner.scan(str(temp_path)) - finally: - os.unlink(temp_path) + # Safe Lambda should not be CRITICAL + critical_issues = [i for i in result.issues if i.severity == IssueSeverity.CRITICAL] + assert len(critical_issues) == 0, "Safe Lambda should not be CRITICAL" def test_opaque_lambda_bytecode_stays_warning(self, tmp_path: Path) -> None: """Opaque compiled Lambda bytecode should remain a warning-level finding.""" @@ -1125,7 +1114,7 @@ def test_stringlookup_prerelease_versions_treated_as_vulnerable(self, tmp_path: assert cve_checks[0].status == CheckStatus.FAILED assert cve_checks[0].severity == IssueSeverity.WARNING - def test_custom_registered_objects(self): + def test_custom_registered_objects(self, tmp_path: Path) -> None: """Test detection of custom registered objects.""" scanner = KerasZipScanner() @@ -1143,54 +1132,43 @@ def test_custom_registered_objects(self): }, } - with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as f: - with zipfile.ZipFile(f, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - temp_path = f.name - - try: - result = scanner.scan(temp_path) + temp_path = tmp_path / "custom_registered.keras" + with zipfile.ZipFile(temp_path, "w") as zf: + zf.writestr("config.json", json.dumps(config)) - # Should detect custom registered object - custom_found = False - for check in result.checks: - if "custom" in check.message.lower() and "registered" in check.message.lower(): - custom_found = True - break + result = scanner.scan(str(temp_path)) - assert custom_found, "Should detect custom registered objects" + # Should detect custom registered object + custom_found = False + for check in result.checks: + if "custom" in check.message.lower() and "registered" in check.message.lower(): + custom_found = True + break - finally: - os.unlink(temp_path) + assert custom_found, "Should detect custom registered objects" - def test_executable_files_in_zip(self): + def test_executable_files_in_zip(self, tmp_path: Path) -> None: """Test detection of executable files in the ZIP archive.""" scanner = KerasZipScanner() config = {"class_name": "Sequential", "config": {"layers": []}} - with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as f: - with zipfile.ZipFile(f, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - # Add suspicious files - zf.writestr("malicious.py", "import os; os.system('cmd')") - zf.writestr("script.sh", "#!/bin/bash\nrm -rf /") - - temp_path = f.name - - try: - result = scanner.scan(temp_path) + temp_path = tmp_path / "executable_files.keras" + with zipfile.ZipFile(temp_path, "w") as zf: + zf.writestr("config.json", json.dumps(config)) + # Add suspicious files + zf.writestr("malicious.py", "import os; os.system('cmd')") + zf.writestr("script.sh", "#!/bin/bash\nrm -rf /") - # Should detect Python and shell scripts - suspicious_files = [] - for check in result.checks: - if "Python file" in check.message or "Executable file" in check.message: - suspicious_files.append(check.message) + result = scanner.scan(str(temp_path)) - assert len(suspicious_files) >= 2, f"Should detect suspicious files, found: {suspicious_files}" + # Should detect Python and shell scripts + suspicious_files = [] + for check in result.checks: + if "Python file" in check.message or "Executable file" in check.message: + suspicious_files.append(check.message) - finally: - os.unlink(temp_path) + assert len(suspicious_files) >= 2, f"Should detect suspicious files, found: {suspicious_files}" def test_case_insensitive_suspicious_extension_detection(self, tmp_path: Path) -> None: """Uppercase/mixed-case executable extensions should be detected.""" @@ -1256,7 +1234,7 @@ def test_executable_extension_near_matches_stay_clean(self, tmp_path: Path) -> N check.name == "Executable File Detection" and check.status == CheckStatus.FAILED for check in result.checks ) - def test_nested_models(self): + def test_nested_models(self, tmp_path: Path) -> None: """Test scanning of nested model structures.""" scanner = KerasZipScanner() @@ -1287,78 +1265,63 @@ def test_nested_models(self): }, } - with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as f: - with zipfile.ZipFile(f, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - temp_path = f.name - - try: - result = scanner.scan(temp_path) + temp_path = tmp_path / "nested_models.keras" + with zipfile.ZipFile(temp_path, "w") as zf: + zf.writestr("config.json", json.dumps(config)) - # Should detect Lambda in nested model - assert len(result.issues) > 0, "Should detect Lambda in nested model" + result = scanner.scan(str(temp_path)) - # Check that __import__ was detected - import_found = False - for issue in result.issues: - if "__import__" in issue.message.lower(): - import_found = True - break + # Should detect Lambda in nested model + assert len(result.issues) > 0, "Should detect Lambda in nested model" - assert import_found, "Should detect __import__ in nested Lambda" + # Check that __import__ was detected + import_found = False + for issue in result.issues: + if "__import__" in issue.message.lower(): + import_found = True + break - finally: - os.unlink(temp_path) + assert import_found, "Should detect __import__ in nested Lambda" - def test_invalid_json_config(self): + def test_invalid_json_config(self, tmp_path: Path) -> None: """Test handling of invalid JSON in config.""" scanner = KerasZipScanner() - with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as f: - with zipfile.ZipFile(f, "w") as zf: - zf.writestr("config.json", "{ invalid json }") - temp_path = f.name - - try: - result = scanner.scan(temp_path) + temp_path = tmp_path / "invalid_json.keras" + with zipfile.ZipFile(temp_path, "w") as zf: + zf.writestr("config.json", "{ invalid json }") - # Should handle invalid JSON gracefully - assert not result.success - json_error_found = False - for check in result.checks: - if "parse" in check.message.lower() and "json" in check.message.lower(): - json_error_found = True - break + result = scanner.scan(str(temp_path)) - assert json_error_found, "Should report JSON parsing error" + # Should handle invalid JSON gracefully + assert not result.success + json_error_found = False + for check in result.checks: + if "parse" in check.message.lower() and "json" in check.message.lower(): + json_error_found = True + break - finally: - os.unlink(temp_path) + assert json_error_found, "Should report JSON parsing error" - def test_missing_config_json(self): + def test_missing_config_json(self, tmp_path: Path) -> None: """Test handling of .keras file without config.json.""" scanner = KerasZipScanner() - with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as f: - with zipfile.ZipFile(f, "w") as zf: - # Only add metadata, no config - zf.writestr("metadata.json", json.dumps({"keras_version": "3.0.0"})) - temp_path = f.name - - try: - result = scanner.scan(temp_path) + temp_path = tmp_path / "missing_config.keras" + with zipfile.ZipFile(temp_path, "w") as zf: + # Only add metadata, no config + zf.writestr("metadata.json", json.dumps({"keras_version": "3.0.0"})) - # Should handle missing config.json - missing_config_found = False - for check in result.checks: - if "config.json" in check.message: - missing_config_found = True - break + result = scanner.scan(str(temp_path)) - assert missing_config_found, "Should report missing config.json" + # Should handle missing config.json + missing_config_found = False + for check in result.checks: + if "config.json" in check.message: + missing_config_found = True + break - finally: - os.unlink(temp_path) + assert missing_config_found, "Should report missing config.json" def test_detects_subclassed_model_in_zip(self, tmp_path): """Test that scanner detects subclassed models with custom class names.""" @@ -1767,18 +1730,10 @@ class TestCVE202549655TorchModuleWrapper: def _make_keras_zip(self, config: dict[str, Any], tmp_path: Path) -> str: """Helper to create a .keras ZIP with the given config.json.""" - keras_path = tmp_path / "model.keras" - with zipfile.ZipFile(keras_path, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - zf.writestr("metadata.json", json.dumps({"keras_version": "3.11.0"})) - return str(keras_path) + return _build_test_keras_zip(config, tmp_path, "3.11.0") def _make_keras_zip_with_version(self, config: dict[str, Any], tmp_path: Path, keras_version: str) -> str: - keras_path = tmp_path / "model.keras" - with zipfile.ZipFile(keras_path, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - zf.writestr("metadata.json", json.dumps({"keras_version": keras_version})) - return str(keras_path) + return _build_test_keras_zip(config, tmp_path, keras_version) def test_torch_module_wrapper_detected_critical(self, tmp_path: Path) -> None: """TorchModuleWrapper layer should be flagged as CRITICAL.""" @@ -1946,11 +1901,7 @@ class TestCVE20251550ModuleReferences: def _make_keras_zip(self, config: dict[str, Any], tmp_path: Path) -> str: """Helper to create a .keras ZIP with the given config.json.""" - keras_path = tmp_path / "model.keras" - with zipfile.ZipFile(keras_path, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - zf.writestr("metadata.json", json.dumps({"keras_version": "3.0.0"})) - return str(keras_path) + return _build_test_keras_zip(config, tmp_path, "3.0.0") def test_dangerous_module_os_in_layer(self, tmp_path: Path) -> None: """A layer referencing 'os' module should be flagged as CRITICAL.""" @@ -2226,11 +2177,7 @@ def lower(self) -> str: def _make_keras_zip(self, config_str: str, tmp_path: Path) -> str: """Helper to create a .keras ZIP with raw config string.""" - keras_path = tmp_path / "model.keras" - with zipfile.ZipFile(keras_path, "w") as zf: - zf.writestr("config.json", config_str) - zf.writestr("metadata.json", json.dumps({"keras_version": "3.5.0"})) - return str(keras_path) + return _build_test_keras_zip(config_str, tmp_path, "3.5.0") def test_get_file_with_url_detected(self, tmp_path: Path) -> None: """Config referencing get_file with URL should be CRITICAL.""" @@ -2706,11 +2653,7 @@ class TestCVE20259906UnsafeDeserialization: def _make_keras_zip(self, config_str: str, tmp_path: Path) -> str: """Helper to create a .keras ZIP with raw config string.""" - keras_path = tmp_path / "model.keras" - with zipfile.ZipFile(keras_path, "w") as zf: - zf.writestr("config.json", config_str) - zf.writestr("metadata.json", json.dumps({"keras_version": "3.0.0"})) - return str(keras_path) + return _build_test_keras_zip(config_str, tmp_path, "3.0.0") def test_enable_unsafe_deserialization_detected(self, tmp_path: Path) -> None: """Config referencing enable_unsafe_deserialization should be CRITICAL.""" @@ -2987,12 +2930,8 @@ def test_description_text_without_dangerous_tokens_not_flagged(self, tmp_path: P class TestCVE20243660LambdaAttribution: """Test CVE-2024-3660: Lambda layer code injection attribution.""" - def _make_keras_zip(self, config: dict, tmp_path: Path, keras_version: str = "2.10.0") -> str: - keras_path = tmp_path / "model.keras" - with zipfile.ZipFile(keras_path, "w") as zf: - zf.writestr("config.json", json.dumps(config)) - zf.writestr("metadata.json", json.dumps({"keras_version": keras_version})) - return str(keras_path) + def _make_keras_zip(self, config: dict[str, Any], tmp_path: Path, keras_version: str = "2.10.0") -> str: + return _build_test_keras_zip(config, tmp_path, keras_version) def test_lambda_layer_has_cve_2024_3660_attribution(self, tmp_path: Path) -> None: """Lambda layer in .keras file should include CVE-2024-3660 attribution.""" diff --git a/tests/scanners/test_metadata_scanner.py b/tests/scanners/test_metadata_scanner.py index b88e0b6a2..2c598aa22 100644 --- a/tests/scanners/test_metadata_scanner.py +++ b/tests/scanners/test_metadata_scanner.py @@ -55,21 +55,21 @@ def test_cannot_handle_other_files(self): assert not scanner.can_handle("data.txt") assert not scanner.can_handle("random.json") - def test_scan_valid_readme(self): + def test_scan_valid_readme(self) -> None: """Test scanning valid README file.""" scanner = MetadataScanner() with tempfile.TemporaryDirectory() as temp_dir: readme_path = Path(temp_dir) / "README.md" with open(readme_path, "w") as f: - f.write("# My Model\\n\\nThis is a clean README with no security issues.\\n") + f.write("# My Model\n\nThis is a clean README with no security issues.\n") result = scanner.scan(str(readme_path)) assert result.scanner_name == "metadata" assert len(result.issues) == 0 # Clean README should have no issues - def test_scan_suspicious_urls_in_readme(self): + def test_scan_suspicious_urls_in_readme(self) -> None: """Test detection of suspicious URLs in README.""" scanner = MetadataScanner() @@ -77,7 +77,9 @@ def test_scan_suspicious_urls_in_readme(self): readme_path = Path(temp_dir) / "README.md" with open(readme_path, "w") as f: f.write( - "# Model Info\\n\\n- Download: https://bit.ly/suspicious-model\\n- Endpoint: https://ngrok.io/malicious-endpoint\\n" + "# Model Info\n\n" + "- Download: https://bit.ly/suspicious-model\n" + "- Endpoint: https://ngrok.io/malicious-endpoint\n" ) result = scanner.scan(str(readme_path)) @@ -112,14 +114,14 @@ def tracking_urlparse(url: str) -> ParseResult: assert parse_calls == 1 assert result.issues == [] - def test_scan_detects_suspicious_subdomain_hosts(self): + def test_scan_detects_suspicious_subdomain_hosts(self) -> None: """Test suspicious domains are detected through subdomain matching.""" scanner = MetadataScanner() with tempfile.TemporaryDirectory() as temp_dir: readme_path = Path(temp_dir) / "README.md" with open(readme_path, "w") as f: - f.write("# Model Info\\n\\n- Endpoint: https://api.ngrok.io/malicious-endpoint\\n") + f.write("# Model Info\n\n- Endpoint: https://api.ngrok.io/malicious-endpoint\n") result = scanner.scan(str(readme_path)) @@ -129,7 +131,7 @@ def test_scan_detects_suspicious_subdomain_hosts(self): assert issue.details.get("suspicious_domain") == "ngrok.io" assert "https://api.ngrok.io/malicious-endpoint" in str(issue.details.get("url")) - def test_scan_ignores_suspicious_domain_substrings(self): + def test_scan_ignores_suspicious_domain_substrings(self) -> None: """Test URLs are matched by hostname, not generic substring.""" scanner = MetadataScanner() @@ -137,9 +139,9 @@ def test_scan_ignores_suspicious_domain_substrings(self): readme_path = Path(temp_dir) / "README.md" with open(readme_path, "w") as f: f.write( - "# Model Info\\n\\n" - "- Docs: https://example.com/guide?redirect=bit.ly/suspicious-model\\n" - "- API: https://safe-ngrok.io/docs\\n" + "# Model Info\n\n" + "- Docs: https://example.com/guide?redirect=bit.ly/suspicious-model\n" + "- API: https://safe-ngrok.io/docs\n" ) result = scanner.scan(str(readme_path)) @@ -219,14 +221,14 @@ def test_scan_exposed_secrets_in_readme(self): assert len(result.issues) >= 1 # Should detect at least one potential secret assert any(issue.severity == IssueSeverity.INFO for issue in result.issues) - def test_scan_ignores_placeholder_secrets(self): + def test_scan_ignores_placeholder_secrets(self) -> None: """Test that obvious placeholders are not flagged as secrets.""" scanner = MetadataScanner() with tempfile.TemporaryDirectory() as temp_dir: readme_path = Path(temp_dir) / "README.md" with open(readme_path, "w") as f: - f.write("# Setup\\n\\nAPI Key: your_api_key_here\\nToken: placeholder_token\\nSecret: XXXXXXXXXX\\n") + f.write("# Setup\n\nAPI Key: your_api_key_here\nToken: placeholder_token\nSecret: XXXXXXXXXX\n") result = scanner.scan(str(readme_path)) @@ -244,14 +246,14 @@ def test_scan_nonexistent_file(self): assert result.issues[0].severity == IssueSeverity.CRITICAL assert "does not exist" in result.issues[0].message - def test_bytes_scanned_reported(self): + def test_bytes_scanned_reported(self) -> None: """Test that bytes scanned is properly reported.""" scanner = MetadataScanner() with tempfile.TemporaryDirectory() as temp_dir: readme_path = Path(temp_dir) / "README.md" with open(readme_path, "w") as f: - f.write("# Test README\\n") + f.write("# Test README\n") expected_size = readme_path.stat().st_size result = scanner.scan(str(readme_path)) diff --git a/tests/scanners/test_skops_scanner.py b/tests/scanners/test_skops_scanner.py index 9e9e49970..43a5ba093 100644 --- a/tests/scanners/test_skops_scanner.py +++ b/tests/scanners/test_skops_scanner.py @@ -1,6 +1,7 @@ """Tests for SkopsScanner covering CVE-2025-54412, CVE-2025-54413, CVE-2025-54886.""" import os +import textwrap import zipfile from pathlib import Path from typing import Any @@ -23,7 +24,10 @@ def _make_numeric_npy(element_count: int = 64) -> bytes: padding_len = (16 - ((10 + header_len) % 16)) % 16 padded_header = header_bytes + (b" " * padding_len) + b"\n" return ( - b"\x93NUMPY\x01\x00" + len(padded_header).to_bytes(2, "little") + padded_header + (b"\x00" * element_count * 8) + b"\x93NUMPY\x01\x00" + + len(padded_header).to_bytes(2, "little") + + padded_header + + (b"\x00" * (element_count * 8)) ) @@ -58,6 +62,8 @@ def _assert_inconclusive_reason(metadata: Any, reason: str) -> None: def test_protocol_probe_reuses_lowered_member_names() -> None: + """Keep ZIP member normalization linear while probing large archives.""" + class CountingMemberName(str): lower_calls = 0 @@ -249,11 +255,13 @@ def test_detects_card_with_get_model(self, tmp_path: Path) -> None: """Test detection of Card.get_model with joblib references.""" skops_file = tmp_path / "malicious.skops" with zipfile.ZipFile(skops_file, "w") as zf: - card_content = """ - # Model Card - This model uses get_model() to load the model. - Fallback to joblib for compatibility. - """ + card_content = textwrap.dedent( + """ + # Model Card + This model uses get_model() to load the model. + Fallback to joblib for compatibility. + """ + ).strip() zf.writestr("model_card.md", card_content) zf.writestr("schema.json", '{"version": "1.0"}') @@ -270,10 +278,12 @@ def test_detects_readme_with_joblib(self, tmp_path: Path) -> None: """Test detection of README with joblib fallback pattern.""" skops_file = tmp_path / "malicious.skops" with zipfile.ZipFile(skops_file, "w") as zf: - readme_content = """ - # Model README - Load the model using joblib.load() if skops fails. - """ + readme_content = textwrap.dedent( + """ + # Model README + Load the model using joblib.load() if skops fails. + """ + ).strip() zf.writestr("README.md", readme_content) zf.writestr("schema.json", '{"version": "1.0"}')