Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
exclude: ^(tests/|code_to_optimize/|pie_test_set/|experiments/|\.codeflash/)

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.8
Expand All @@ -9,7 +11,8 @@ repos:
hooks:
- id: mypy
name: mypy
entry: uv run mypy --non-interactive --config-file pyproject.toml
entry: uv run mypy --non-interactive --config-file pyproject.toml @mypy_allowlist.txt
language: system
pass_filenames: false
types: [python]
require_serial: true
110 changes: 59 additions & 51 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import _thread
import array
import ast
Expand All @@ -15,6 +17,8 @@
import weakref
import xml.etree.ElementTree as ET
from collections import ChainMap, OrderedDict, deque
from functools import cache
from importlib import import_module
from importlib.util import find_spec
from typing import Any, Optional

Expand All @@ -35,33 +39,6 @@
HAS_NUMBA = find_spec("numba") is not None
HAS_PYARROW = find_spec("pyarrow") is not None

if HAS_NUMPY:
import numpy as np
if HAS_SCIPY:
import scipy # type: ignore # noqa: PGH003
if HAS_JAX:
import jax # type: ignore # noqa: PGH003
import jax.numpy as jnp # type: ignore # noqa: PGH003
if HAS_XARRAY:
import xarray # type: ignore # noqa: PGH003
if HAS_TENSORFLOW:
import tensorflow as tf # type: ignore # noqa: PGH003
if HAS_SQLALCHEMY:
import sqlalchemy # type: ignore # noqa: PGH003
if HAS_PYARROW:
import pyarrow as pa # type: ignore # noqa: PGH003
if HAS_PANDAS:
import pandas # noqa: ICN001
if HAS_TORCH:
import torch # type: ignore # noqa: PGH003
if HAS_NUMBA:
import numba # type: ignore # noqa: PGH003
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
if HAS_PYRSISTENT:
import pyrsistent # type: ignore # noqa: PGH003

# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
# These paths vary between test runs but are logically equivalent
PYTEST_TEMP_PATH_PATTERN = re.compile(r"/tmp/pytest-of-[^/]+/pytest-\d+/") # noqa: S108
Expand Down Expand Up @@ -117,6 +94,18 @@
)


@cache
def _optional_module(module_name: str) -> Optional[Any]:
if find_spec(module_name) is None:
return None
return import_module(module_name)
Comment on lines +97 to +101
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is for import time then just remove the entire block and import on demand https://github.com/KRRT7/codeflash/pull/16/changes#diff-a6f2493dbc3f1ec576d03e2d188986f94ca475e3b4cf79ba3dff06309ec01997



def _object_module_matches(obj: Any, module_prefix: str) -> bool:
module_name = type(obj).__module__
return module_name == module_prefix or module_name.startswith(f"{module_prefix}.")


def _normalize_temp_path(path: str) -> str:
"""Normalize temporary file paths by replacing session-specific components.

Expand All @@ -133,7 +122,7 @@ def _is_temp_path(s: str) -> bool:
return PYTEST_TEMP_PATH_PATTERN.search(s) is not None or PYTHON_TEMPFILE_PATTERN.search(s) is not None


def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noqa: FA100
def _extract_exception_from_message(msg: str) -> Optional[BaseException]:
"""Try to extract a wrapped exception type from an error message.

Looks for patterns like "got ExceptionType('..." that indicate a wrapped exception.
Expand All @@ -153,7 +142,7 @@ def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noq
return None


def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # noqa: FA100
def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]:
"""Get the wrapped exception if this is a simple wrapper.

Returns the inner exception if:
Expand Down Expand Up @@ -272,22 +261,26 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
return False
return comparator(orig_referent, new_referent, superset_obj)

if HAS_JAX:
if HAS_JAX and (_object_module_matches(orig, "jax") or _object_module_matches(orig, "jaxlib")):
jax = _optional_module("jax")
jnp = _optional_module("jax.numpy")
# Handle JAX arrays first to avoid boolean context errors in other conditions
if isinstance(orig, jax.Array):
if jax is not None and jnp is not None and isinstance(orig, jax.Array):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
return bool(jnp.allclose(orig, new, equal_nan=True))

# Handle xarray objects before numpy to avoid boolean context errors
if HAS_XARRAY:
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
if HAS_XARRAY and _object_module_matches(orig, "xarray"):
xarray = _optional_module("xarray")
if xarray is not None and isinstance(orig, (xarray.Dataset, xarray.DataArray)):
return orig.identical(new)

# Handle TensorFlow objects early to avoid boolean context errors
if HAS_TENSORFLOW:
if HAS_TENSORFLOW and _object_module_matches(orig, "tensorflow"):
tf = _optional_module("tensorflow")
if isinstance(orig, tf.Tensor):
if orig.dtype != new.dtype:
return False
Expand All @@ -313,9 +306,7 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj):
return False
return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator(
orig.values.numpy(), # noqa: PD011
new.values.numpy(), # noqa: PD011
superset_obj,
orig.values.numpy(), new.values.numpy(), superset_obj
)

if isinstance(orig, tf.RaggedTensor):
Expand All @@ -325,7 +316,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
return False
return comparator(orig.to_list(), new.to_list(), superset_obj)

if HAS_SQLALCHEMY:
if HAS_SQLALCHEMY and (hasattr(orig, "_sa_instance_state") or _object_module_matches(orig, "sqlalchemy")):
sqlalchemy = _optional_module("sqlalchemy")
try:
insp = sqlalchemy.inspection.inspect(orig)
insp = sqlalchemy.inspection.inspect(new)
Expand All @@ -342,7 +334,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
pass

# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
scipy_sparse = _optional_module("scipy.sparse") if HAS_SCIPY and _object_module_matches(orig, "scipy") else None
if isinstance(orig, dict) and not (scipy_sparse is not None and isinstance(orig, scipy_sparse.spmatrix)):
if superset_obj:
return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items())
if len(orig) != len(new):
Expand All @@ -366,7 +359,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
if isinstance(orig, _DICT_ITEMS_TYPE):
return comparator(dict(orig), dict(new), superset_obj)

if HAS_NUMPY:
np = _optional_module("numpy") if HAS_NUMPY and _object_module_matches(orig, "numpy") else None
if np is not None:
if isinstance(orig, (np.datetime64, np.timedelta64)):
# Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
if np.isnat(orig) and np.isnat(new):
Expand Down Expand Up @@ -420,14 +414,15 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
new_state = new.get_state(legacy=False)
return comparator(orig_state, new_state, superset_obj)

if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
if scipy_sparse is not None and isinstance(orig, scipy_sparse.spmatrix):
if orig.dtype != new.dtype:
return False
if orig.get_shape() != new.get_shape():
return False
return (orig != new).nnz == 0

if HAS_PYARROW:
if HAS_PYARROW and _object_module_matches(orig, "pyarrow"):
pa = _optional_module("pyarrow")
if isinstance(orig, pa.Table):
if orig.schema != new.schema:
return False
Expand Down Expand Up @@ -469,7 +464,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
if isinstance(orig, (pa.Schema, pa.Field, pa.DataType)):
return bool(orig.equals(new))

if HAS_PANDAS:
if HAS_PANDAS and _object_module_matches(orig, "pandas"):
pandas = _optional_module("pandas")
if isinstance(
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
):
Expand All @@ -489,17 +485,18 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:

# This should be at the end of all numpy checking
try:
if HAS_NUMPY and np.isnan(orig):
if np is not None and np.isnan(orig):
return np.isnan(new)
except Exception:
pass
try:
if HAS_NUMPY and np.isinf(orig):
if np is not None and np.isinf(orig):
return np.isinf(new)
except Exception:
pass

if HAS_TORCH:
if HAS_TORCH and _object_module_matches(orig, "torch"):
torch = _optional_module("torch")
if isinstance(orig, torch.Tensor):
if orig.dtype != new.dtype:
return False
Expand All @@ -517,15 +514,23 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
if isinstance(orig, torch.device):
return orig == new

if HAS_NUMBA:
if HAS_NUMBA and _object_module_matches(orig, "numba"):
numba = _optional_module("numba")
numba_dispatcher = _optional_module("numba.core.dispatcher")
numba_typed = _optional_module("numba.typed")
if numba is None or numba_dispatcher is None or numba_typed is None:
return False
dispatcher = numba_dispatcher.Dispatcher
numba_dict = numba_typed.Dict
numba_list = numba_typed.List
# Handle numba typed List
if isinstance(orig, NumbaList):
if isinstance(orig, numba_list):
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))

# Handle numba typed Dict
if isinstance(orig, NumbaDict):
if isinstance(orig, numba_dict):
if superset_obj:
# Allow new dict to have more keys, but all orig keys must exist with equal values
return all(key in new and comparator(orig[key], new[key], superset_obj) for key in orig)
Expand All @@ -543,12 +548,15 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
return orig == new

# Handle numba JIT-compiled functions (CPUDispatcher, etc.)
if isinstance(orig, Dispatcher):
if isinstance(orig, dispatcher):
# Compare by identity of the underlying Python function
# Two JIT functions are equal if they wrap the same Python function
return orig.py_func is new.py_func

if HAS_PYRSISTENT:
if HAS_PYRSISTENT and _object_module_matches(orig, "pyrsistent"):
pyrsistent = _optional_module("pyrsistent")
if pyrsistent is None:
return False
if isinstance(
orig,
(
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ warn_unreachable = true
install_types = true
plugins = ["pydantic.mypy"]

exclude = ["tests/", "code_to_optimize/", "pie_test_set/", "experiments/"]
exclude = ["tests/", "code_to_optimize/", "pie_test_set/", "experiments/", "^\\.codeflash/"]

[[tool.mypy.overrides]]
module = ["jedi", "jedi.api.classes", "inquirer", "inquirer.themes", "numba", "dill"]
Expand Down Expand Up @@ -379,4 +379,3 @@ markers = [
[build-system]
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"

22 changes: 22 additions & 0 deletions tests/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime
import decimal
import re
import subprocess
import sys
import uuid
import weakref
Expand All @@ -29,6 +30,27 @@
from codeflash.verification.equivalence import compare_test_results


def test_comparator_import_does_not_load_optional_numeric_modules() -> None:
result = subprocess.run(
[
sys.executable,
"-c",
(
"import sys; "
"import codeflash.verification.comparator; "
"loaded = {'numpy', 'pandas', 'xarray', 'numexpr'} & set(sys.modules); "
"print(','.join(sorted(loaded))); "
"raise SystemExit(bool(loaded))"
),
],
capture_output=True,
text=True,
encoding="utf-8",
check=False,
)
assert result.returncode == 0, result.stdout + result.stderr


def test_basic_python_objects() -> None:
a = 5
b = 5
Expand Down
Loading