Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@


class _CacheClearable(Protocol):
def cache_clear(self) -> None: ...
def cache_clear(self) -> None:
pass


_SOURCE_SENSITIVE_CACHED_FUNCTIONS: set[_CacheClearable] = set()
Expand Down
43 changes: 29 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
}
_FRAMEWORK_AVAILABILITY: dict[str, bool] = {}

# Reduced CI lanes only install the dependency set needed for this curated test
# subset. Keep this centralized so adding a supported Python version requires one
# deliberate update instead of duplicating version checks in hook logic.
RESTRICTED_PYTHON_VERSIONS = frozenset({(3, 10), (3, 12), (3, 13)})


def _check_framework(name: str) -> bool:
"""Check whether a framework package can be imported, with lazy caching."""
Expand Down Expand Up @@ -80,8 +85,8 @@ def _detect_symlink_support() -> bool:

def pytest_runtest_setup(item):
"""Skip tests based on Python version and framework availability."""
# Skip problematic tests on Python 3.10, 3.12, and 3.13 to ensure CI passes
if sys.version_info[:2] in [(3, 10), (3, 12), (3, 13)]:
# Skip problematic tests on restricted Python versions to ensure CI passes
if sys.version_info[:2] in RESTRICTED_PYTHON_VERSIONS:
test_file = str(item.fspath)

# Only allow core XGBoost scanner tests and basic unit tests on problematic Python versions
Expand Down Expand Up @@ -301,20 +306,21 @@ def temp_model_dir(tmp_path):


@pytest.fixture
def mock_progress_callback():
"""Return a mock progress callback function that records calls."""
progress_messages = []
progress_percentages = []
def mock_progress_callback() -> "ProgressCallbackRecorder":
"""Return a mock progress callback object that records calls."""
return ProgressCallbackRecorder()


def progress_callback(message, percentage):
progress_messages.append(message)
progress_percentages.append(percentage)
class ProgressCallbackRecorder:
"""Record progress callback calls for assertions in integration tests."""

# Add the recorded messages and percentages as attributes
progress_callback.messages = progress_messages # type: ignore[attr-defined]
progress_callback.percentages = progress_percentages # type: ignore[attr-defined]
def __init__(self) -> None:
self.messages: list[str] = []
self.percentages: list[float] = []

return progress_callback
def __call__(self, message: str, percentage: float) -> None:
self.messages.append(message)
self.percentages.append(percentage)


@pytest.fixture
Expand Down Expand Up @@ -415,7 +421,16 @@ def pytest_configure_node(node: Any) -> None:
global _xdist_status_reporter

if _xdist_status_reporter is None:
status_dir = node.config._tmp_path_factory.mktemp("modelaudit-pytest-xdist") # type: ignore[attr-defined]
# pytest does not expose a public API at this xdist hook point for creating
# a controller-owned temp directory. Pytest 8.4+ provides this internal
# factory; fail clearly if a future pytest release removes it.
tmp_path_factory = getattr(node.config, "_tmp_path_factory", None)
if tmp_path_factory is None:
raise RuntimeError(
"pytest internal API changed: '_tmp_path_factory' is unavailable in pytest_configure_node"
)

status_dir = tmp_path_factory.mktemp("modelaudit-pytest-xdist")
_xdist_status_reporter = XdistWorkerStatusReporter.from_environment(status_dir)
if _xdist_status_reporter is None:
return
Expand Down
Loading