Skip to content

Commit ced0707

Browse files
committed
astile requires dtype
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent ef013c0 commit ced0707

4 files changed

Lines changed: 6 additions & 54 deletions

File tree

src/cuda/tile/_ir/op_impl.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,12 +323,6 @@ def require_dtype_spec(var: Var) -> DType:
323323
return ty.dtype
324324

325325

326-
def require_optional_dtype_spec(var: Var) -> DType | None:
327-
if var.is_constant() and var.get_constant() is None:
328-
return None
329-
return require_dtype_spec(var)
330-
331-
332326
def require_constant_pointer_info(var: Var) -> PointerInfo:
333327
ty = var.get_type()
334328
if not isinstance(ty, PointerInfoTy):

src/cuda/tile/_ir/ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from .hir import ResolvedName
3737
from .op_impl import (
3838
ImplRegistry, is_0d_tile, require_constant_int, require_constant_int_tuple,
39-
require_optional_dtype_spec,
4039
require_signed_integer_0d_tile_type,
4140
require_tile_type, normalize_axis, require_dtype_spec,
4241
require_constant_bool, require_optional_constant_enum,
@@ -3413,10 +3412,10 @@ def _cat_tuple(tiles: tuple[Var, ...]) -> Var:
34133412

34143413
@impl(ct.astile)
34153414
def astile_impl(value: Var, dtype: Var) -> Var:
3416-
dtype: Optional[DType] = require_optional_dtype_spec(dtype)
3415+
dtype = require_dtype_spec(dtype)
34173416
value_ty = value.get_type()
34183417
if is_0d_tile(value_ty):
3419-
return value if dtype is None else astype(value, dtype)
3418+
return astype(value, dtype)
34203419

34213420
if not isinstance(value_ty, TupleTy):
34223421
raise TileTypeError(
@@ -3425,9 +3424,6 @@ def astile_impl(value: Var, dtype: Var) -> Var:
34253424

34263425
shape = _tuple_shape(value_ty, path=())
34273426
tiles = _flatten_tuple(value)
3428-
dtype = (functools.reduce(promote_dtypes, (require_0d_tile_type(t).dtype for t in tiles))
3429-
if dtype is None
3430-
else dtype)
34313427

34323428
if value.is_constant():
34333429
return _const(shape, value.get_constant(), dtype)

src/cuda/tile/_stub.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,20 +1962,18 @@ def zeros(shape, dtype) -> Tile:
19621962

19631963

19641964
@stub
1965-
def astile(value, /, *, dtype: Optional[DType] = None) -> Tile:
1965+
def astile(value, /, *, dtype: DType) -> Tile:
19661966
"""Creates a tile from a value.
19671967
19681968
Args:
19691969
value (scalar | (nested) tuple of scalar): A scalar (yielding a 0-d tile),
19701970
or a (possibly nested) tuple of scalars whose nesting determines the
19711971
tile's shape. Every tuple's length must be a power of two, and sibling tuples
19721972
at each level must have uniform length.
1973-
dtype (DType, optional): The |Data type| of the tile. If ``None``, the
1974-
dtype is inferred from ``value``.
1973+
dtype (DType): The |Data type| of the tile.
19751974
19761975
Returns:
1977-
Tile: A tile shaped from ``value``, with elements cast to ``dtype`` if
1978-
given.
1976+
Tile: A tile shaped from ``value``, with elements cast to ``dtype``.
19791977
19801978
Examples:
19811979

test/test_astile.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -129,24 +129,6 @@ def kernel(X):
129129
assert_equal(x, ref)
130130

131131

132-
@pytest.mark.parametrize("value,expected_dtype", [
133-
((1,), ct.int32),
134-
((-2**42,), ct.int64),
135-
((2**63,), ct.uint64),
136-
((2.5,), ct.float32),
137-
((True,), ct.bool_),
138-
((1, True), ct.int32),
139-
((1, 2.5, True, 2**63), ct.float32),
140-
])
141-
def test_astile_dtype_infer_const(value, expected_dtype):
142-
@ct.kernel
143-
def kernel():
144-
t = ct.astile(value)
145-
ct.static_assert(t.dtype == expected_dtype)
146-
147-
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
148-
149-
150132
def test_astile_scalar_runtime():
151133
@ct.kernel
152134
def kernel(X, a: float):
@@ -182,24 +164,6 @@ def kernel(X, a: int, b: int, c: float, d: bool):
182164
assert_equal(x, ref)
183165

184166

185-
@pytest.mark.parametrize("ann1,ann2,val1,val2,expected_dtype", [
186-
(int, int, 1, 2, ct.int32),
187-
(int, ct.ScalarInt64, 1, 2, ct.int64),
188-
(float, float, 1.5, 2.5, ct.float32),
189-
(bool, bool, True, False, ct.bool_),
190-
(int, float, 1, 2.5, ct.float32),
191-
(int, bool, 1, True, ct.int32),
192-
(float, bool, 2.5, True, ct.float32),
193-
])
194-
def test_astile_dtype_infer_runtime(ann1, ann2, val1, val2, expected_dtype):
195-
@ct.kernel
196-
def kernel(a: ann1, b: ann2):
197-
t = ct.astile((a, b))
198-
ct.static_assert(t.dtype == expected_dtype)
199-
200-
ct.launch(torch.cuda.current_stream(), (1,), kernel, (val1, val2))
201-
202-
203167
def test_astile_3d_mixed():
204168
@ct.kernel
205169
def kernel(X, a: int, b: int, c: float, d: bool):
@@ -263,7 +227,7 @@ def kernel(X):
263227
def test_astile_top_level_not_supported():
264228
@ct.kernel
265229
def kernel():
266-
ct.astile(ct.full((4,), 1, dtype=ct.int32))
230+
ct.astile(ct.full((4,), 1, dtype=ct.int32), dtype=ct.int32)
267231
with pytest.raises(TileTypeError,
268232
match=r"Expected a scalar or \(possibly nested\) tuple of scalars"):
269233
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())

0 commit comments

Comments
 (0)