Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from quadrants.lang.buffer_view import BufferView
from quadrants.lang.exception import (
QuadrantsCompilationError,
QuadrantsSyntaxError,
)
from quadrants.lang.matrix import MatrixType
Expand Down Expand Up @@ -211,10 +212,6 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None:
``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the
launch context.
"""
from quadrants.lang.util import cook_dtype # pylint: disable=C0415

cache = ctx.global_context.ndarray_to_any_array
launch_info = ctx.global_context.struct_ndarray_launch_info

def _walk_obj(obj, arg_idx, path):
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
Expand All @@ -234,25 +231,16 @@ def _walk_obj(obj, arg_idx, path):
_register_ndarray(attr_val, arg_idx, (*path, attr_name))

def _register_ndarray(nd, arg_idx, attr_chain):
key = id(nd)
if key in cache:
return
from quadrants._lib import core as _qd_core # pylint: disable=C0415

element_type = cook_dtype(nd.element_type)
ndim = len(nd._physical_shape)
needs_grad = nd.grad is not None
layout = getattr(nd, "_qd_layout", None)
name = f"__qd_struct_nd_{key}"
arg_id_vec = impl.get_runtime().compiling_callable.insert_ndarray_param(
element_type, ndim, name, needs_grad
)
arr = any_array.AnyArray(
_qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE),
_qd_layout=layout,
param_name = ctx.func.arg_metas[arg_idx].name
attr_path = ".".join(attr_chain)
raise QuadrantsCompilationError(
f"Kernel parameter '{param_name}' is annotated as qd.template(), but "
f"'{param_name}.{attr_path}' is a qd.ndarray. Passing ndarrays through "
f"template structs is not supported because it bypasses argument pruning "
f"and degrades launch performance. Use a concrete struct annotation "
f"(e.g. a @dataclass type hint) instead of qd.template() for struct "
Comment on lines +236 to +241
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Document new template-struct ndarray compile failure

This change introduces a user-visible API behavior change by turning previously valid qd.template() struct usage with qd.ndarray fields into a hard QuadrantsCompilationError in _predeclare_struct_ndarrays, but the commit does not update any user docs under docs/. Per AGENTS.md ("Keep user-facing docs in sync with public API changes"), this should be documented so users understand the new restriction and migration path.

Useful? React with 👍 / 👎.

f"parameters that contain ndarrays."
)
cache[key] = arr
launch_info.append((arg_id_vec[0], arg_idx, attr_chain))

assert ctx.py_args is not None
for i, arg_meta in enumerate(ctx.func.arg_metas):
Expand Down
45 changes: 33 additions & 12 deletions tests/python/test_tensor_annotation_in_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest

import quadrants as qd
from quadrants.lang.exception import QuadrantsCompilationError

from tests import test_utils

Expand Down Expand Up @@ -80,8 +81,12 @@ def fill(st: qd.template()):
for i in range(N):
st.vals[i] = i * 3

fill(s)
np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 3)
if backend == qd.Backend.NDARRAY:
with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"):
fill(s)
else:
fill(s)
np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 3)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -114,8 +119,12 @@ def run(st: qd.template()):
st.vals[i] = i
inc_all(st)

run(s)
np.testing.assert_array_equal(t.to_numpy(), np.arange(N) + 10)
if backend == qd.Backend.NDARRAY:
with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"):
run(s)
else:
run(s)
np.testing.assert_array_equal(t.to_numpy(), np.arange(N) + 10)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -148,8 +157,12 @@ def run(st: qd.template()):
for i in range(N):
write(st, i, i * 7)

run(s)
np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 7)
if backend == qd.Backend.NDARRAY:
with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"):
run(s)
else:
run(s)
np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 7)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -181,8 +194,12 @@ def run(st: qd.template()):
for i in range(N):
scaled_fill(st, i)

run(s)
np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 5)
if backend == qd.Backend.NDARRAY:
with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"):
run(s)
else:
run(s)
np.testing.assert_array_equal(t.to_numpy(), np.arange(N) * 5)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -254,10 +271,14 @@ def fill(st: qd.template()):
for j in qd.static(range(3)):
st.vals[i][j] = i * 10.0 + j

fill(s)
arr = t.to_numpy()
assert arr.shape == (N, 3)
assert arr[2, 1] == pytest.approx(21.0)
if backend == qd.Backend.NDARRAY:
with pytest.raises(QuadrantsCompilationError, match="qd.template.*qd.ndarray"):
fill(s)
else:
fill(s)
arr = t.to_numpy()
assert arr.shape == (N, 3)
assert arr[2, 1] == pytest.approx(21.0)


# ---------------------------------------------------------------------------
Expand Down
Loading