diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py index 98a1206b20..240910bf01 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py @@ -8,6 +8,7 @@ from __future__ import annotations +import dataclasses from typing import Optional, Sequence, TypeAlias import dace @@ -19,15 +20,24 @@ from gt4py.next.program_processors.runners.dace import gtir_to_sdfg_utils -FieldopDomain: TypeAlias = list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] -] -""" -Domain of a field operator represented as a list of tuples with 3 elements: - - dimension definition - - symbolic expression for lower bound (inclusive) - - symbolic expression for upper bound (exclusive) -""" +@dataclasses.dataclass(frozen=True) +class FieldopDomainRange: + """ + Represents the range of a field operator domain in one dimension. + + It contains 3 elements: + dim: dimension definition + start: symbolic expression for lower bound (inclusive) + stop: symbolic expression for upper bound (exclusive) + """ + + dim: gtx_common.Dimension + start: dace.symbolic.SymbolicType + stop: dace.symbolic.SymbolicType + + +FieldopDomain: TypeAlias = list[FieldopDomainRange] +"""Domain of a field operator represented as a list of `FieldopDomainRange` for each dimension.""" def extract_domain(node: gtir.Expr) -> FieldopDomain: @@ -49,12 +59,12 @@ def extract_domain(node: gtir.Expr) -> FieldopDomain: gtir_to_sdfg_utils.get_symbolic(arg) for arg in named_range.args[1:3] ) dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, lower_bound, upper_bound)) + domain.append(FieldopDomainRange(dim, lower_bound, upper_bound)) elif isinstance(node, domain_utils.SymbolicDomain): for dim, drange in node.ranges.items(): domain.append( - ( + FieldopDomainRange( dim, gtir_to_sdfg_utils.get_symbolic(drange.start), gtir_to_sdfg_utils.get_symbolic(drange.stop), @@ -119,6 +129,7 @@ def get_field_layout( """ if len(domain) == 0: return [], [], [] - domain_dims, domain_lbs, domain_ubs = zip(*domain) - domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)] - return list(domain_dims), list(domain_lbs), domain_sizes + domain_dims = [domain_range.dim for domain_range in domain] + domain_origin = [domain_range.start for domain_range in domain] + domain_shape = [(domain_range.stop - domain_range.start) for domain_range in domain] + return domain_dims, domain_origin, domain_shape diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py index 26234b04c5..54118d2cd2 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py @@ -197,7 +197,9 @@ def _make_access_index_for_field( # since the access indices have to follow the order of dimensions in field domain if isinstance(data.gt_type, ts.FieldType) and len(data.gt_type.dims) != 0: assert data.origin is not None - domain_ranges = {dim: (lb, ub) for dim, lb, ub in domain} + domain_ranges = { + domain_range.dim: (domain_range.start, domain_range.stop) for domain_range in domain + } return dace.subsets.Range( (domain_ranges[dim][0] - origin, domain_ranges[dim][1] - origin - 1, 1) for dim, origin in zip(data.gt_type.dims, data.origin, strict=True) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py index 3053165003..feecc8a5e8 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py @@ -185,7 +185,7 @@ def _translate_concat_where_impl( gtir_domain.extract_domain(domain) for domain in [tb_node_domain, fb_node_domain] ) assert len(mask_domain) == 1 - concat_dim, mask_lower_bound, mask_upper_bound = mask_domain[0] + concat_domain = mask_domain[0] # Expect unbound range in the concat domain expression on lower or upper range: # - if the domain expression is unbound on lower side (negative infinite), @@ -193,12 +193,12 @@ def _translate_concat_where_impl( # lower domain. # - viceversa, if the domain expression is unbound on upper side (positive # infinite), the true expression represents the input for the upper domain. - if mask_lower_bound == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.NEGATIVE): - concat_dim_bound = mask_upper_bound + if concat_domain.start == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.NEGATIVE): + concat_dim_bound = concat_domain.stop lower, lower_desc, lower_domain = (tb_field, tb_data_desc, tb_domain) upper, upper_desc, upper_domain = (fb_field, fb_data_desc, fb_domain) - elif mask_upper_bound == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.POSITIVE): - concat_dim_bound = mask_lower_bound + elif concat_domain.stop == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.POSITIVE): + concat_dim_bound = concat_domain.start lower, lower_desc, lower_domain = (fb_field, fb_data_desc, fb_domain) upper, upper_desc, upper_domain = (tb_field, tb_data_desc, tb_domain) else: @@ -207,9 +207,9 @@ def _translate_concat_where_impl( # we use the concat domain, stored in the annex, as the domain of output field output_domain = gtir_domain.extract_domain(node_domain) output_dims, output_origin, output_shape = _get_concat_where_field_layout( - output_domain, concat_dim + output_domain, concat_domain.dim ) - concat_dim_index = output_dims.index(concat_dim) + concat_dim_index = output_dims.index(concat_domain.dim) """ In case one of the arguments is a scalar value, for example: @@ -225,23 +225,27 @@ def testee(a: np.int32, b: cases.IJKField) -> cases.IJKField: assert isinstance(upper.gt_type, ts.FieldType) lower = gtir_to_sdfg_types.FieldopData( lower.dc_node, - ts.FieldType(dims=[concat_dim], dtype=lower.gt_type), + ts.FieldType(dims=[concat_domain.dim], dtype=lower.gt_type), origin=(concat_dim_bound - 1,), ) - lower_bound = output_domain[concat_dim_index][1] - lower_domain = [(concat_dim, lower_bound, concat_dim_bound)] + lower_bound = output_domain[concat_dim_index].start + lower_domain = [ + gtir_domain.FieldopDomainRange(concat_domain.dim, lower_bound, concat_dim_bound) + ] elif isinstance(upper.gt_type, ts.ScalarType): assert len(upper_domain) == 0 assert isinstance(lower.gt_type, ts.FieldType) upper = gtir_to_sdfg_types.FieldopData( upper.dc_node, - ts.FieldType(dims=[concat_dim], dtype=upper.gt_type), + ts.FieldType(dims=[concat_domain.dim], dtype=upper.gt_type), origin=(concat_dim_bound,), ) - upper_bound = output_domain[concat_dim_index][2] - upper_domain = [(concat_dim, concat_dim_bound, upper_bound)] + upper_bound = output_domain[concat_dim_index].stop + upper_domain = [ + gtir_domain.FieldopDomainRange(concat_domain.dim, concat_dim_bound, upper_bound) + ] - if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr] + if concat_domain.dim not in lower.gt_type.dims: # type: ignore[union-attr] """ The field on the lower domain is to be treated as a slice to add as one level in the concat dimension, on the lower bound. @@ -261,13 +265,22 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: ] ) lower, lower_desc = _make_concat_field_slice( - sdfg, state, lower, lower_desc, concat_dim, concat_dim_index, concat_dim_bound - 1 + sdfg=sdfg, + state=state, + field=lower, + field_desc=lower_desc, + concat_dim=concat_domain.dim, + concat_dim_index=concat_dim_index, + concat_dim_origin=concat_dim_bound - 1, ) lower_bound = dace.symbolic.pystr_to_symbolic( - f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index][1]})" + f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index].start})" + ) + lower_domain.insert( + concat_dim_index, + gtir_domain.FieldopDomainRange(concat_domain.dim, lower_bound, concat_dim_bound), ) - lower_domain.insert(concat_dim_index, (concat_dim, lower_bound, concat_dim_bound)) - elif concat_dim not in upper.gt_type.dims: # type: ignore[union-attr] + elif concat_domain.dim not in upper.gt_type.dims: # type: ignore[union-attr] # Same as previous case, but the field slice is added on the upper bound. assert ( upper.gt_type.dims # type: ignore[union-attr] @@ -277,12 +290,21 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: ] ) upper, upper_desc = _make_concat_field_slice( - sdfg, state, upper, upper_desc, concat_dim, concat_dim_index, concat_dim_bound + sdfg=sdfg, + state=state, + field=upper, + field_desc=upper_desc, + concat_dim=concat_domain.dim, + concat_dim_index=concat_dim_index, + concat_dim_origin=concat_dim_bound, ) upper_bound = dace.symbolic.pystr_to_symbolic( - f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index][2]})" + f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index].stop})" + ) + upper_domain.insert( + concat_dim_index, + gtir_domain.FieldopDomainRange(concat_domain.dim, concat_dim_bound, upper_bound), ) - upper_domain.insert(concat_dim_index, (concat_dim, concat_dim_bound, upper_bound)) elif isinstance(lower_desc, dace.data.Scalar) or ( len(lower.gt_type.dims) == 1 and len(output_domain) > 1 # type: ignore[union-attr] ): @@ -297,27 +319,37 @@ def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: return concat_where(KDim == 0, a, b) ``` """ - assert len(lower_domain) == 1 and lower_domain[0][0] == concat_dim + assert len(lower_domain) == 1 and lower_domain[0].dim == concat_domain.dim lower_domain = [ *output_domain[:concat_dim_index], lower_domain[0], *output_domain[concat_dim_index + 1 :], ] lower, lower_desc = _make_concat_scalar_broadcast( - sdfg, state, lower, lower_desc, lower_domain, concat_dim_index + sdfg=sdfg, + state=state, + inp=lower, + inp_desc=lower_desc, + domain=lower_domain, + concat_dim_index=concat_dim_index, ) elif isinstance(upper_desc, dace.data.Scalar) or ( len(upper.gt_type.dims) == 1 and len(output_domain) > 1 # type: ignore[union-attr] ): # Same as previous case, but the scalar value is taken from `upper` input. - assert len(upper_domain) == 1 and upper_domain[0][0] == concat_dim + assert len(upper_domain) == 1 and upper_domain[0].dim == concat_domain.dim upper_domain = [ *output_domain[:concat_dim_index], upper_domain[0], *output_domain[concat_dim_index + 1 :], ] upper, upper_desc = _make_concat_scalar_broadcast( - sdfg, state, upper, upper_desc, upper_domain, concat_dim_index + sdfg=sdfg, + state=state, + inp=upper, + inp_desc=upper_desc, + domain=upper_domain, + concat_dim_index=concat_dim_index, ) else: """ @@ -341,15 +373,15 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: # ensure that the arguments have the same domain as the concat result assert all(ftype.dims == output_dims for ftype in (lower.gt_type, upper.gt_type)) # type: ignore[union-attr] - lower_range_0 = output_domain[concat_dim_index][1] + lower_range_0 = output_domain[concat_dim_index].start lower_range_1 = dace.symbolic.pystr_to_symbolic( - f"max({lower_range_0}, {lower_domain[concat_dim_index][2]})" + f"max({lower_range_0}, {lower_domain[concat_dim_index].stop})" ) lower_range_size = lower_range_1 - lower_range_0 - upper_range_1 = output_domain[concat_dim_index][2] + upper_range_1 = output_domain[concat_dim_index].stop upper_range_0 = dace.symbolic.pystr_to_symbolic( - f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})" + f"min({upper_range_1}, {upper_domain[concat_dim_index].start})" ) upper_range_size = upper_range_1 - upper_range_0 @@ -391,15 +423,15 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: else: lower_subset.append( ( - output_domain[dim_index][1] - lower.origin[dim_index], - output_domain[dim_index][1] - lower.origin[dim_index] + size - 1, + output_domain[dim_index].start - lower.origin[dim_index], + output_domain[dim_index].start - lower.origin[dim_index] + size - 1, 1, ) ) upper_subset.append( ( - output_domain[dim_index][1] - upper.origin[dim_index], - output_domain[dim_index][1] - upper.origin[dim_index] + size - 1, + output_domain[dim_index].start - upper.origin[dim_index], + output_domain[dim_index].start - upper.origin[dim_index] + size - 1, 1, ) ) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py index fe75b3d0ad..911b457bab 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py @@ -201,8 +201,10 @@ def _create_field_operator( else: # create map range corresponding to the field operator domain map_range = { - gtir_to_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain + gtir_to_sdfg_utils.get_map_variable( + domain_range.dim + ): f"{domain_range.start}:{domain_range.stop}" + for domain_range in domain } map_entry, map_exit = sdfg_builder.add_map("fieldop", state, map_range) @@ -511,8 +513,7 @@ def translate_index( assert "domain" in node.annex domain = gtir_domain.extract_domain(node.annex.domain) assert len(domain) == 1 - dim, _, _ = domain[0] - dim_index = gtir_to_sdfg_utils.get_map_variable(dim) + dim_index = gtir_to_sdfg_utils.get_map_variable(domain[0].dim) index_data, _ = sdfg_builder.add_temp_scalar(sdfg, gtir_to_sdfg_types.INDEX_DTYPE) index_node = state.add_access(index_data) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py index ce819ff470..589a2750e5 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py @@ -210,9 +210,11 @@ def _create_scan_field_operator( "fieldop", state, ndrange={ - gtir_to_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - if not sdfg_builder.is_column_axis(dim) + gtir_to_sdfg_utils.get_map_variable( + domain_range.dim + ): f"{domain_range.start}:{domain_range.stop}" + for domain_range in domain + if not sdfg_builder.is_column_axis(domain_range.dim) }, ) @@ -329,22 +331,18 @@ def _lower_lambda_to_nested_sdfg( ) # use the vertical dimension in the domain as scan dimension - scan_domain = [ - (dim, lower_bound, upper_bound) - for dim, lower_bound, upper_bound in domain - if sdfg_builder.is_column_axis(dim) - ] - assert len(scan_domain) == 1 - scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0] + scan_domain = next( + domain_range for domain_range in domain if sdfg_builder.is_column_axis(domain_range.dim) + ) # extract the scan loop range - scan_loop_var = gtir_to_sdfg_utils.get_map_variable(scan_dim) + scan_loop_var = gtir_to_sdfg_utils.get_map_variable(scan_domain.dim) # in case the scan operator computes a list (not a scalar), we need to add an extra dimension def get_scan_output_shape( scan_init_data: gtir_to_sdfg_types.FieldopData, ) -> list[dace.symbolic.SymExpr]: - scan_column_size = scan_upper_bound - scan_lower_bound + scan_column_size = scan_domain.stop - scan_domain.start if isinstance(scan_init_data.gt_type, ts.ScalarType): return [scan_column_size] assert isinstance(scan_init_data.gt_type, ts.ListType) @@ -391,18 +389,18 @@ def init_scan_carry(sym: gtir.Sym) -> None: if scan_forward: scan_loop = dace.sdfg.state.LoopRegion( label="scan", - condition_expr=f"{scan_loop_var} < {scan_upper_bound}", + condition_expr=f"{scan_loop_var} < {scan_domain.stop}", loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_lower_bound}", + initialize_expr=f"{scan_loop_var} = {scan_domain.start}", update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", inverted=False, ) else: scan_loop = dace.sdfg.state.LoopRegion( label="scan", - condition_expr=f"{scan_loop_var} >= {scan_lower_bound}", + condition_expr=f"{scan_loop_var} >= {scan_domain.start}", loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_upper_bound} - 1", + initialize_expr=f"{scan_loop_var} = {scan_domain.stop} - 1", update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", inverted=False, ) @@ -431,7 +429,7 @@ def init_scan_carry(sym: gtir.Sym) -> None: for edge in lambda_input_edges: edge.connect(map_entry=None) # connect the dataflow output nodes, called 'scan_result' below, to a global field called 'output' - output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_lower_bound + output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_domain.start def connect_scan_output( scan_output_edge: gtir_dataflow.DataflowOutputEdge, @@ -475,8 +473,8 @@ def connect_scan_output( dace.Memlet.from_array(scan_result_data, scan_result_desc), ) - output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) - return gtir_to_sdfg_types.FieldopData(output_node, output_type, origin=(scan_lower_bound,)) + output_type = ts.FieldType(dims=[scan_domain.dim], dtype=scan_result.gt_dtype) + return gtir_to_sdfg_types.FieldopData(output_node, output_type, origin=(scan_domain.start,)) # write the stencil result (value on one vertical level) into a 1D field # with full vertical shape representing one column diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py index 47f8d0c5d6..9f39b2689d 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py @@ -115,7 +115,7 @@ def get_local_view( # The `make_field` constructor converts any local dimension, if present, to `ListType` # element type, while leaving the field domain with all global dimensions. assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims) - domain_dims = [dim for dim, _, _ in domain] + domain_dims = [domain_range.dim for domain_range in domain] domain_indices = gtir_domain.get_domain_indices(domain_dims, origin=None) it_indices = { dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE)