From 59fb788eac0e71284a39eafa70461052eaa28acc Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 20:04:35 -0700 Subject: [PATCH 01/24] [Fix] Add eval_str=True to inspect.signature() calls for stringified annotations Support `from __future__ import annotations` by resolving stringified type annotations at signature inspection time. Catch NameError from invalid string annotations and raise a clear QuadrantsSyntaxError. --- python/quadrants/lang/_func_base.py | 7 +++++-- python/quadrants/lang/_kernel_impl_dataclass.py | 2 +- python/quadrants/lang/_perf_dispatch.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/_func_base.py b/python/quadrants/lang/_func_base.py index dd5cdbac8b..f5d334458a 100644 --- a/python/quadrants/lang/_func_base.py +++ b/python/quadrants/lang/_func_base.py @@ -97,7 +97,10 @@ def check_parameter_annotations(self) -> None: Note: NOT in the hot path. Just run once, on function registration """ - sig = inspect.signature(self.func) + try: + sig = inspect.signature(self.func, eval_str=True) + except NameError as e: + raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e if hasattr(self.func, "__wrapped__"): raise_exception( QuadrantsSyntaxError, @@ -189,7 +192,7 @@ def _populate_global_vars_for_templates( for i in template_slot_locations: template_var_name = argument_metas[i].name global_vars[template_var_name] = py_args[i] - parameters = inspect.signature(fn).parameters + parameters = inspect.signature(fn, eval_str=True).parameters for i, (parameter_name, parameter) in enumerate(parameters.items()): if is_dataclass(parameter.annotation): _kernel_impl_dataclass.populate_global_vars_from_dataclass( diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index c5d7bd530e..531878249e 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -73,7 +73,7 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - sig = inspect.signature(ctx.func.func) + sig = inspect.signature(ctx.func.func, eval_str=True) parameters = sig.parameters for param_name, parameter in parameters.items(): if dataclasses.is_dataclass(parameter.annotation): diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index 4bc21844a5..ec29bc136a 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -58,7 +58,7 @@ def __init__( self.num_active = num_active if num_active is not None else NUM_ACTIVE self.repeat_after_count = repeat_after_count if repeat_after_count is not None else REPEAT_AFTER_COUNT self.repeat_after_seconds = repeat_after_seconds if repeat_after_seconds is not None else REPEAT_AFTER_SECONDS - sig = inspect.signature(fn) + sig = inspect.signature(fn, eval_str=True) self._param_types: dict[str, Any] = {} for param_name, param in sig.parameters.items(): self._param_types[param_name] = param.annotation From 544bbc981bda671d43c50f73801511574cb52892 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:31:46 -0700 Subject: [PATCH 02/24] [Fix] Add eval_str=True to perf_dispatch register decorator The register decorator's inspect.signature() call was missing eval_str=True, inconsistent with all other call sites. --- python/quadrants/lang/_perf_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index ec29bc136a..feb1f83fa2 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -99,7 +99,7 @@ def register( dispatch_impl_set = self._dispatch_impl_set def decorator(func: Callable | QuadrantsCallable) -> DispatchImpl: - sig = inspect.signature(func) + sig = inspect.signature(func, eval_str=True) log_str = f"perf_dispatch registering {func.__name__}" # type: ignore _logging.debug(log_str) if QD_PERFDISPATCH_PRINT_DEBUG: From 77fccda38e5b8b82d3322527341486eaddb21e55 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:32:56 -0700 Subject: [PATCH 03/24] [Fix] Add consistent error handling at all eval_str=True call sites Wrap all inspect.signature(eval_str=True) calls with try/except that catches both NameError and AttributeError, re-raising as QuadrantsSyntaxError for clear user-facing errors. Previously only check_parameter_annotations had error handling and it only caught NameError. --- python/quadrants/lang/_func_base.py | 7 +++++-- python/quadrants/lang/_kernel_impl_dataclass.py | 7 ++++++- python/quadrants/lang/_perf_dispatch.py | 10 ++++++++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/python/quadrants/lang/_func_base.py b/python/quadrants/lang/_func_base.py index f5d334458a..16f33b66d2 100644 --- a/python/quadrants/lang/_func_base.py +++ b/python/quadrants/lang/_func_base.py @@ -99,7 +99,7 @@ def check_parameter_annotations(self) -> None: """ try: sig = inspect.signature(self.func, eval_str=True) - except NameError as e: + except (NameError, AttributeError) as e: raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e if hasattr(self.func, "__wrapped__"): raise_exception( @@ -192,7 +192,10 @@ def _populate_global_vars_for_templates( for i in template_slot_locations: template_var_name = argument_metas[i].name global_vars[template_var_name] = py_args[i] - parameters = inspect.signature(fn, eval_str=True).parameters + try: + parameters = inspect.signature(fn, eval_str=True).parameters + except (NameError, AttributeError) as e: + raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e for i, (parameter_name, parameter) in enumerate(parameters.items()): if is_dataclass(parameter.annotation): _kernel_impl_dataclass.populate_global_vars_from_dataclass( diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index 531878249e..f220e03c36 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -73,7 +73,12 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - sig = inspect.signature(ctx.func.func, eval_str=True) + try: + sig = inspect.signature(ctx.func.func, eval_str=True) + except (NameError, AttributeError) as e: + from quadrants.lang.exception import QuadrantsSyntaxError + + raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e parameters = sig.parameters for param_name, parameter in parameters.items(): if dataclasses.is_dataclass(parameter.annotation): diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index feb1f83fa2..cf99b97639 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -58,7 +58,10 @@ def __init__( self.num_active = num_active if num_active is not None else NUM_ACTIVE self.repeat_after_count = repeat_after_count if repeat_after_count is not None else REPEAT_AFTER_COUNT self.repeat_after_seconds = repeat_after_seconds if repeat_after_seconds is not None else REPEAT_AFTER_SECONDS - sig = inspect.signature(fn, eval_str=True) + try: + sig = inspect.signature(fn, eval_str=True) + except (NameError, AttributeError) as e: + raise QuadrantsSyntaxError(f"Invalid type annotation: {e}") from e self._param_types: dict[str, Any] = {} for param_name, param in sig.parameters.items(): self._param_types[param_name] = param.annotation @@ -99,7 +102,10 @@ def register( dispatch_impl_set = self._dispatch_impl_set def decorator(func: Callable | QuadrantsCallable) -> DispatchImpl: - sig = inspect.signature(func, eval_str=True) + try: + sig = inspect.signature(func, eval_str=True) + except (NameError, AttributeError) as e: + raise QuadrantsSyntaxError(f"Invalid type annotation: {e}") from e log_str = f"perf_dispatch registering {func.__name__}" # type: ignore _logging.debug(log_str) if QD_PERFDISPATCH_PRINT_DEBUG: From 849fe2dae24ab8f7fbae1341c86f780b06c60259 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:37:02 -0700 Subject: [PATCH 04/24] [Refactor] Extract get_func_signature helper for eval_str=True calls Centralize the inspect.signature(eval_str=True) + error handling pattern into a shared get_func_signature() helper in exception.py, replacing 5 inline try/except blocks across _func_base.py, _kernel_impl_dataclass.py, and _perf_dispatch.py. --- python/quadrants/lang/_func_base.py | 11 +++-------- python/quadrants/lang/_kernel_impl_dataclass.py | 8 ++------ python/quadrants/lang/_perf_dispatch.py | 13 +++---------- python/quadrants/lang/exception.py | 10 ++++++++++ 4 files changed, 18 insertions(+), 24 deletions(-) diff --git a/python/quadrants/lang/_func_base.py b/python/quadrants/lang/_func_base.py index 16f33b66d2..078e921f53 100644 --- a/python/quadrants/lang/_func_base.py +++ b/python/quadrants/lang/_func_base.py @@ -29,6 +29,7 @@ QuadrantsRuntimeError, QuadrantsRuntimeTypeError, QuadrantsSyntaxError, + get_func_signature, ) from quadrants.lang.kernel_arguments import ArgMetadata from quadrants.lang.matrix import MatrixType @@ -97,10 +98,7 @@ def check_parameter_annotations(self) -> None: Note: NOT in the hot path. Just run once, on function registration """ - try: - sig = inspect.signature(self.func, eval_str=True) - except (NameError, AttributeError) as e: - raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e + sig = get_func_signature(self.func) if hasattr(self.func, "__wrapped__"): raise_exception( QuadrantsSyntaxError, @@ -192,10 +190,7 @@ def _populate_global_vars_for_templates( for i in template_slot_locations: template_var_name = argument_metas[i].name global_vars[template_var_name] = py_args[i] - try: - parameters = inspect.signature(fn, eval_str=True).parameters - except (NameError, AttributeError) as e: - raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e + parameters = get_func_signature(fn).parameters for i, (parameter_name, parameter) in enumerate(parameters.items()): if is_dataclass(parameter.annotation): _kernel_impl_dataclass.populate_global_vars_from_dataclass( diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index f220e03c36..ebd298dff5 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -1,6 +1,5 @@ import ast import dataclasses -import inspect from typing import Any from quadrants.lang import util @@ -73,12 +72,9 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - try: - sig = inspect.signature(ctx.func.func, eval_str=True) - except (NameError, AttributeError) as e: - from quadrants.lang.exception import QuadrantsSyntaxError + from quadrants.lang.exception import get_func_signature - raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e + sig = get_func_signature(ctx.func.func) parameters = sig.parameters for param_name, parameter in parameters.items(): if dataclasses.is_dataclass(parameter.annotation): diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index cf99b97639..a1999199d7 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -1,4 +1,3 @@ -import inspect import os import time from collections import defaultdict @@ -8,7 +7,7 @@ from . import impl from ._exceptions import raise_exception from ._quadrants_callable import QuadrantsCallable -from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError +from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError, get_func_signature NUM_WARMUP: int = 3 NUM_ACTIVE: int = 1 @@ -58,10 +57,7 @@ def __init__( self.num_active = num_active if num_active is not None else NUM_ACTIVE self.repeat_after_count = repeat_after_count if repeat_after_count is not None else REPEAT_AFTER_COUNT self.repeat_after_seconds = repeat_after_seconds if repeat_after_seconds is not None else REPEAT_AFTER_SECONDS - try: - sig = inspect.signature(fn, eval_str=True) - except (NameError, AttributeError) as e: - raise QuadrantsSyntaxError(f"Invalid type annotation: {e}") from e + sig = get_func_signature(fn) self._param_types: dict[str, Any] = {} for param_name, param in sig.parameters.items(): self._param_types[param_name] = param.annotation @@ -102,10 +98,7 @@ def register( dispatch_impl_set = self._dispatch_impl_set def decorator(func: Callable | QuadrantsCallable) -> DispatchImpl: - try: - sig = inspect.signature(func, eval_str=True) - except (NameError, AttributeError) as e: - raise QuadrantsSyntaxError(f"Invalid type annotation: {e}") from e + sig = get_func_signature(func) log_str = f"perf_dispatch registering {func.__name__}" # type: ignore _logging.debug(log_str) if QD_PERFDISPATCH_PRINT_DEBUG: diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index 771dd56b02..beaf2eeb0e 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -57,6 +57,16 @@ def get_ret(needed, provided): return QuadrantsRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}") +def get_func_signature(func): + """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" + import inspect + + try: + return inspect.signature(func, eval_str=True) + except (NameError, AttributeError) as e: + raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e + + def handle_exception_from_cpp(exc): if isinstance(exc, core.QuadrantsTypeError): return QuadrantsTypeError(str(exc)) From 46db1c700737e6f7d8cacc373f897395595097c5 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:37:39 -0700 Subject: [PATCH 05/24] [Test] Add test for kernels with from __future__ import annotations Verify that kernel parameter annotations are correctly resolved when the module uses PEP 563 stringified annotations. --- tests/python/test_future_annotations.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/python/test_future_annotations.py diff --git a/tests/python/test_future_annotations.py b/tests/python/test_future_annotations.py new file mode 100644 index 0000000000..c359679f7a --- /dev/null +++ b/tests/python/test_future_annotations.py @@ -0,0 +1,25 @@ +"""Test that kernels work with `from __future__ import annotations` (PEP 563).""" + +from __future__ import annotations + +import quadrants as qd + +from tests import test_utils + + +@qd.kernel +def add_kernel(a: qd.types.NDArray[qd.i32, 1], b: qd.types.NDArray[qd.i32, 1]) -> None: + for i in a: + a[i] = a[i] + b[i] + + +@test_utils.test() +def test_future_annotations_kernel(): + a = qd.ndarray(qd.i32, (4,)) + b = qd.ndarray(qd.i32, (4,)) + for i in range(4): + a[i] = i + b[i] = 10 + add_kernel(a, b) + for i in range(4): + assert a[i] == i + 10 From 74dcea65b64eb7def77de3fbb6d8f3a57508922c Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 20:04:53 -0700 Subject: [PATCH 06/24] [Fix] Fix NDArray single-arg subscript crash NdarrayType.__class_getitem__ crashed when called with a single arg (e.g. NdarrayType[dtype]) because it tried to unpack a non-tuple. Wrap single args in a tuple before passing to __init__. --- python/quadrants/types/ndarray_type.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/quadrants/types/ndarray_type.py b/python/quadrants/types/ndarray_type.py index 994fa70b50..9ed077b677 100644 --- a/python/quadrants/types/ndarray_type.py +++ b/python/quadrants/types/ndarray_type.py @@ -94,8 +94,10 @@ def __init__( self.boundary = int(to_boundary_enum(boundary)) @classmethod - def __class_getitem__(cls, args, **kwargs): - return cls(*args, **kwargs) + def __class_getitem__(cls, args): + if not isinstance(args, tuple): + args = (args,) + return cls(*args) def check_matched(self, ndarray_type: NdarrayTypeMetadata, arg_name: str): # FIXME(Haidong) Cannot use Vector/MatrixType due to circular import From 71dcd622626322b91a61c2440e6524d52187b433 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 20:06:29 -0700 Subject: [PATCH 07/24] [Refactor] Add cook_dtype calls at C++ boundaries (no-op preparatory refactor) Add cook_dtype() calls at all points where dtype values are passed to C++ code. Make PyQuadrants.default_fp/ip/up into properties that always store DataTypeCxx. Rename shadowed dtype var in create_field_member. All changes are behavioral no-ops with current code, preparing for a future refactor of primitive dtypes into Python classes. --- python/quadrants/lang/expr.py | 14 +++++++-- python/quadrants/lang/impl.py | 37 ++++++++++++++++++++---- python/quadrants/lang/matrix.py | 2 +- python/quadrants/linalg/sparse_matrix.py | 11 ++++--- python/quadrants/linalg/sparse_solver.py | 8 +++-- 5 files changed, 56 insertions(+), 16 deletions(-) diff --git a/python/quadrants/lang/expr.py b/python/quadrants/lang/expr.py index 0369349d6b..aedacc2876 100644 --- a/python/quadrants/lang/expr.py +++ b/python/quadrants/lang/expr.py @@ -7,7 +7,12 @@ from quadrants.lang.common_ops import QuadrantsOperations from quadrants.lang.exception import QuadrantsCompilationError, QuadrantsTypeError from quadrants.lang.matrix import make_matrix -from quadrants.lang.util import is_matrix_class, is_quadrants_class, to_numpy_type +from quadrants.lang.util import ( + cook_dtype, + is_matrix_class, + is_quadrants_class, + to_numpy_type, +) from quadrants.types import primitive_types from quadrants.types.primitive_types import integer_types, real_types @@ -109,12 +114,16 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int: def make_constant_expr(val, dtype): + if dtype is not None: + dtype = cook_dtype(dtype) + if isinstance(val, (bool, np.bool_)): - constant_dtype = primitive_types.u1 + constant_dtype = cook_dtype(primitive_types.u1) return Expr(_qd_core.make_const_expr_bool(constant_dtype, val)) if isinstance(val, (float, np.floating)): constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype + constant_dtype = cook_dtype(constant_dtype) if constant_dtype not in real_types: raise QuadrantsTypeError( "Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`." @@ -123,6 +132,7 @@ def make_constant_expr(val, dtype): if isinstance(val, (int, np.integer)): constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype + constant_dtype = cook_dtype(constant_dtype) if constant_dtype not in integer_types: raise QuadrantsTypeError( "Integer literals must be annotated with a integer type. For type casting, use `qd.cast`." diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 08144b9b10..0036a3bb78 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -83,6 +83,7 @@ def expr_init_shared_array(shape, element_type): ast_builder = get_runtime().compiling_callable.ast_builder() debug_info = _qd_core.DebugInfo(get_runtime().get_current_src_info()) + element_type = cook_dtype(element_type) return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info) @@ -355,9 +356,9 @@ def __init__(self, kernels=None): self.grad_vars = [] self.dual_vars = [] self.matrix_fields = [] - self.default_fp = f32 - self.default_ip = i32 - self.default_up = u32 + self._default_fp = cook_dtype(f32) + self._default_ip = cook_dtype(i32) + self._default_up = cook_dtype(u32) self.print_full_traceback: bool = False self.target_tape = None self.fwd_mode_manager = None @@ -371,6 +372,30 @@ def __init__(self, kernels=None): self.unrolling_limit: int = 0 self.src_ll_cache: bool = True + @property + def default_fp(self) -> DataTypeCxx: + return self._default_fp + + @default_fp.setter + def default_fp(self, value: Any) -> None: + self._default_fp = cook_dtype(value) + + @property + def default_ip(self) -> DataTypeCxx: + return self._default_ip + + @default_ip.setter + def default_ip(self, value: Any) -> None: + self._default_ip = cook_dtype(value) + + @property + def default_up(self) -> DataTypeCxx: + return self._default_up + + @default_up.setter + def default_up(self, value: Any) -> None: + self._default_up = cook_dtype(value) + @property def compiling_callable(self) -> KernelCxx | Kernel | Function: if self._compiling_callable is None: @@ -737,10 +762,10 @@ def create_field_member(dtype, name, needs_grad, needs_dual): if prog.config().debug: # adjoint checkbit x_grad_checkbit = Expr(prog.make_id_expr("")) - dtype = u8 + checkbit_dtype = u8 if prog.config().arch == _qd_core.vulkan: - dtype = i32 - x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype)) + checkbit_dtype = i32 + x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(checkbit_dtype)) x_grad_checkbit.ptr.set_name(name + ".grad_checkbit") x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT) x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr) diff --git a/python/quadrants/lang/matrix.py b/python/quadrants/lang/matrix.py index 01bf6eef64..6f44a0947f 100644 --- a/python/quadrants/lang/matrix.py +++ b/python/quadrants/lang/matrix.py @@ -176,7 +176,7 @@ def make_matrix(arr, dt=None): if len(arr) == 0: # the only usage of an empty vector is to serve as field indices shape = [0] - dt = primitive_types.i32 + dt = cook_dtype(primitive_types.i32) else: if isinstance(arr[0], Iterable): # matrix shape = [len(arr), len(arr[0])] diff --git a/python/quadrants/linalg/sparse_matrix.py b/python/quadrants/linalg/sparse_matrix.py index 7eb4f40be2..09cfd75a35 100644 --- a/python/quadrants/linalg/sparse_matrix.py +++ b/python/quadrants/linalg/sparse_matrix.py @@ -9,6 +9,7 @@ from quadrants.lang.exception import QuadrantsRuntimeError from quadrants.lang.field import Field from quadrants.lang.impl import get_runtime +from quadrants.lang.util import cook_dtype from quadrants.types import f32 @@ -24,11 +25,12 @@ class SparseMatrix: """ def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"): - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx if sm is None: self.n = n self.m = m if m else n - self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format) + self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype_cxx, storage_format) else: self.n = sm.num_rows() self.m = sm.num_cols() @@ -247,7 +249,8 @@ def __init__( ): self.num_rows = num_rows self.num_cols = num_cols if num_cols else num_rows - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx if num_rows is not None: quadrants_arch = get_runtime().prog.config().arch if quadrants_arch in [ @@ -259,7 +262,7 @@ def __init__( num_rows, num_cols, max_num_triplets, - dtype, + dtype_cxx, storage_format, ) self.ptr.create_ndarray(get_runtime().prog) diff --git a/python/quadrants/linalg/sparse_solver.py b/python/quadrants/linalg/sparse_solver.py index 3544d1a957..e66de69b06 100644 --- a/python/quadrants/linalg/sparse_solver.py +++ b/python/quadrants/linalg/sparse_solver.py @@ -8,6 +8,7 @@ from quadrants.lang.exception import QuadrantsRuntimeError from quadrants.lang.field import Field from quadrants.lang.impl import get_runtime +from quadrants.lang.util import cook_dtype from quadrants.linalg.sparse_matrix import SparseMatrix from quadrants.types.primitive_types import f32 @@ -24,7 +25,8 @@ class SparseSolver: def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): self.matrix = None - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx solver_type_list = ["LLT", "LDLT", "LU"] solver_ordering = ["AMD", "COLAMD"] if solver_type in solver_type_list and ordering in solver_ordering: @@ -35,9 +37,9 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): or quadrants_arch == _qd_core.Arch.cuda ), "SparseSolver only supports CPU and CUDA for now." if quadrants_arch == _qd_core.Arch.cuda: - self.solver = _qd_core.make_cusparse_solver(dtype, solver_type, ordering) + self.solver = _qd_core.make_cusparse_solver(dtype_cxx, solver_type, ordering) else: - self.solver = _qd_core.make_sparse_solver(dtype, solver_type, ordering) + self.solver = _qd_core.make_sparse_solver(dtype_cxx, solver_type, ordering) else: raise QuadrantsRuntimeError( f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported." From 050f3a73cf795370e047e7ebf2a0a4b3026c49ca Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:33:57 -0700 Subject: [PATCH 08/24] [Test] Add test for single-arg NDArray subscript syntax Test that NDArray[dtype] (without ndim) works and produces an NdarrayType with the correct dtype and ndim=None. --- tests/python/test_ndarray_typing.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python/test_ndarray_typing.py b/tests/python/test_ndarray_typing.py index 0ce6b4b7dc..531c786ce5 100644 --- a/tests/python/test_ndarray_typing.py +++ b/tests/python/test_ndarray_typing.py @@ -16,3 +16,9 @@ def test_ndarray_typing_square_brackets(): b[1, 1] = 5 some_kernel(a, b) assert a[1, 1] == 5 + 2 + + +def test_ndarray_typing_single_arg(): + t = qd.types.NDArray[qd.i32] + assert t.dtype == qd.i32 + assert t.ndim is None From 1c6b67335c61ff0e4cd9b4512b70401eb05d844c Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:38:28 -0700 Subject: [PATCH 09/24] [Doc] Document checkbit_dtype rename to avoid shadowing outer dtype The previous code overwrote the outer `dtype` parameter in the debug checkbit block, causing x_dual to be created with the wrong dtype. --- python/quadrants/lang/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 0036a3bb78..7ef6d18f04 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -760,7 +760,7 @@ def create_field_member(dtype, name, needs_grad, needs_dual): pyquadrants.grad_vars.append(x_grad) if prog.config().debug: - # adjoint checkbit + # adjoint checkbit — use a separate var to avoid shadowing the outer `dtype` x_grad_checkbit = Expr(prog.make_id_expr("")) checkbit_dtype = u8 if prog.config().arch == _qd_core.vulkan: From a694d81500b1df037bd323b6f20e149d4119298b Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:40:03 -0700 Subject: [PATCH 10/24] [Test] Add regression test for debug-mode dual field dtype shadowing Verify that forward-mode AD produces correct results when debug=True, guarding against the previous bug where the checkbit block's local dtype variable shadowed the outer dtype parameter. --- tests/python/test_ad_basics_fwd.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index fc37ef582c..760db9a7c4 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -124,3 +124,22 @@ def clear_dual_test(): with qd.ad.FwdMode(loss=loss, param=x): clear_dual_test() assert y.dual[None] == 4.0 + + +@test_utils.test(debug=True) +def test_dual_field_dtype_preserved_in_debug_mode(): + """Regression: debug-mode checkbit must not shadow the outer dtype.""" + x = qd.field(qd.f64, shape=(), needs_dual=True) + loss = qd.field(qd.f64, shape=(), needs_dual=True) + + x[None] = 3.0 + + @qd.kernel + def compute(): + loss[None] = x[None] * x[None] + + with qd.ad.FwdMode(loss=loss, param=x): + compute() + + assert loss[None] == 9.0 + assert loss.dual[None] == 6.0 From a5b414b7484697a09c240196721195050469e924 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:47:55 -0700 Subject: [PATCH 11/24] Remove inline comment about checkbit_dtype rename The regression test covers this; no need for a code comment. --- python/quadrants/lang/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 7ef6d18f04..0036a3bb78 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -760,7 +760,7 @@ def create_field_member(dtype, name, needs_grad, needs_dual): pyquadrants.grad_vars.append(x_grad) if prog.config().debug: - # adjoint checkbit — use a separate var to avoid shadowing the outer `dtype` + # adjoint checkbit x_grad_checkbit = Expr(prog.make_id_expr("")) checkbit_dtype = u8 if prog.config().arch == _qd_core.vulkan: From 215544f15d1a8296f32d4da35110d095b5a8651d Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:52:02 -0700 Subject: [PATCH 12/24] Remove redundant cook_dtype calls in make_constant_expr When dtype is provided it is already cooked at the top of the function, so the per-branch cook_dtype(constant_dtype) was a no-op. Now only the fallback default_fp/default_ip paths are cooked. --- python/quadrants/lang/expr.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/expr.py b/python/quadrants/lang/expr.py index aedacc2876..8112aa1544 100644 --- a/python/quadrants/lang/expr.py +++ b/python/quadrants/lang/expr.py @@ -122,8 +122,7 @@ def make_constant_expr(val, dtype): return Expr(_qd_core.make_const_expr_bool(constant_dtype, val)) if isinstance(val, (float, np.floating)): - constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype - constant_dtype = cook_dtype(constant_dtype) + constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_fp) if constant_dtype not in real_types: raise QuadrantsTypeError( "Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`." @@ -131,8 +130,7 @@ def make_constant_expr(val, dtype): return Expr(_qd_core.make_const_expr_fp(constant_dtype, val)) if isinstance(val, (int, np.integer)): - constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype - constant_dtype = cook_dtype(constant_dtype) + constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_ip) if constant_dtype not in integer_types: raise QuadrantsTypeError( "Integer literals must be annotated with a integer type. For type casting, use `qd.cast`." From 2876fa15917ad8760cb55b376fa7aa921c03b08d Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:52:36 -0700 Subject: [PATCH 13/24] Add comment clarifying cook_dtype strategy in make_constant_expr --- python/quadrants/lang/expr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/quadrants/lang/expr.py b/python/quadrants/lang/expr.py index 8112aa1544..deee4de570 100644 --- a/python/quadrants/lang/expr.py +++ b/python/quadrants/lang/expr.py @@ -114,6 +114,8 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int: def make_constant_expr(val, dtype): + # Normalise dtype once up front so the per-branch fallbacks only need to + # cook the runtime defaults (default_fp / default_ip). if dtype is not None: dtype = cook_dtype(dtype) From 9160366e095fbf6fb506b6a49a58ff8de2edd7ff Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 15:46:20 -0700 Subject: [PATCH 14/24] [Lint] Move inspect/get_func_signature imports to top-level Pylint flagged C0415 (import-outside-toplevel) for the local inspect import in exception.py and the local get_func_signature import in _kernel_impl_dataclass.py. Hoist both to module top-level imports. Made-with: Cursor --- python/quadrants/lang/_kernel_impl_dataclass.py | 3 +-- python/quadrants/lang/exception.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index ebd298dff5..baec2c68ab 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -7,6 +7,7 @@ from quadrants.lang.ast import ( ASTTransformerFuncContext, ) +from quadrants.lang.exception import get_func_signature from quadrants.lang.kernel_arguments import ArgMetadata @@ -72,8 +73,6 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - from quadrants.lang.exception import get_func_signature - sig = get_func_signature(ctx.func.func) parameters = sig.parameters for param_name, parameter in parameters.items(): diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index beaf2eeb0e..8db8c5bffc 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -1,5 +1,7 @@ # type: ignore +import inspect + from quadrants._lib import core @@ -59,8 +61,6 @@ def get_ret(needed, provided): def get_func_signature(func): """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" - import inspect - try: return inspect.signature(func, eval_str=True) except (NameError, AttributeError) as e: From 04d8e085bf55aeda031ebf703ee595a9599af703 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 15:49:30 -0700 Subject: [PATCH 15/24] fix(lint): move imports to module top in exception and _kernel_impl_dataclass Made-with: Cursor --- python/quadrants/lang/_kernel_impl_dataclass.py | 3 +-- python/quadrants/lang/exception.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index ebd298dff5..baec2c68ab 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -7,6 +7,7 @@ from quadrants.lang.ast import ( ASTTransformerFuncContext, ) +from quadrants.lang.exception import get_func_signature from quadrants.lang.kernel_arguments import ArgMetadata @@ -72,8 +73,6 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - from quadrants.lang.exception import get_func_signature - sig = get_func_signature(ctx.func.func) parameters = sig.parameters for param_name, parameter in parameters.items(): diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index beaf2eeb0e..8db8c5bffc 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -1,5 +1,7 @@ # type: ignore +import inspect + from quadrants._lib import core @@ -59,8 +61,6 @@ def get_ret(needed, provided): def get_func_signature(func): """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" - import inspect - try: return inspect.signature(func, eval_str=True) except (NameError, AttributeError) as e: From 970d7c9f1a25d12820990bad017bbce6de2777cd Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 15:51:18 -0700 Subject: [PATCH 16/24] [Lint] Move imports to top-level to satisfy pylint Fixes pylint C0415 (import-outside-toplevel) on `inspect` in `exception.py` and `get_func_signature` in `_kernel_impl_dataclass.py`. Made-with: Cursor --- python/quadrants/lang/_kernel_impl_dataclass.py | 3 +-- python/quadrants/lang/exception.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index ebd298dff5..baec2c68ab 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -7,6 +7,7 @@ from quadrants.lang.ast import ( ASTTransformerFuncContext, ) +from quadrants.lang.exception import get_func_signature from quadrants.lang.kernel_arguments import ArgMetadata @@ -72,8 +73,6 @@ def extract_struct_locals_from_context(ctx: ASTTransformerFuncContext) -> set[st """ struct_locals = set() assert ctx.func is not None - from quadrants.lang.exception import get_func_signature - sig = get_func_signature(ctx.func.func) parameters = sig.parameters for param_name, parameter in parameters.items(): diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index beaf2eeb0e..8db8c5bffc 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -1,5 +1,7 @@ # type: ignore +import inspect + from quadrants._lib import core @@ -59,8 +61,6 @@ def get_ret(needed, provided): def get_func_signature(func): """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" - import inspect - try: return inspect.signature(func, eval_str=True) except (NameError, AttributeError) as e: From dcc12164f3bdaf76bd0bd04a79eb4770942025bd Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 19 Apr 2026 16:38:32 -0700 Subject: [PATCH 17/24] [Test] Require data64 for dual_field_dtype debug-mode test The Vulkan/Metal backends on macOS lack f64 support and crash when running this test, which uses qd.f64 fields. Add require=qd.extension.data64 to skip on backends without double-precision support. Made-with: Cursor --- tests/python/test_ad_basics_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index 760db9a7c4..9dcbd1f060 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -126,7 +126,7 @@ def clear_dual_test(): assert y.dual[None] == 4.0 -@test_utils.test(debug=True) +@test_utils.test(require=qd.extension.data64, debug=True) def test_dual_field_dtype_preserved_in_debug_mode(): """Regression: debug-mode checkbit must not shadow the outer dtype.""" x = qd.field(qd.f64, shape=(), needs_dual=True) From 5f79631e1014f1ccbc3e0b2ebde9e77a3d66e092 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 01:20:13 -0700 Subject: [PATCH 18/24] [Refactor] Move get_func_signature to dedicated _signature module - Move get_func_signature out of exception.py into a new quadrants.lang._signature module so the helper lives somewhere whose name actually describes its purpose. - Add explicit type annotations (Callable -> inspect.Signature); the new module is not blanket-`# type: ignore`d so the signature is type-checked. - Generalize the error message from "Invalid type annotation of Taichi kernel" to "Invalid type annotation in ``" since the helper is also used by @perf_dispatch and non-kernel FuncBase paths. - Update import sites in _func_base, _kernel_impl_dataclass, and _perf_dispatch to pull from the new module. --- python/quadrants/lang/_func_base.py | 2 +- .../quadrants/lang/_kernel_impl_dataclass.py | 2 +- python/quadrants/lang/_perf_dispatch.py | 3 ++- python/quadrants/lang/_signature.py | 23 +++++++++++++++++++ python/quadrants/lang/exception.py | 10 -------- 5 files changed, 27 insertions(+), 13 deletions(-) create mode 100644 python/quadrants/lang/_signature.py diff --git a/python/quadrants/lang/_func_base.py b/python/quadrants/lang/_func_base.py index 8465df3d4c..b2d77c9d6e 100644 --- a/python/quadrants/lang/_func_base.py +++ b/python/quadrants/lang/_func_base.py @@ -25,11 +25,11 @@ from quadrants.lang._ndarray import Ndarray from quadrants.lang._wrap_inspect import get_source_info_and_src from quadrants.lang.ast import ASTTransformerFuncContext +from quadrants.lang._signature import get_func_signature from quadrants.lang.exception import ( QuadrantsRuntimeError, QuadrantsRuntimeTypeError, QuadrantsSyntaxError, - get_func_signature, ) from quadrants.lang.kernel_arguments import ArgMetadata from quadrants.lang.matrix import MatrixType diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index baec2c68ab..51c481d6ae 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -7,7 +7,7 @@ from quadrants.lang.ast import ( ASTTransformerFuncContext, ) -from quadrants.lang.exception import get_func_signature +from quadrants.lang._signature import get_func_signature from quadrants.lang.kernel_arguments import ArgMetadata diff --git a/python/quadrants/lang/_perf_dispatch.py b/python/quadrants/lang/_perf_dispatch.py index 2795d49213..39fae64e8d 100644 --- a/python/quadrants/lang/_perf_dispatch.py +++ b/python/quadrants/lang/_perf_dispatch.py @@ -7,7 +7,8 @@ from . import impl from ._exceptions import raise_exception from ._quadrants_callable import QuadrantsCallable -from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError, get_func_signature +from ._signature import get_func_signature +from .exception import QuadrantsRuntimeError, QuadrantsSyntaxError NUM_WARMUP: int = 3 NUM_ACTIVE: int = 1 diff --git a/python/quadrants/lang/_signature.py b/python/quadrants/lang/_signature.py new file mode 100644 index 0000000000..4798acfa07 --- /dev/null +++ b/python/quadrants/lang/_signature.py @@ -0,0 +1,23 @@ +import inspect +from typing import Callable + +from quadrants.lang.exception import QuadrantsSyntaxError + + +def get_func_signature(func: Callable) -> inspect.Signature: + """Call ``inspect.signature`` with ``eval_str=True``. + + ``eval_str=True`` resolves stringified annotations (PEP 563 / + ``from __future__ import annotations``) to real type objects so downstream + code can introspect them (e.g. ``dataclasses.is_dataclass``). + + Annotation-evaluation failures (``NameError``, ``AttributeError``) are + re-raised as :class:`QuadrantsSyntaxError` with the offending function's + qualified name, so users get a Quadrants-flavored error rather than a raw + ``inspect`` traceback. + """ + try: + return inspect.signature(func, eval_str=True) + except (NameError, AttributeError) as e: + qualname = getattr(func, "__qualname__", repr(func)) + raise QuadrantsSyntaxError(f"Invalid type annotation in `{qualname}`: {e}") from e diff --git a/python/quadrants/lang/exception.py b/python/quadrants/lang/exception.py index 8db8c5bffc..771dd56b02 100644 --- a/python/quadrants/lang/exception.py +++ b/python/quadrants/lang/exception.py @@ -1,7 +1,5 @@ # type: ignore -import inspect - from quadrants._lib import core @@ -59,14 +57,6 @@ def get_ret(needed, provided): return QuadrantsRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}") -def get_func_signature(func): - """Call inspect.signature with eval_str=True, converting annotation errors to QuadrantsSyntaxError.""" - try: - return inspect.signature(func, eval_str=True) - except (NameError, AttributeError) as e: - raise QuadrantsSyntaxError(f"Invalid type annotation of Taichi kernel: {e}") from e - - def handle_exception_from_cpp(exc): if isinstance(exc, core.QuadrantsTypeError): return QuadrantsTypeError(str(exc)) From 74936177bc706a4763c3ca4868d4270b9ff507e2 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 01:36:43 -0700 Subject: [PATCH 19/24] [Lint] Sort imports after _signature module addition --- python/quadrants/lang/_func_base.py | 2 +- python/quadrants/lang/_kernel_impl_dataclass.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/quadrants/lang/_func_base.py b/python/quadrants/lang/_func_base.py index b2d77c9d6e..96ab76afa5 100644 --- a/python/quadrants/lang/_func_base.py +++ b/python/quadrants/lang/_func_base.py @@ -23,9 +23,9 @@ from quadrants.lang import _kernel_impl_dataclass, impl from quadrants.lang._dataclass_util import create_flat_name from quadrants.lang._ndarray import Ndarray +from quadrants.lang._signature import get_func_signature from quadrants.lang._wrap_inspect import get_source_info_and_src from quadrants.lang.ast import ASTTransformerFuncContext -from quadrants.lang._signature import get_func_signature from quadrants.lang.exception import ( QuadrantsRuntimeError, QuadrantsRuntimeTypeError, diff --git a/python/quadrants/lang/_kernel_impl_dataclass.py b/python/quadrants/lang/_kernel_impl_dataclass.py index 51c481d6ae..fe862cfa3e 100644 --- a/python/quadrants/lang/_kernel_impl_dataclass.py +++ b/python/quadrants/lang/_kernel_impl_dataclass.py @@ -4,10 +4,10 @@ from quadrants.lang import util from quadrants.lang._dataclass_util import create_flat_name +from quadrants.lang._signature import get_func_signature from quadrants.lang.ast import ( ASTTransformerFuncContext, ) -from quadrants.lang._signature import get_func_signature from quadrants.lang.kernel_arguments import ArgMetadata From 24f1558356acf7579d14ede95a39ea80122fc0f3 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 01:55:37 -0700 Subject: [PATCH 20/24] [Fix] Catch SyntaxError and TypeError in get_func_signature inspect.signature(func, eval_str=True) can raise SyntaxError for syntactically invalid string annotations (e.g. `a: "NDArray["`) and TypeError for annotations that cannot be evaluated as types. These were previously leaking as raw Python tracebacks instead of being wrapped as QuadrantsSyntaxError, contradicting the helper's docstring. --- python/quadrants/lang/_signature.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/quadrants/lang/_signature.py b/python/quadrants/lang/_signature.py index 4798acfa07..2909c7f6fb 100644 --- a/python/quadrants/lang/_signature.py +++ b/python/quadrants/lang/_signature.py @@ -11,13 +11,15 @@ def get_func_signature(func: Callable) -> inspect.Signature: ``from __future__ import annotations``) to real type objects so downstream code can introspect them (e.g. ``dataclasses.is_dataclass``). - Annotation-evaluation failures (``NameError``, ``AttributeError``) are - re-raised as :class:`QuadrantsSyntaxError` with the offending function's - qualified name, so users get a Quadrants-flavored error rather than a raw - ``inspect`` traceback. + Annotation-evaluation failures (``NameError`` / ``AttributeError`` for + unresolved references, ``SyntaxError`` for malformed string annotations + such as ``"NDArray["``, and ``TypeError`` for annotations that cannot be + evaluated as types) are re-raised as :class:`QuadrantsSyntaxError` with + the offending function's qualified name, so users get a Quadrants-flavored + error rather than a raw ``inspect`` traceback. """ try: return inspect.signature(func, eval_str=True) - except (NameError, AttributeError) as e: + except (NameError, AttributeError, SyntaxError, TypeError) as e: qualname = getattr(func, "__qualname__", repr(func)) raise QuadrantsSyntaxError(f"Invalid type annotation in `{qualname}`: {e}") from e From faeda40a453b3ae40b739f94a3159bc162bd71fc Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 01:56:47 -0700 Subject: [PATCH 21/24] [Fix] Stop catching TypeError in get_func_signature inspect.signature itself raises TypeError for non-introspectable objects (some C built-ins, etc.), so wrapping that as "invalid type annotation" would be misleading. Narrow the except tuple back to (NameError, AttributeError, SyntaxError) and document why TypeError is excluded. --- python/quadrants/lang/_signature.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/quadrants/lang/_signature.py b/python/quadrants/lang/_signature.py index 2909c7f6fb..57b55ec5c0 100644 --- a/python/quadrants/lang/_signature.py +++ b/python/quadrants/lang/_signature.py @@ -13,13 +13,17 @@ def get_func_signature(func: Callable) -> inspect.Signature: Annotation-evaluation failures (``NameError`` / ``AttributeError`` for unresolved references, ``SyntaxError`` for malformed string annotations - such as ``"NDArray["``, and ``TypeError`` for annotations that cannot be - evaluated as types) are re-raised as :class:`QuadrantsSyntaxError` with - the offending function's qualified name, so users get a Quadrants-flavored - error rather than a raw ``inspect`` traceback. + such as ``"NDArray["``) are re-raised as :class:`QuadrantsSyntaxError` + with the offending function's qualified name, so users get a + Quadrants-flavored error rather than a raw ``inspect`` traceback. + + Note: ``TypeError`` is intentionally not caught here, since + ``inspect.signature`` itself raises ``TypeError`` for non-introspectable + objects -- wrapping that as "invalid type annotation" would be + misleading. """ try: return inspect.signature(func, eval_str=True) - except (NameError, AttributeError, SyntaxError, TypeError) as e: + except (NameError, AttributeError, SyntaxError) as e: qualname = getattr(func, "__qualname__", repr(func)) raise QuadrantsSyntaxError(f"Invalid type annotation in `{qualname}`: {e}") from e From 3d8f73c9adf65d545f6bdf1f8582e48490519e6e Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 02:10:31 -0700 Subject: [PATCH 22/24] [Test] Run dual_field_dtype regression on all backends using f32 The previous version used qd.f64 and was gated on qd.extension.data64, so it was skipped on backends without f64 support (Vulkan/Metal on macOS) -- meaning the regression wasn't exercised everywhere. Rewrite with f32 (universally supported) and pick values whose dual is a non-integer exactly representable in f32 (x=1.25, dual=2x=2.5). Under the original bug the dual field was created as u8 (or i32 on Vulkan), which would truncate 2.5 to 2 and fail the assertion -- so the test still catches the regression on every backend. --- tests/python/test_ad_basics_fwd.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index 9dcbd1f060..0c0ddf2615 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -126,13 +126,19 @@ def clear_dual_test(): assert y.dual[None] == 4.0 -@test_utils.test(require=qd.extension.data64, debug=True) +@test_utils.test(debug=True) def test_dual_field_dtype_preserved_in_debug_mode(): - """Regression: debug-mode checkbit must not shadow the outer dtype.""" - x = qd.field(qd.f64, shape=(), needs_dual=True) - loss = qd.field(qd.f64, shape=(), needs_dual=True) + """Regression: debug-mode checkbit must not shadow the outer dtype. - x[None] = 3.0 + Uses ``f32`` (supported on every backend) and picks values whose dual + is a non-integer exactly representable in ``f32``: under the old bug + the dual field was created as ``u8`` (or ``i32`` on Vulkan), which + would truncate ``2.5`` to ``2`` and fail the assertion. + """ + x = qd.field(qd.f32, shape=(), needs_dual=True) + loss = qd.field(qd.f32, shape=(), needs_dual=True) + + x[None] = 1.25 @qd.kernel def compute(): @@ -141,5 +147,5 @@ def compute(): with qd.ad.FwdMode(loss=loss, param=x): compute() - assert loss[None] == 9.0 - assert loss.dual[None] == 6.0 + assert loss[None] == 1.5625 + assert loss.dual[None] == 2.5 From 6095453bcfeb8a39afaa118289316551c97ac715 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 20 Apr 2026 02:12:41 -0700 Subject: [PATCH 23/24] [Test] Parametrize dual_field_dtype regression over f32 and f64 Run the regression test on both f32 and f64, using test_utils.skip_if_f64_unsupported to skip f64 only on backends that don't reliably support it (Metal, Vulkan). f32 still runs everywhere so the bug is exercised on every backend; f64 adds extra coverage where supported. --- tests/python/test_ad_basics_fwd.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index 0c0ddf2615..989286c23d 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -1,3 +1,5 @@ +import pytest + import quadrants as qd from tests import test_utils @@ -126,17 +128,20 @@ def clear_dual_test(): assert y.dual[None] == 4.0 +@pytest.mark.parametrize("dtype", [qd.f32, qd.f64]) @test_utils.test(debug=True) -def test_dual_field_dtype_preserved_in_debug_mode(): +def test_dual_field_dtype_preserved_in_debug_mode(dtype): """Regression: debug-mode checkbit must not shadow the outer dtype. - Uses ``f32`` (supported on every backend) and picks values whose dual - is a non-integer exactly representable in ``f32``: under the old bug - the dual field was created as ``u8`` (or ``i32`` on Vulkan), which - would truncate ``2.5`` to ``2`` and fail the assertion. + Picks values whose dual is a non-integer exactly representable in + both ``f32`` and ``f64`` (``x=1.25`` -> ``dual=2.5``): under the old + bug the dual field was created as ``u8`` (or ``i32`` on Vulkan), + which would truncate ``2.5`` to ``2`` and fail the assertion. """ - x = qd.field(qd.f32, shape=(), needs_dual=True) - loss = qd.field(qd.f32, shape=(), needs_dual=True) + test_utils.skip_if_f64_unsupported(dtype) + + x = qd.field(dtype, shape=(), needs_dual=True) + loss = qd.field(dtype, shape=(), needs_dual=True) x[None] = 1.25 From 2a885b153f378bb767a0183731182a14b71aa937 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Thu, 23 Apr 2026 19:50:44 -0700 Subject: [PATCH 24/24] [CI] Disable pytest-timeout plugin to avoid conflict with pytest_hardtle The AMD GPU runner has pytest-timeout installed, which registers the same hooks (pytest_timeout_cancel_timer, pytest_timeout_set_timer) as our custom pytest_hardtle plugin, causing a ValueError at startup. Made-with: Cursor --- tests/run_tests.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/run_tests.py b/tests/run_tests.py index a454003002..36059493cc 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -52,6 +52,8 @@ def _test_python(args, default_dir="python"): pytest_args += [ "--durations=15", "-p", + "no:timeout", + "-p", "pytest_hardtle", f"--timeout={args.timeout}", ]