Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 64 additions & 36 deletions packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import sys
import sysconfig
import threading
from collections import deque
from collections.abc import Callable, Iterable, Iterator, Mapping
from contextlib import contextmanager
Expand All @@ -14,8 +15,16 @@
from functools import lru_cache
from importlib.machinery import EXTENSION_SUFFIXES, BuiltinImporter, FrozenImporter, ModuleSpec, PathFinder
from pathlib import Path
from typing import Protocol, TypeVar, cast

# Bound per-pass import/callable fan-out for untrusted inputs. The 32-reference
# cap has kept call-graph enrichment useful while preventing pathological scan
# growth; raising it improves completeness at a runtime cost, lowering it can
# reduce detection coverage.
_MAX_IMPORT_REFERENCES = 32
# Limit per-module source reads to 1 MiB so AST parsing remains bounded on large
# inputs. This is an explicit coverage/performance tradeoff and can be tuned if
# scan precision or throughput needs change.
_MAX_SOURCE_BYTES = 1024 * 1024
_MAX_CALL_GRAPH_DEPTH = 4
_MAX_VISITED_FUNCTIONS = 64
Expand Down Expand Up @@ -45,6 +54,21 @@
"_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH",
default=0,
)
_SHARED_SOURCE_SENSITIVE_CACHE_LOCK = threading.RLock()
_CachedFunctionT = TypeVar("_CachedFunctionT", bound=Callable[..., object])


class _CacheClearable(Protocol):
def cache_clear(self) -> None: ...


_SOURCE_SENSITIVE_CACHED_FUNCTIONS: set[_CacheClearable] = set()


def _register_source_sensitive_cache(function: _CachedFunctionT) -> _CachedFunctionT:
_SOURCE_SENSITIVE_CACHED_FUNCTIONS.add(cast(_CacheClearable, function))
return function


_CLASS_ENTRYPOINT_METHODS = (
"__getattribute__",
Expand Down Expand Up @@ -290,7 +314,9 @@ def find_dangerous_call_graphs(
positional_arg_counts = _callable_invocation_positional_arg_counts(callable_invocations)
callable_references = _iter_callable_invocation_references(callable_invocations)
invoked_references = {
(str(reference.get("module", "")), str(reference.get("name", ""))) for reference in callable_references
(str(reference.get("module", "")), str(reference.get("name", "")))
for reference in callable_references
if str(reference.get("module", "")) and str(reference.get("name", ""))
}

for reference in _iter_call_graph_references(import_references, callable_references, invoked_references):
Expand Down Expand Up @@ -454,15 +480,18 @@ def find_unanalyzed_callable_call_graph_references(
@contextmanager
def shared_source_sensitive_caches() -> Iterator[None]:
"""Share one fresh cache generation across related enrichment passes."""
if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() == 0:
_clear_source_sensitive_caches_now()
token = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.set(_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() + 1)
try:
yield
finally:
_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.reset(token)
with _SHARED_SOURCE_SENSITIVE_CACHE_LOCK:
depth = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get()
if depth == 0:
_clear_source_sensitive_caches_now()
token = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.set(depth + 1)
try:
yield
finally:
_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.reset(token)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _safe_call_graph_entrypoints(function_name: str) -> tuple[str, ...]:
try:
Expand Down Expand Up @@ -700,6 +729,7 @@ def _is_skippable_torch_extension_global_reference(module: str, name: str) -> bo
return not _has_static_torch_extension_global_target(module, name)


@_register_source_sensitive_cache
@lru_cache(maxsize=256)
def _has_static_torch_extension_global_target(module: str, name: str) -> bool:
analysis = _analyze_module(module)
Expand Down Expand Up @@ -735,37 +765,14 @@ def has_unanalyzed_call_graph_import_references(import_references: object) -> bo


def _clear_source_sensitive_caches() -> None:
if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() > 0:
return
_clear_source_sensitive_caches_now()
with _SHARED_SOURCE_SENSITIVE_CACHE_LOCK:
if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() > 0:
return
_clear_source_sensitive_caches_now()


def _clear_source_sensitive_caches_now() -> None:
for function in (
_safe_call_graph_entrypoints,
_has_static_torch_extension_global_target,
_find_sink_path,
_find_invoked_import_execution_path,
_find_file_open_path,
_find_file_write_path,
_call_graph_entrypoints,
_resolve_function_target,
_resolve_wildcard_reexport_alias,
_wildcard_export_summary,
_resolve_class_target,
_split_function_name,
_module_source_context,
_analyze_module,
_source_function_context,
_source_class_context,
_constructor_parameter_self_attribute_targets,
_iter_call_nodes,
_collect_function_import_aliases,
_parameter_controlled_names,
_can_invoke_function_with_positional_args,
_can_follow_import_execution_fallback,
_resolve_module_source,
):
for function in _SOURCE_SENSITIVE_CACHED_FUNCTIONS:
function.cache_clear()


Expand Down Expand Up @@ -842,6 +849,7 @@ def _find_meta_path_module_spec_without_imports(module_name: str) -> ModuleSpec
return None


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _find_sink_path(start: str) -> tuple[str, ...] | None:
return _find_matching_call_path(
Expand All @@ -851,6 +859,7 @@ def _find_sink_path(start: str) -> tuple[str, ...] | None:
)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _find_invoked_import_execution_path(
start: str,
Expand Down Expand Up @@ -881,11 +890,13 @@ def _find_invoked_import_execution_path(
return path


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _find_file_open_path(start: str) -> tuple[str, ...] | None:
return _find_matching_call_path(start, _file_open_sink)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _find_file_write_path(start: str) -> tuple[str, ...] | None:
return _find_matching_call_path(start, _file_write_sink)
Expand Down Expand Up @@ -997,6 +1008,7 @@ def _calls_for_function(function_name: str) -> tuple[str, ...] | None:
return analysis.calls_by_function.get(f"{module_name}.{qualified_name}")


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _call_graph_entrypoints(function_name: str) -> tuple[str, ...]:
resolved = _resolve_function_target(function_name)
Expand All @@ -1009,6 +1021,7 @@ def _call_graph_entrypoints(function_name: str) -> tuple[str, ...]:
return _class_entrypoints(class_target)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _resolve_function_target(function_name: str) -> str | None:
alias_target = _static_import_reference_alias(function_name)
Expand Down Expand Up @@ -1113,6 +1126,7 @@ def _static_import_reference_alias(function_name: str) -> str | None:
return None


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _resolve_wildcard_reexport_alias(module_name: str, qualified_name: str) -> str | None:
return _resolve_wildcard_reexport_alias_inner(module_name, qualified_name, set(), 0)
Expand Down Expand Up @@ -1150,6 +1164,7 @@ def _resolve_wildcard_reexport_alias_inner(
return None


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _wildcard_export_summary(module_name: str) -> _WildcardExportSummary | None:
context = _module_source_context(module_name)
Expand All @@ -1158,6 +1173,7 @@ def _wildcard_export_summary(module_name: str) -> _WildcardExportSummary | None:
return _collect_module_export_summary(context.module_statements, module_name, context.is_package)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _resolve_class_target(function_name: str) -> str | None:
module_name, qualified_name = _split_function_name(function_name)
Expand Down Expand Up @@ -1193,6 +1209,7 @@ def _class_entrypoints(class_name: str) -> tuple[str, ...]:
return analysis.class_entrypoints.get(f"{module_name}.{qualified_name}", ())


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _split_function_name(function_name: str) -> tuple[str | None, str]:
parts = function_name.split(".")
Expand All @@ -1205,6 +1222,7 @@ def _split_function_name(function_name: str) -> tuple[str | None, str]:
return None, function_name


@_register_source_sensitive_cache
@lru_cache(maxsize=1024)
def _analyze_module(module_name: str) -> _ModuleAnalysis | None:
context = _module_source_context(module_name)
Expand Down Expand Up @@ -1259,6 +1277,7 @@ def _analyze_module(module_name: str) -> _ModuleAnalysis | None:
)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _module_source_context(module_name: str) -> _ModuleSourceContext | None:
source_path = _resolve_module_source(module_name)
Expand Down Expand Up @@ -1367,6 +1386,7 @@ def _collect_import_aliases(nodes: Iterable[ast.AST], module_name: str, is_packa
return aliases


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _collect_function_import_aliases(
function_node: ast.FunctionDef | ast.AsyncFunctionDef,
Expand Down Expand Up @@ -1530,6 +1550,7 @@ def _class_source_context_for_target(
return _source_class_context(class_target)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _source_function_context(
function_name: str,
Expand Down Expand Up @@ -1584,6 +1605,7 @@ def _inherited_source_function_context(
return inherited_method.module_name, inherited_is_package, inherited_method.node


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _source_class_context(class_name: str) -> _ClassSourceContext | None:
module_name, qualified_name = _split_source_qualified_name(class_name)
Expand Down Expand Up @@ -1927,6 +1949,7 @@ def _class_base_targets(
return tuple(targets)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _constructor_parameter_self_attribute_targets(class_name: str, parameter_name: str) -> tuple[str, ...]:
module_name, qualified_name = _split_source_qualified_name(class_name)
Expand Down Expand Up @@ -2295,6 +2318,7 @@ def _has_required_user_arguments(function_node: ast.FunctionDef | ast.AsyncFunct
return has_required_keyword_only


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _can_invoke_function_with_positional_args(function_name: str, positional_arg_count: int) -> bool:
resolved = _resolve_function_target(function_name)
Expand All @@ -2307,6 +2331,7 @@ def _can_invoke_function_with_positional_args(function_name: str, positional_arg
return _can_enter_function_with_positional_args(function_node, positional_arg_count)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _can_follow_import_execution_fallback(function_name: str, positional_arg_count: int) -> bool:
resolved = _resolve_function_target(function_name)
Expand Down Expand Up @@ -2608,6 +2633,7 @@ def _function_instance_alias_value(
return _resolve_class_target(call_target)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _iter_call_nodes(function_node: ast.FunctionDef | ast.AsyncFunctionDef) -> tuple[ast.Call, ...]:
calls: list[ast.Call] = []
Expand Down Expand Up @@ -2647,6 +2673,7 @@ def _visit_nested_function_signature(self, node: ast.FunctionDef | ast.AsyncFunc
return tuple(calls)


@_register_source_sensitive_cache
@lru_cache(maxsize=4096)
def _parameter_controlled_names(function_node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
controlled = _initial_parameter_controlled_names(function_node)
Expand Down Expand Up @@ -2819,6 +2846,7 @@ def _resolve_import_from_module(module_name: str, is_package: bool, level: int,
return ".".join(part for part in parts if part)


@_register_source_sensitive_cache
@lru_cache(maxsize=1024)
def _resolve_module_source(module_name: str) -> Path | None:
parts = module_name.split(".")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,8 @@ def _global_call_payload(module: str, name: str, *arg_operands: bytes) -> bytes:


def _clear_call_graph_caches() -> None:
call_graph_module = sys.modules["modelaudit_picklescan.call_graph"]
for function_name in (
"_analyze_module",
"_call_graph_entrypoints",
"_find_sink_path",
"_has_static_torch_extension_global_target",
"_resolve_class_target",
"_resolve_function_target",
"_resolve_module_source",
"_safe_call_graph_entrypoints",
"_iter_call_nodes",
"_collect_function_import_aliases",
"_parameter_controlled_names",
"_split_function_name",
"_wildcard_export_summary",
"_module_source_context",
):
cache_clear = getattr(getattr(call_graph_module, function_name), "cache_clear", None)
if cache_clear is not None:
cache_clear()
for function in call_graph._SOURCE_SENSITIVE_CACHED_FUNCTIONS:
function.cache_clear()


def test_wildcard_summary_and_analysis_share_module_parse(
Expand Down
Loading
Loading