@@ -129,24 +129,6 @@ def kernel(X):
129129 assert_equal (x , ref )
130130
131131
132- @pytest .mark .parametrize ("value,expected_dtype" , [
133- ((1 ,), ct .int32 ),
134- ((- 2 ** 42 ,), ct .int64 ),
135- ((2 ** 63 ,), ct .uint64 ),
136- ((2.5 ,), ct .float32 ),
137- ((True ,), ct .bool_ ),
138- ((1 , True ), ct .int32 ),
139- ((1 , 2.5 , True , 2 ** 63 ), ct .float32 ),
140- ])
141- def test_astile_dtype_infer_const (value , expected_dtype ):
142- @ct .kernel
143- def kernel ():
144- t = ct .astile (value )
145- ct .static_assert (t .dtype == expected_dtype )
146-
147- ct .launch (torch .cuda .current_stream (), (1 ,), kernel , ())
148-
149-
150132def test_astile_scalar_runtime ():
151133 @ct .kernel
152134 def kernel (X , a : float ):
@@ -182,24 +164,6 @@ def kernel(X, a: int, b: int, c: float, d: bool):
182164 assert_equal (x , ref )
183165
184166
185- @pytest .mark .parametrize ("ann1,ann2,val1,val2,expected_dtype" , [
186- (int , int , 1 , 2 , ct .int32 ),
187- (int , ct .ScalarInt64 , 1 , 2 , ct .int64 ),
188- (float , float , 1.5 , 2.5 , ct .float32 ),
189- (bool , bool , True , False , ct .bool_ ),
190- (int , float , 1 , 2.5 , ct .float32 ),
191- (int , bool , 1 , True , ct .int32 ),
192- (float , bool , 2.5 , True , ct .float32 ),
193- ])
194- def test_astile_dtype_infer_runtime (ann1 , ann2 , val1 , val2 , expected_dtype ):
195- @ct .kernel
196- def kernel (a : ann1 , b : ann2 ):
197- t = ct .astile ((a , b ))
198- ct .static_assert (t .dtype == expected_dtype )
199-
200- ct .launch (torch .cuda .current_stream (), (1 ,), kernel , (val1 , val2 ))
201-
202-
203167def test_astile_3d_mixed ():
204168 @ct .kernel
205169 def kernel (X , a : int , b : int , c : float , d : bool ):
@@ -263,7 +227,7 @@ def kernel(X):
263227def test_astile_top_level_not_supported ():
264228 @ct .kernel
265229 def kernel ():
266- ct .astile (ct .full ((4 ,), 1 , dtype = ct .int32 ))
230+ ct .astile (ct .full ((4 ,), 1 , dtype = ct .int32 ), dtype = ct . int32 )
267231 with pytest .raises (TileTypeError ,
268232 match = r"Expected a scalar or \(possibly nested\) tuple of scalars" ):
269233 ct .launch (torch .cuda .current_stream (), (1 ,), kernel , ())
0 commit comments