Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
59fb788
[Fix] Add eval_str=True to inspect.signature() calls for stringified …
hughperkins Mar 12, 2026
544bbc9
[Fix] Add eval_str=True to perf_dispatch register decorator
hughperkins Mar 12, 2026
77fccda
[Fix] Add consistent error handling at all eval_str=True call sites
hughperkins Mar 12, 2026
849fe2d
[Refactor] Extract get_func_signature helper for eval_str=True calls
hughperkins Mar 12, 2026
46db1c7
[Test] Add test for kernels with from __future__ import annotations
hughperkins Mar 12, 2026
74dcea6
[Fix] Fix NDArray single-arg subscript crash
hughperkins Mar 12, 2026
71dcd62
[Refactor] Add cook_dtype calls at C++ boundaries (no-op preparatory …
hughperkins Mar 12, 2026
050f3a7
[Test] Add test for single-arg NDArray subscript syntax
hughperkins Mar 12, 2026
1c6b673
[Doc] Document checkbit_dtype rename to avoid shadowing outer dtype
hughperkins Mar 12, 2026
a694d81
[Test] Add regression test for debug-mode dual field dtype shadowing
hughperkins Mar 12, 2026
a5b414b
Remove inline comment about checkbit_dtype rename
hughperkins Mar 12, 2026
215544f
Remove redundant cook_dtype calls in make_constant_expr
hughperkins Mar 12, 2026
2876fa1
Add comment clarifying cook_dtype strategy in make_constant_expr
hughperkins Mar 12, 2026
9160366
[Lint] Move inspect/get_func_signature imports to top-level
hughperkins Apr 19, 2026
5ec31d9
Merge remote-tracking branch 'origin/main' into hp/typing-t4-1-eval-str
hughperkins Apr 19, 2026
04d8e08
fix(lint): move imports to module top in exception and _kernel_impl_d…
hughperkins Apr 19, 2026
970d7c9
[Lint] Move imports to top-level to satisfy pylint
hughperkins Apr 19, 2026
dcc1216
[Test] Require data64 for dual_field_dtype debug-mode test
hughperkins Apr 19, 2026
5a42b98
Merge remote-tracking branch 'origin/main' into hp/typing-t4-3-cook-d…
hughperkins Apr 20, 2026
84e64a4
Merge remote-tracking branch 'origin/main' into hp/typing-t4-1-eval-str
hughperkins Apr 20, 2026
5f79631
[Refactor] Move get_func_signature to dedicated _signature module
hughperkins Apr 20, 2026
7493617
[Lint] Sort imports after _signature module addition
hughperkins Apr 20, 2026
24f1558
[Fix] Catch SyntaxError and TypeError in get_func_signature
hughperkins Apr 20, 2026
faeda40
[Fix] Stop catching TypeError in get_func_signature
hughperkins Apr 20, 2026
3539c76
Merge branch 'hp/typing-t4-1-eval-str' into hp/typing-t4-2-ndarray-su…
hughperkins Apr 20, 2026
0d406e3
Merge branch 'hp/typing-t4-2-ndarray-subscript' into hp/typing-t4-3-c…
hughperkins Apr 20, 2026
3d8f73c
[Test] Run dual_field_dtype regression on all backends using f32
hughperkins Apr 20, 2026
6095453
[Test] Parametrize dual_field_dtype regression over f32 and f64
hughperkins Apr 20, 2026
7f46a78
Merge remote-tracking branch 'origin/main' into hp/typing-t4-3-cook-d…
hughperkins Apr 24, 2026
2a885b1
[CI] Disable pytest-timeout plugin to avoid conflict with pytest_hardtle
hughperkins Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions python/quadrants/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from quadrants.lang.common_ops import QuadrantsOperations
from quadrants.lang.exception import QuadrantsCompilationError, QuadrantsTypeError
from quadrants.lang.matrix import make_matrix
from quadrants.lang.util import is_matrix_class, is_quadrants_class, to_numpy_type
from quadrants.lang.util import (
cook_dtype,
is_matrix_class,
is_quadrants_class,
to_numpy_type,
)
from quadrants.types import primitive_types
from quadrants.types.primitive_types import integer_types, real_types

Expand Down Expand Up @@ -109,20 +114,25 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int:


def make_constant_expr(val, dtype):
# Normalise dtype once up front so the per-branch fallbacks only need to
# cook the runtime defaults (default_fp / default_ip).
if dtype is not None:
dtype = cook_dtype(dtype)

if isinstance(val, (bool, np.bool_)):
constant_dtype = primitive_types.u1
constant_dtype = cook_dtype(primitive_types.u1)
return Expr(_qd_core.make_const_expr_bool(constant_dtype, val))

if isinstance(val, (float, np.floating)):
constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype
constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_fp)
if constant_dtype not in real_types:
raise QuadrantsTypeError(
"Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`."
)
return Expr(_qd_core.make_const_expr_fp(constant_dtype, val))

if isinstance(val, (int, np.integer)):
constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype
constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_ip)
if constant_dtype not in integer_types:
raise QuadrantsTypeError(
"Integer literals must be annotated with a integer type. For type casting, use `qd.cast`."
Expand Down
37 changes: 31 additions & 6 deletions python/quadrants/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
def expr_init_shared_array(shape, element_type):
ast_builder = get_runtime().compiling_callable.ast_builder()
debug_info = _qd_core.DebugInfo(get_runtime().get_current_src_info())
element_type = cook_dtype(element_type)
return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info)


Expand Down Expand Up @@ -365,9 +366,9 @@ def __init__(self, kernels=None):
self.grad_vars = []
self.dual_vars = []
self.matrix_fields = []
self.default_fp = f32
self.default_ip = i32
self.default_up = u32
self._default_fp = cook_dtype(f32)
self._default_ip = cook_dtype(i32)
self._default_up = cook_dtype(u32)
self.print_full_traceback: bool = False
self.target_tape = None
self.fwd_mode_manager = None
Expand All @@ -381,6 +382,30 @@ def __init__(self, kernels=None):
self.unrolling_limit: int = 0
self.src_ll_cache: bool = True

@property
def default_fp(self) -> DataTypeCxx:
return self._default_fp

@default_fp.setter
def default_fp(self, value: Any) -> None:
self._default_fp = cook_dtype(value)

@property
def default_ip(self) -> DataTypeCxx:
return self._default_ip

@default_ip.setter
def default_ip(self, value: Any) -> None:
self._default_ip = cook_dtype(value)

@property
def default_up(self) -> DataTypeCxx:
return self._default_up

@default_up.setter
def default_up(self, value: Any) -> None:
self._default_up = cook_dtype(value)

@property
def compiling_callable(self) -> KernelCxx | Kernel | Function:
if self._compiling_callable is None:
Expand Down Expand Up @@ -747,10 +772,10 @@ def create_field_member(dtype, name, needs_grad, needs_dual):
if prog.config().debug:
# adjoint checkbit
x_grad_checkbit = Expr(prog.make_id_expr(""))
dtype = u8
checkbit_dtype = u8
if prog.config().arch == _qd_core.vulkan:
dtype = i32
x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype))
checkbit_dtype = i32
x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(checkbit_dtype))
x_grad_checkbit.ptr.set_name(name + ".grad_checkbit")
x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT)
x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr)
Expand Down
2 changes: 1 addition & 1 deletion python/quadrants/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def make_matrix(arr, dt=None):
if len(arr) == 0:
# the only usage of an empty vector is to serve as field indices
shape = [0]
dt = primitive_types.i32
dt = cook_dtype(primitive_types.i32)
else:
if isinstance(arr[0], Iterable): # matrix
shape = [len(arr), len(arr[0])]
Expand Down
11 changes: 7 additions & 4 deletions python/quadrants/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from quadrants.lang.exception import QuadrantsRuntimeError
from quadrants.lang.field import Field
from quadrants.lang.impl import get_runtime
from quadrants.lang.util import cook_dtype
from quadrants.types import f32


Expand All @@ -24,11 +25,12 @@ class SparseMatrix:
"""

def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"):
self.dtype = dtype
dtype_cxx = cook_dtype(dtype)
self.dtype = dtype_cxx
if sm is None:
self.n = n
self.m = m if m else n
self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format)
self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype_cxx, storage_format)
else:
self.n = sm.num_rows()
self.m = sm.num_cols()
Expand Down Expand Up @@ -247,7 +249,8 @@ def __init__(
):
self.num_rows = num_rows
self.num_cols = num_cols if num_cols else num_rows
self.dtype = dtype
dtype_cxx = cook_dtype(dtype)
self.dtype = dtype_cxx
if num_rows is not None:
quadrants_arch = get_runtime().prog.config().arch
if quadrants_arch in [
Expand All @@ -259,7 +262,7 @@ def __init__(
num_rows,
num_cols,
max_num_triplets,
dtype,
dtype_cxx,
storage_format,
)
self.ptr.create_ndarray(get_runtime().prog)
Expand Down
8 changes: 5 additions & 3 deletions python/quadrants/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from quadrants.lang.exception import QuadrantsRuntimeError
from quadrants.lang.field import Field
from quadrants.lang.impl import get_runtime
from quadrants.lang.util import cook_dtype
from quadrants.linalg.sparse_matrix import SparseMatrix
from quadrants.types.primitive_types import f32

Expand All @@ -24,7 +25,8 @@ class SparseSolver:

def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
self.matrix = None
self.dtype = dtype
dtype_cxx = cook_dtype(dtype)
self.dtype = dtype_cxx
solver_type_list = ["LLT", "LDLT", "LU"]
solver_ordering = ["AMD", "COLAMD"]
if solver_type in solver_type_list and ordering in solver_ordering:
Expand All @@ -35,9 +37,9 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
or quadrants_arch == _qd_core.Arch.cuda
), "SparseSolver only supports CPU and CUDA for now."
if quadrants_arch == _qd_core.Arch.cuda:
self.solver = _qd_core.make_cusparse_solver(dtype, solver_type, ordering)
self.solver = _qd_core.make_cusparse_solver(dtype_cxx, solver_type, ordering)
else:
self.solver = _qd_core.make_sparse_solver(dtype, solver_type, ordering)
self.solver = _qd_core.make_sparse_solver(dtype_cxx, solver_type, ordering)
else:
raise QuadrantsRuntimeError(
f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported."
Expand Down
30 changes: 30 additions & 0 deletions tests/python/test_ad_basics_fwd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import quadrants as qd

from tests import test_utils
Expand Down Expand Up @@ -124,3 +126,31 @@ def clear_dual_test():
with qd.ad.FwdMode(loss=loss, param=x):
clear_dual_test()
assert y.dual[None] == 4.0


@pytest.mark.parametrize("dtype", [qd.f32, qd.f64])
@test_utils.test(debug=True)
def test_dual_field_dtype_preserved_in_debug_mode(dtype):
"""Regression: debug-mode checkbit must not shadow the outer dtype.

Picks values whose dual is a non-integer exactly representable in
both ``f32`` and ``f64`` (``x=1.25`` -> ``dual=2.5``): under the old
bug the dual field was created as ``u8`` (or ``i32`` on Vulkan),
which would truncate ``2.5`` to ``2`` and fail the assertion.
"""
test_utils.skip_if_f64_unsupported(dtype)

x = qd.field(dtype, shape=(), needs_dual=True)
loss = qd.field(dtype, shape=(), needs_dual=True)

x[None] = 1.25

@qd.kernel
def compute():
loss[None] = x[None] * x[None]

with qd.ad.FwdMode(loss=loss, param=x):
compute()

assert loss[None] == 1.5625
assert loss.dual[None] == 2.5
2 changes: 2 additions & 0 deletions tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _test_python(args, default_dir="python"):
pytest_args += [
"--durations=15",
"-p",
"no:timeout",
"-p",
"pytest_hardtle",
f"--timeout={args.timeout}",
]
Expand Down
Loading