diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6f0861c36..ea807e628 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ +exclude: ^(tests/|code_to_optimize/|pie_test_set/|experiments/|\.codeflash/) + repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.15.8 @@ -9,7 +11,8 @@ repos: hooks: - id: mypy name: mypy - entry: uv run mypy --non-interactive --config-file pyproject.toml + entry: uv run mypy --non-interactive --config-file pyproject.toml @mypy_allowlist.txt language: system + pass_filenames: false types: [python] require_serial: true diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 36b98e6e2..c197cb63d 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import _thread import array import ast @@ -15,6 +17,8 @@ import weakref import xml.etree.ElementTree as ET from collections import ChainMap, OrderedDict, deque +from functools import cache +from importlib import import_module from importlib.util import find_spec from typing import Any, Optional @@ -35,33 +39,6 @@ HAS_NUMBA = find_spec("numba") is not None HAS_PYARROW = find_spec("pyarrow") is not None -if HAS_NUMPY: - import numpy as np -if HAS_SCIPY: - import scipy # type: ignore # noqa: PGH003 -if HAS_JAX: - import jax # type: ignore # noqa: PGH003 - import jax.numpy as jnp # type: ignore # noqa: PGH003 -if HAS_XARRAY: - import xarray # type: ignore # noqa: PGH003 -if HAS_TENSORFLOW: - import tensorflow as tf # type: ignore # noqa: PGH003 -if HAS_SQLALCHEMY: - import sqlalchemy # type: ignore # noqa: PGH003 -if HAS_PYARROW: - import pyarrow as pa # type: ignore # noqa: PGH003 -if HAS_PANDAS: - import pandas # noqa: ICN001 -if HAS_TORCH: - import torch # type: ignore # noqa: PGH003 -if HAS_NUMBA: - import numba # type: ignore # noqa: PGH003 - from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003 - from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003 - from numba.typed import List as NumbaList # type: ignore # noqa: PGH003 -if HAS_PYRSISTENT: - import pyrsistent # type: ignore # noqa: PGH003 - # Pattern to match pytest temp directories: /tmp/pytest-of-/pytest-/ # These paths vary between test runs but are logically equivalent PYTEST_TEMP_PATH_PATTERN = re.compile(r"/tmp/pytest-of-[^/]+/pytest-\d+/") # noqa: S108 @@ -117,6 +94,18 @@ ) +@cache +def _optional_module(module_name: str) -> Optional[Any]: + if find_spec(module_name) is None: + return None + return import_module(module_name) + + +def _object_module_matches(obj: Any, module_prefix: str) -> bool: + module_name = type(obj).__module__ + return module_name == module_prefix or module_name.startswith(f"{module_prefix}.") + + def _normalize_temp_path(path: str) -> str: """Normalize temporary file paths by replacing session-specific components. @@ -133,7 +122,7 @@ def _is_temp_path(s: str) -> bool: return PYTEST_TEMP_PATH_PATTERN.search(s) is not None or PYTHON_TEMPFILE_PATTERN.search(s) is not None -def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noqa: FA100 +def _extract_exception_from_message(msg: str) -> Optional[BaseException]: """Try to extract a wrapped exception type from an error message. Looks for patterns like "got ExceptionType('..." that indicate a wrapped exception. @@ -153,7 +142,7 @@ def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noq return None -def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # noqa: FA100 +def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: """Get the wrapped exception if this is a simple wrapper. Returns the inner exception if: @@ -272,9 +261,11 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return False return comparator(orig_referent, new_referent, superset_obj) - if HAS_JAX: + if HAS_JAX and (_object_module_matches(orig, "jax") or _object_module_matches(orig, "jaxlib")): + jax = _optional_module("jax") + jnp = _optional_module("jax.numpy") # Handle JAX arrays first to avoid boolean context errors in other conditions - if isinstance(orig, jax.Array): + if jax is not None and jnp is not None and isinstance(orig, jax.Array): if orig.dtype != new.dtype: return False if orig.shape != new.shape: @@ -282,12 +273,14 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return bool(jnp.allclose(orig, new, equal_nan=True)) # Handle xarray objects before numpy to avoid boolean context errors - if HAS_XARRAY: - if isinstance(orig, (xarray.Dataset, xarray.DataArray)): + if HAS_XARRAY and _object_module_matches(orig, "xarray"): + xarray = _optional_module("xarray") + if xarray is not None and isinstance(orig, (xarray.Dataset, xarray.DataArray)): return orig.identical(new) # Handle TensorFlow objects early to avoid boolean context errors - if HAS_TENSORFLOW: + if HAS_TENSORFLOW and _object_module_matches(orig, "tensorflow"): + tf = _optional_module("tensorflow") if isinstance(orig, tf.Tensor): if orig.dtype != new.dtype: return False @@ -313,9 +306,7 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj): return False return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator( - orig.values.numpy(), # noqa: PD011 - new.values.numpy(), # noqa: PD011 - superset_obj, + orig.values.numpy(), new.values.numpy(), superset_obj ) if isinstance(orig, tf.RaggedTensor): @@ -325,7 +316,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return False return comparator(orig.to_list(), new.to_list(), superset_obj) - if HAS_SQLALCHEMY: + if HAS_SQLALCHEMY and (hasattr(orig, "_sa_instance_state") or _object_module_matches(orig, "sqlalchemy")): + sqlalchemy = _optional_module("sqlalchemy") try: insp = sqlalchemy.inspection.inspect(orig) insp = sqlalchemy.inspection.inspect(new) @@ -342,7 +334,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: pass # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it - if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)): + scipy_sparse = _optional_module("scipy.sparse") if HAS_SCIPY and _object_module_matches(orig, "scipy") else None + if isinstance(orig, dict) and not (scipy_sparse is not None and isinstance(orig, scipy_sparse.spmatrix)): if superset_obj: return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items()) if len(orig) != len(new): @@ -366,7 +359,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: if isinstance(orig, _DICT_ITEMS_TYPE): return comparator(dict(orig), dict(new), superset_obj) - if HAS_NUMPY: + np = _optional_module("numpy") if HAS_NUMPY and _object_module_matches(orig, "numpy") else None + if np is not None: if isinstance(orig, (np.datetime64, np.timedelta64)): # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime if np.isnat(orig) and np.isnat(new): @@ -420,14 +414,15 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: new_state = new.get_state(legacy=False) return comparator(orig_state, new_state, superset_obj) - if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): + if scipy_sparse is not None and isinstance(orig, scipy_sparse.spmatrix): if orig.dtype != new.dtype: return False if orig.get_shape() != new.get_shape(): return False return (orig != new).nnz == 0 - if HAS_PYARROW: + if HAS_PYARROW and _object_module_matches(orig, "pyarrow"): + pa = _optional_module("pyarrow") if isinstance(orig, pa.Table): if orig.schema != new.schema: return False @@ -469,7 +464,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: if isinstance(orig, (pa.Schema, pa.Field, pa.DataType)): return bool(orig.equals(new)) - if HAS_PANDAS: + if HAS_PANDAS and _object_module_matches(orig, "pandas"): + pandas = _optional_module("pandas") if isinstance( orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray) ): @@ -489,17 +485,18 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: # This should be at the end of all numpy checking try: - if HAS_NUMPY and np.isnan(orig): + if np is not None and np.isnan(orig): return np.isnan(new) except Exception: pass try: - if HAS_NUMPY and np.isinf(orig): + if np is not None and np.isinf(orig): return np.isinf(new) except Exception: pass - if HAS_TORCH: + if HAS_TORCH and _object_module_matches(orig, "torch"): + torch = _optional_module("torch") if isinstance(orig, torch.Tensor): if orig.dtype != new.dtype: return False @@ -517,15 +514,23 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: if isinstance(orig, torch.device): return orig == new - if HAS_NUMBA: + if HAS_NUMBA and _object_module_matches(orig, "numba"): + numba = _optional_module("numba") + numba_dispatcher = _optional_module("numba.core.dispatcher") + numba_typed = _optional_module("numba.typed") + if numba is None or numba_dispatcher is None or numba_typed is None: + return False + dispatcher = numba_dispatcher.Dispatcher + numba_dict = numba_typed.Dict + numba_list = numba_typed.List # Handle numba typed List - if isinstance(orig, NumbaList): + if isinstance(orig, numba_list): if len(orig) != len(new): return False return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) # Handle numba typed Dict - if isinstance(orig, NumbaDict): + if isinstance(orig, numba_dict): if superset_obj: # Allow new dict to have more keys, but all orig keys must exist with equal values return all(key in new and comparator(orig[key], new[key], superset_obj) for key in orig) @@ -543,12 +548,15 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return orig == new # Handle numba JIT-compiled functions (CPUDispatcher, etc.) - if isinstance(orig, Dispatcher): + if isinstance(orig, dispatcher): # Compare by identity of the underlying Python function # Two JIT functions are equal if they wrap the same Python function return orig.py_func is new.py_func - if HAS_PYRSISTENT: + if HAS_PYRSISTENT and _object_module_matches(orig, "pyrsistent"): + pyrsistent = _optional_module("pyrsistent") + if pyrsistent is None: + return False if isinstance( orig, ( diff --git a/pyproject.toml b/pyproject.toml index 0d2df965d..6635e2173 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,7 +218,7 @@ warn_unreachable = true install_types = true plugins = ["pydantic.mypy"] -exclude = ["tests/", "code_to_optimize/", "pie_test_set/", "experiments/"] +exclude = ["tests/", "code_to_optimize/", "pie_test_set/", "experiments/", "^\\.codeflash/"] [[tool.mypy.overrides]] module = ["jedi", "jedi.api.classes", "inquirer", "inquirer.themes", "numba", "dill"] @@ -379,4 +379,3 @@ markers = [ [build-system] requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" - diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 52d5192a4..7d621d7af 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -5,6 +5,7 @@ import datetime import decimal import re +import subprocess import sys import uuid import weakref @@ -29,6 +30,27 @@ from codeflash.verification.equivalence import compare_test_results +def test_comparator_import_does_not_load_optional_numeric_modules() -> None: + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import sys; " + "import codeflash.verification.comparator; " + "loaded = {'numpy', 'pandas', 'xarray', 'numexpr'} & set(sys.modules); " + "print(','.join(sorted(loaded))); " + "raise SystemExit(bool(loaded))" + ), + ], + capture_output=True, + text=True, + encoding="utf-8", + check=False, + ) + assert result.returncode == 0, result.stdout + result.stderr + + def test_basic_python_objects() -> None: a = 5 b = 5