diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 4c9bd5115b..da56f9918d 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -28,6 +28,7 @@ ) from quadrants.lang.buffer_view import BufferView from quadrants.lang.exception import ( + QuadrantsCompilationError, QuadrantsSyntaxError, ) from quadrants.lang.matrix import MatrixType @@ -211,10 +212,6 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: ``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the launch context. """ - from quadrants.lang.util import cook_dtype # pylint: disable=C0415 - - cache = ctx.global_context.ndarray_to_any_array - launch_info = ctx.global_context.struct_ndarray_launch_info def _walk_obj(obj, arg_idx, path): if dataclasses.is_dataclass(obj) and not isinstance(obj, type): @@ -234,25 +231,16 @@ def _walk_obj(obj, arg_idx, path): _register_ndarray(attr_val, arg_idx, (*path, attr_name)) def _register_ndarray(nd, arg_idx, attr_chain): - key = id(nd) - if key in cache: - return - from quadrants._lib import core as _qd_core # pylint: disable=C0415 - - element_type = cook_dtype(nd.element_type) - ndim = len(nd._physical_shape) - needs_grad = nd.grad is not None - layout = getattr(nd, "_qd_layout", None) - name = f"__qd_struct_nd_{key}" - arg_id_vec = impl.get_runtime().compiling_callable.insert_ndarray_param( - element_type, ndim, name, needs_grad - ) - arr = any_array.AnyArray( - _qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE), - _qd_layout=layout, + param_name = ctx.func.arg_metas[arg_idx].name + attr_path = ".".join(attr_chain) + raise QuadrantsCompilationError( + f"Kernel parameter '{param_name}' is annotated as qd.template(), but " + f"'{param_name}.{attr_path}' is a qd.ndarray. Passing ndarrays through " + f"template structs is not supported because it bypasses argument pruning " + f"and degrades launch performance. Use a concrete struct annotation " + f"(e.g. a @dataclass type hint) instead of qd.template() for struct " + f"parameters that contain ndarrays." ) - cache[key] = arr - launch_info.append((arg_id_vec[0], arg_idx, attr_chain)) assert ctx.py_args is not None for i, arg_meta in enumerate(ctx.func.arg_metas): diff --git a/tests/python/test_tensor_annotation_in_func.py b/tests/python/test_tensor_annotation_in_func.py index 059885d7dd..874a1c870c 100644 --- a/tests/python/test_tensor_annotation_in_func.py +++ b/tests/python/test_tensor_annotation_in_func.py @@ -13,6 +13,7 @@ import pytest import quadrants as qd +from quadrants.lang.exception import QuadrantsCompilationError from tests import test_utils @@ -80,8 +81,12 @@ def fill(st: qd.template()): for i in range(N): st.vals[i] = i * 3 - fill(s) - np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 3) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + fill(s) + else: + fill(s) + np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 3) # --------------------------------------------------------------------------- @@ -114,8 +119,12 @@ def run(st: qd.template()): st.vals[i] = i inc_all(st) - run(s) - np.testing.assert_array_equal(t.to_numpy(), np.arange(N) + 10) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + run(s) + else: + run(s) + np.testing.assert_array_equal(t.to_numpy(), np.arange(N) + 10) # --------------------------------------------------------------------------- @@ -148,8 +157,12 @@ def run(st: qd.template()): for i in range(N): write(st, i, i * 7) - run(s) - np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 7) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + run(s) + else: + run(s) + np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 7) # --------------------------------------------------------------------------- @@ -181,8 +194,12 @@ def run(st: qd.template()): for i in range(N): scaled_fill(st, i) - run(s) - np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 5) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + run(s) + else: + run(s) + np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 5) # --------------------------------------------------------------------------- @@ -254,10 +271,14 @@ def fill(st: qd.template()): for j in qd.static(range(3)): st.vals[i][j] = i * 10.0 + j - fill(s) - arr = t.to_numpy() - assert arr.shape == (N, 3) - assert arr[2, 1] == pytest.approx(21.0) + if backend == qd.Backend.NDARRAY: + with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"): + fill(s) + else: + fill(s) + arr = t.to_numpy() + assert arr.shape == (N, 3) + assert arr[2, 1] == pytest.approx(21.0) # ---------------------------------------------------------------------------