From 337795f8884f7f497ac1e5ac5beb1c4aaf82e9a2 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 1 May 2026 11:35:58 -0700 Subject: [PATCH 1/2] Reject ndarray fields inside template-annotated struct params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Passing structs containing ndarrays through qd.template() bypasses argument pruning — every ndarray field gets registered regardless of whether the kernel uses it — and inflates the cached launch context, causing a measured 42% launch overhead on real workloads. Raise a clear QuadrantsCompilationError guiding users to use a concrete dataclass type annotation instead. Co-authored-by: Cursor --- .../ast/ast_transformers/function_def_transformer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index acefe70e4b..1c0a7e98d7 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -27,6 +27,7 @@ ASTTransformerFuncContext, ) from quadrants.lang.exception import ( + QuadrantsCompilationError, QuadrantsSyntaxError, ) from quadrants.lang.matrix import MatrixType @@ -220,6 +221,16 @@ def _walk_obj(obj, arg_idx, path): _register_ndarray(attr_val, arg_idx, (*path, attr_name)) def _register_ndarray(nd, arg_idx, attr_chain): + param_name = ctx.func.arg_metas[arg_idx].name + attr_path = ".".join(attr_chain) + raise QuadrantsCompilationError( + f"Kernel parameter '{param_name}' is annotated as qd.template(), but " + f"'{param_name}.{attr_path}' is a qd.ndarray. Passing ndarrays through " + f"template structs is not supported because it bypasses argument pruning " + f"and degrades launch performance. Use a concrete struct annotation " + f"(e.g. a @dataclass type hint) instead of qd.template() for struct " + f"parameters that contain ndarrays." + ) key = id(nd) if key in cache: return From 733a160fd1d4973941a2880eb768446cea41b9b8 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 1 May 2026 12:17:56 -0700 Subject: [PATCH 2/2] Update tests to expect QuadrantsCompilationError for ndarray-in-template The 5 ndarray-backend variants of struct-via-template tests now assert that passing a struct containing ndarrays through qd.template() raises QuadrantsCompilationError. Field-backend variants are unchanged. Co-authored-by: Cursor --- .../function_def_transformer.py | 23 ---------- .../python/test_tensor_annotation_in_func.py | 45 ++++++++++++++----- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 1c0a7e98d7..2e0f000a53 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -198,10 +198,6 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: ``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the launch context. """ - from quadrants.lang.util import cook_dtype # pylint: disable=C0415 - - cache = ctx.global_context.ndarray_to_any_array - launch_info = ctx.global_context.struct_ndarray_launch_info def _walk_obj(obj, arg_idx, path): if dataclasses.is_dataclass(obj) and not isinstance(obj, type): @@ -231,25 +227,6 @@ def _register_ndarray(nd, arg_idx, attr_chain): f"(e.g. a @dataclass type hint) instead of qd.template() for struct " f"parameters that contain ndarrays." ) - key = id(nd) - if key in cache: - return - from quadrants._lib import core as _qd_core # pylint: disable=C0415 - - element_type = cook_dtype(nd.element_type) - ndim = len(nd._physical_shape) - needs_grad = nd.grad is not None - layout = getattr(nd, "_qd_layout", None) - name = f"__qd_struct_nd_{key}" - arg_id_vec = impl.get_runtime().compiling_callable.insert_ndarray_param( - element_type, ndim, name, needs_grad - ) - arr = any_array.AnyArray( - _qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE), - _qd_layout=layout, - ) - cache[key] = arr - launch_info.append((arg_id_vec[0], arg_idx, attr_chain)) assert ctx.py_args is not None for i, arg_meta in enumerate(ctx.func.arg_metas): diff --git a/tests/python/test_tensor_annotation_in_func.py b/tests/python/test_tensor_annotation_in_func.py index 059885d7dd..874a1c870c 100644 --- a/tests/python/test_tensor_annotation_in_func.py +++ b/tests/python/test_tensor_annotation_in_func.py @@ -13,6 +13,7 @@ import pytest import quadrants as qd +from quadrants.lang.exception import QuadrantsCompilationError from tests import test_utils @@ -80,8 +81,12 @@ def fill(st: qd.template()): for i in range(N): st.vals[i] = i * 3 - fill(s) - np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 3) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + fill(s) + else: + fill(s) + np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 3) # --------------------------------------------------------------------------- @@ -114,8 +119,12 @@ def run(st: qd.template()): st.vals[i] = i inc_all(st) - run(s) - np.testing.assert_array_equal(t.to_numpy(), np.arange(N) + 10) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + run(s) + else: + run(s) + np.testing.assert_array_equal(t.to_numpy(), np.arange(N) + 10) # --------------------------------------------------------------------------- @@ -148,8 +157,12 @@ def run(st: qd.template()): for i in range(N): write(st, i, i * 7) - run(s) - np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 7) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + run(s) + else: + run(s) + np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 7) # --------------------------------------------------------------------------- @@ -181,8 +194,12 @@ def run(st: qd.template()): for i in range(N): scaled_fill(st, i) - run(s) - np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 5) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + run(s) + else: + run(s) + np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 5) # --------------------------------------------------------------------------- @@ -254,10 +271,14 @@ def fill(st: qd.template()): for j in qd.static(range(3)): st.vals[i][j] = i * 10.0 + j - fill(s) - arr = t.to_numpy() - assert arr.shape == (N, 3) - assert arr[2, 1] == pytest.approx(21.0) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + fill(s) + else: + fill(s) + arr = t.to_numpy() + assert arr.shape == (N, 3) + assert arr[2, 1] == pytest.approx(21.0) # ---------------------------------------------------------------------------