diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index dc6f24e9dd..30726ee862 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -109,6 +109,41 @@ def __add__(self, offset: int) -> Connectivity: def __sub__(self, offset: int) -> Connectivity: return self + (-offset) + def __gt__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),)) + + def __ge__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),)) + + def __lt__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),)) + + def __le__(self, value: core_defs.IntegralScalar) -> Domain: + # TODO add test + return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),)) + + def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain: + if isinstance(value, Dimension): + return self.value == value.value + elif isinstance(value, core_defs.INTEGRAL_TYPES): + # TODO probably only within valid embedded context? + return Domain(dims=(self,), ranges=(UnitRange(value, value + 1),)) + else: + return False + + def __ne__(self, value: Dimension | core_defs.IntegralScalar) -> bool | tuple[Domain, Domain]: + # TODO add test + if isinstance(value, Dimension): + return self.value != value.value + elif isinstance(value, core_defs.INTEGRAL_TYPES): + # TODO probably only within valid embedded context? + return ( + Domain(self, UnitRange(Infinity.NEGATIVE, value)), + Domain(self, UnitRange(value + 1, Infinity.POSITIVE)), + ) + else: + return True + class Infinity(enum.Enum): """Describes an unbounded `UnitRange`.""" @@ -500,6 +535,24 @@ def __and__(self, other: Domain) -> Domain: ) return Domain(dims=broadcast_dims, ranges=intersected_ranges) + def __or__(self, other: Domain) -> Domain: + # TODO support arbitrary union of domains + # TODO add tests + if self.ndim > 1 or other.ndim > 1: + raise NotImplementedError("Union of multidimensional domains is not supported.") + if self.ndim == 0: + return other + if other.ndim == 0: + return self + sorted_ = sorted((self, other), key=lambda x: x.ranges[0].start) + if sorted_[0].ranges[0].stop >= sorted_[1].ranges[0].start: + return Domain( + dims=(self.dims[0],), + ranges=(UnitRange(sorted_[0].ranges[0].start, sorted_[1].ranges[0].stop),), + ) + else: + return (sorted_[0], sorted_[1]) + @functools.cached_property def slice_at(self) -> utils.IndexerCallable[slice, Domain]: """ diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 25ce060c7c..a2101e6c99 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -810,25 +810,6 @@ def _hyperslice( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _compute_mask_slices( - mask: core_defs.NDArrayObject, -) -> list[tuple[bool, slice]]: - """Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices.""" - # TODO: does it make sense to upgrade this naive algorithm to numpy? - assert mask.ndim == 1 - cur = bool(mask[0].item()) - ind = 0 - res = [] - for i in range(1, mask.shape[0]): - # Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy - if (mask_i := bool(mask[i].item())) != cur: - res.append((cur, slice(ind, i))) - cur = mask_i - ind = i - res.append((cur, slice(ind, mask.shape[0]))) - return res - - def _trim_empty_domains( lst: Iterable[tuple[bool, common.Domain]], ) -> list[tuple[bool, common.Domain]]: @@ -896,82 +877,108 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: # TODO(havogt): this function could be extended to a general concat - # currently only concatenate along the given dimension and requires the fields to be ordered + # currently only concatenate along the given dimension + sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start) if ( - len(fields) > 1 - and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty() + len(sorted_fields) > 1 + and not embedded_common.domain_intersection(*[f.domain for f in sorted_fields]).is_empty() ): raise ValueError("Fields to concatenate must not overlap.") - new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) + new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim) if new_domain is None: raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") - nd_array_class = _get_nd_array_class(*fields) + nd_array_class = _get_nd_array_class(*sorted_fields) return nd_array_class.from_array( nd_array_class.array_ns.concatenate( - [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], + [ + nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) + for f in sorted_fields + ], axis=new_domain.dim_index(dim, allow_missing=False), ), domain=new_domain, ) +def _invert_domain( + domains: common.Domain | tuple[common.Domain], +) -> common.Domain | tuple[common.Domain, ...]: + if not isinstance(domains, tuple): + domains = (domains,) + + assert all(d.ndim == 1 for d in domains) + dim = domains[0].dims[0] + assert all(d.dims[0] == dim for d in domains) + sorted_domains = sorted(domains, key=lambda d: d.ranges[0].start) + + result = [] + if domains[0].ranges[0].start is not common.Infinity.NEGATIVE: + result.append( + common.Domain( + dims=(dim,), + ranges=(common.UnitRange(common.Infinity.NEGATIVE, domains[0].ranges[0].start),), + ) + ) + for i in range(len(sorted_domains) - 1): + if sorted_domains[i].ranges[0].stop != sorted_domains[i + 1].ranges[0].start: + result.append( + common.Domain( + dims=(dim,), + ranges=( + common.UnitRange( + sorted_domains[i].ranges[0].stop, sorted_domains[i + 1].ranges[0].start + ), + ), + ) + ) + if domains[-1].ranges[0].stop is not common.Infinity.POSITIVE: + result.append( + common.Domain( + dims=(dim,), + ranges=(common.UnitRange(domains[-1].ranges[0].stop, common.Infinity.POSITIVE),), + ) + ) + return tuple(result) + + +def _intersect_multiple( + domain: common.Domain, domains: common.Domain | tuple[common.Domain] +) -> tuple[common.Domain, ...]: + if not isinstance(domains, tuple): + domains = (domains,) + + return tuple( + intersection + for d in domains + if not (intersection := embedded_common.domain_intersection(domain, d)).is_empty() + ) + + def _concat_where( - mask_field: common.Field, true_field: common.Field, false_field: common.Field + masks: common.Domain | tuple[common.Domain, ...], + true_field: common.Field, + false_field: common.Field, ) -> common.Field: - cls_ = _get_nd_array_class(mask_field, true_field, false_field) - xp = cls_.array_ns - if mask_field.domain.ndim != 1: + if not isinstance(masks, tuple): + masks = (masks,) + if any(m.ndim for m in masks) != 1: raise NotImplementedError( "'concat_where': Can only concatenate fields with a 1-dimensional mask." ) - mask_dim = mask_field.domain.dims[0] + mask_dim = masks[0].dims[0] # intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) - # TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils - # compute the consecutive ranges (first relative, then domain) of true and false values - mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices( - mask_field.ndarray - ) - mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( - (mask, mask_field.domain.slice_at[domain_slice]) - for mask, domain_slice in mask_values_to_slices_mapping - ) - # mask domains intersected with the respective fields - mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( - ( - mask_value, - embedded_common.domain_intersection( - t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain - ), - ) - for mask_value, mask_domain in mask_values_to_domain_mapping - ) - - # remove the empty domains from the beginning and end - mask_values_to_intersected_domains_mapping = _trim_empty_domains( - mask_values_to_intersected_domains_mapping - ) - if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping): - raise embedded_exceptions.NonContiguousDomain( - f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." - ) + true_domains = _intersect_multiple(t_broadcasted.domain, masks) + t_slices = tuple(t_broadcasted[d] for d in true_domains) - # slice the fields with the domain ranges - transformed = [ - t_broadcasted[d] if v else f_broadcasted[d] - for v, d in mask_values_to_intersected_domains_mapping - ] + inverted_masks = _invert_domain(masks) + false_domains = _intersect_multiple(f_broadcasted.domain, inverted_masks) + f_slices = tuple(f_broadcasted[d] for d in false_domains) - # stack the fields together - if transformed: - return _concat(*transformed, dim=mask_dim) - else: - result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0))) - result_array = xp.empty(result_domain.shape) - return cls_.from_array(result_array, domain=result_domain) + return _concat(*f_slices, *t_slices, dim=mask_dim) NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index d96209bbb1..98e47e31ae 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -406,7 +406,12 @@ def create_if( return im.let(cond_symref_name, cond_)(result) - _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + domain, true_branch, false_branch = self.visit(node.args, **kwargs) + # TODO: use this case again. breaks domain inference in fused_velocity_advection_stencil_1_to_7 + # because some tuple elements are never accessed and the collapse tuple + # does not propagate across concat where + return im.concat_where(domain, true_branch, false_branch) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return im.call("broadcast")(*self.visit(node.args, **kwargs)) @@ -488,7 +493,7 @@ def _map( Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. """ if all( - isinstance(t, ts.ScalarType) + isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType)) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 9b065de658..d2bcded0cc 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -407,6 +407,16 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def concat_where(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def in_(*args): + raise BackendNotSelectedError() + + UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -494,6 +504,8 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "scan", "tuple_get", "unstructured_domain", + "concat_where", + "in_", *ARITHMETIC_BUILTINS, *TYPE_BUILTINS, } diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index d86a584b8d..d7c455763e 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1802,6 +1802,16 @@ def index(axis: common.Dimension) -> common.Field: return IndexField(axis) +@builtins.concat_where.register(EMBEDDED) +def concat_where(*args): + raise NotImplementedError("To be implemented in frontend embedded.") + + +@builtins.in_.register(EMBEDDED) +def in_(*args): + raise NotImplementedError("To be implemented in frontend embedded.") + + def closure( domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, sten: Callable[..., Any], diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index d82ef93015..e25eaeee1e 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -5,8 +5,10 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations -from typing import ClassVar, List, Optional, Union +import typing +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef @@ -63,6 +65,22 @@ class NoneLiteral(Expr): _none_literal: int = 0 +class InfinityLiteral(Expr): + # TODO(tehrengruber): self referential `ClassVar` not supported in eve. + if TYPE_CHECKING: + POSITIVE: ClassVar[InfinityLiteral] + NEGATIVE: ClassVar[InfinityLiteral] + + name: typing.Literal["POSITIVE", "NEGATIVE"] + + def __str__(self): + return f"{type(self).__name__}.{self.name}" + + +InfinityLiteral.NEGATIVE = InfinityLiteral(name="NEGATIVE") +InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE") + + class OffsetLiteral(Expr): value: Union[int, str] @@ -142,4 +160,5 @@ class Program(Node, ValidatedSymbolTableTrait): Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] +InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] Temporary.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index afcf02927a..c01b0a0dcc 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -14,7 +14,7 @@ from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -51,7 +51,7 @@ def translate(self, distance: int) -> SymbolicRange: @dataclasses.dataclass(frozen=True) class SymbolicDomain: - grid_type: Literal["unstructured_domain", "cartesian_domain"] + grid_type: common.GridType ranges: dict[ common.Dimension, SymbolicRange ] # TODO(havogt): remove `AxisLiteral` by `Dimension` everywhere @@ -61,25 +61,19 @@ def __hash__(self) -> int: @classmethod def from_expr(cls, node: itir.Node) -> SymbolicDomain: - assert isinstance(node, itir.FunCall) and node.fun in [ - im.ref("unstructured_domain"), - im.ref("cartesian_domain"), - ] + assert cpm.is_call_to(node, ("unstructured_domain", "cartesian_domain")) + grid_type = getattr(common.GridType, node.fun.id[: -len("_domain")].upper()) ranges: dict[common.Dimension, SymbolicRange] = {} for named_range in node.args: - assert ( - isinstance(named_range, itir.FunCall) - and isinstance(named_range.fun, itir.SymRef) - and named_range.fun.id == "named_range" - ) + assert cpm.is_call_to(named_range, "named_range") axis_literal, lower_bound, upper_bound = named_range.args assert isinstance(axis_literal, itir.AxisLiteral) ranges[common.Dimension(value=axis_literal.value, kind=axis_literal.kind)] = ( SymbolicRange(lower_bound, upper_bound) ) - return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above + return cls(grid_type, ranges) def as_expr(self) -> itir.FunCall: converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = { @@ -183,3 +177,74 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: + """Return the (set) intersection of a list of domains.""" + new_domain_ranges = {} + assert all(domain.grid_type == domains[0].grid_type for domain in domains) + for dim in domains[0].ranges.keys(): + start = functools.reduce( + lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + [domain.ranges[dim].start for domain in domains], + ) + stop = functools.reduce( + lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + [domain.ranges[dim].stop for domain in domains], + ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr + new_domain_ranges[dim] = SymbolicRange(start, stop) + + return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: + """Return the (set) complement of a domain.""" + dims_dict = {} + for dim in domain.ranges.keys(): + lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop + # `]-inf, a[` -> `[a, inf[` + if lb == itir.InfinityLiteral.NEGATIVE: + dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral.POSITIVE) + # `[a, inf]` -> `]-inf, a]` + elif ub == itir.InfinityLiteral.POSITIVE: + dims_dict[dim] = SymbolicRange(start=itir.InfinityLiteral.NEGATIVE, stop=lb) + else: + raise ValueError("Invalid domain ranges") + return SymbolicDomain(domain.grid_type, dims_dict) + + +def promote_to_same_dimensions( + domain_small: SymbolicDomain, domain_large: SymbolicDomain +) -> SymbolicDomain: + """Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions.""" + dims_dict = {} + for dim in domain_large.ranges.keys(): + if dim in domain_small.ranges.keys(): + lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop + dims_dict[dim] = SymbolicRange(lb, ub) + else: + dims_dict[dim] = SymbolicRange( + itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE + ) + return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured + + +def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: + """ + Return whether a range is unbounded in (at least) one direction. + + The expression is required to be constant folded before for the result to be reliable. + """ + if isinstance(range_ := range_or_domain, SymbolicRange): + # TODO: assert no infinity literal in here + if any( + v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] + for v in [range_.start, range_.stop] + ): + return False + return True + elif isinstance(domain := range_or_domain, SymbolicDomain): + return all(is_finite(range_) for range_ in domain.ranges.values()) + raise ValueError("Expected a SymbolicRange or SymbolicDomain.") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 7acdeb2f61..739aa5d90d 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -243,6 +243,12 @@ def if_(cond, true_val, false_val): return call("if_")(cond, true_val, false_val) +def concat_where(cond, true_field, false_field): + """Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``.""" + + return call("concat_where")(cond, true_field, false_field) + + def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) @@ -437,18 +443,18 @@ def domain( """ if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" - return call(grid_type)( + expr = call(grid_type)( *[ call("named_range")( - itir.AxisLiteral(value=d.value, kind=d.kind) - if isinstance(d, common.Dimension) - else itir.AxisLiteral(value=d), + axis_literal(d), r[0], r[1], ) for d, r in ranges.items() ] ) + expr.type = ts.DomainType(dims=list(ranges.keys())) + return expr def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable: @@ -478,7 +484,8 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal def _populate_domain_annex_wrapper(*args, **kwargs): node = result(*args, **kwargs) # note: if the domain is not a direct construction, e.g. because it is only a reference - # to a domain defined in a let, don't populate the annex + # to a domain defined in a let, don't populate the annex, since we can not create a + # symbolic domain for it. if domain and cpm.is_call_to(domain, ("cartesian_domain", "unstructured_domain")): node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) return node @@ -515,6 +522,10 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl +def axis_literal(dim: common.Dimension) -> itir.AxisLiteral: + return itir.AxisLiteral(value=dim.value, kind=dim.kind) + + def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None): """ Promotes the function `cast_` to a field_operator. diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 0c75679c8c..100e46e44c 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -12,6 +12,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import inline_lambdas @@ -234,3 +235,21 @@ def unique_symbol(sym: SymOrStr, reserved_names: Iterable[str]) -> SymOrStr: name = name + "_" return name + + +def grid_type_from_domain(domain: itir.FunCall) -> common.GridType: + if cpm.is_call_to(domain, "cartesian_domain"): + return common.GridType.CARTESIAN + else: + assert cpm.is_call_to(domain, "unstructured_domain") + return common.GridType.UNSTRUCTURED + + +def grid_type_from_program(program: itir.Program) -> common.GridType: + domains = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + grid_types = {grid_type_from_domain(d) for d in domains} + if len(grid_types) != 1: + raise ValueError( + f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." + ) + return grid_types.pop() diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 7acbf5d23d..5063e26392 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -133,6 +133,13 @@ def visit_Sym(self, node: ir.Sym, *, prec: int) -> list[str]: def visit_Literal(self, node: ir.Literal, *, prec: int) -> list[str]: return [str(node.value)] + def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, prec: int) -> list[str]: + if node == ir.InfinityLiteral.POSITIVE: + return ["∞"] + elif node == ir.InfinityLiteral.NEGATIVE: + return ["-∞"] + raise AssertionError() + def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index eae97b51b4..85a4854998 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -50,7 +50,7 @@ def _with_altered_iterator_position_dims( ) -def _is_trivial_make_tuple_call(node: ir.Expr): +def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): return False @@ -112,7 +112,7 @@ def _flattened_as_fieldop_param_el_name(param: str, idx: int) -> str: # should revisit the pattern here and try to find a more general mechanism. @dataclasses.dataclass(frozen=True, kw_only=True) class CollapseTuple( - fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor + fixed_point_transformation.CombinedFixedPointTransform, eve.PreserveLocationVisitor ): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -307,9 +307,10 @@ def transform_propagate_tuple_get(self, node: itir.FunCall, **kwargs) -> Optiona self.fp_transform(im.tuple_get(idx.value, expr.fun.expr), **kwargs) ) )(*expr.args) - elif cpm.is_call_to(expr, "if_"): + elif cpm.is_call_to(expr, ("if_", "concat_where")): + fun = expr.fun cond, true_branch, false_branch = expr.args - return im.if_( + return im.call(fun)( cond, self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs), self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs), diff --git a/src/gt4py/next/iterator/transforms/concat_where/__init__.py b/src/gt4py/next/iterator/transforms/concat_where/__init__.py new file mode 100644 index 0000000000..a9c3fb2576 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/__init__.py @@ -0,0 +1,18 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next.iterator.transforms.concat_where.expand_tuple_args import expand_tuple_args +from gt4py.next.iterator.transforms.concat_where.simplify_domain_argument import ( + simplify_domain_argument, +) +from gt4py.next.iterator.transforms.concat_where.transform_to_as_fieldop import ( + transform_to_as_fieldop, +) + + +__all__ = ["expand_tuple_args", "simplify_domain_argument", "transform_to_as_fieldop"] diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py new file mode 100644 index 0000000000..ea4086976f --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -0,0 +1,62 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts + + +class _ExpandTupleArgs(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply( + cls, + node: itir.Node, + *, + offset_provider_type: common.OffsetProviderType, + allow_undeclared_symbols: bool = False, + ) -> itir.Node: + node = type_inference.infer( + node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, + ) + return cls().visit(node) + + def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: + node = self.generic_visit(node) + + # `concat_where(cond, {a, b}, {c, d})` + # -> `{concat_where(cond, a, c), concat_where(cond, a, c)}` + if cpm.is_call_to(node, "concat_where") and isinstance( + type_inference.reinfer(node.args[1]).type, ts.TupleType + ): + cond, true_branch, false_branch = node.args + new_els = [] + assert isinstance(true_branch.type, ts.TupleType) + for i in range(len(true_branch.type.types)): + new_els.append( + im.concat_where(cond, im.tuple_get(i, "__tb"), im.tuple_get(i, "__fb")) + ) + + new_node = im.let(("__tb", true_branch), ("__fb", false_branch))( + im.make_tuple(*new_els) + ) + return new_node + + return node + + +expand_tuple_args = _ExpandTupleArgs.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py b/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py new file mode 100644 index 0000000000..b626bcb5a1 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py @@ -0,0 +1,93 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Optional + +from gt4py.eve import PreserveLocationVisitor +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain +from gt4py.next.iterator.transforms import fixed_point_transformation + + +def _range_complement( + range_: domain_utils.SymbolicRange, +) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: + # `[a, b[` -> `[-inf, a[` ∪ `[b, inf[` # noqa: RUF003 + assert not any(isinstance(b, itir.InfinityLiteral) for b in [range_.start, range_.stop]) + return ( + domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, range_.start), + domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE), + ) + + +class _SimplifyDomainArgument( + PreserveLocationVisitor, fixed_point_transformation.FixedPointTransformation +): + @classmethod + def apply(cls, node: itir.Node): + return cls().visit(node) + + def transform(self, node: itir.Node) -> Optional[itir.Node]: # type: ignore[override] # ignore kwargs for simplicity + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + if cpm.is_call_to(cond_expr, "and_"): + conds = cond_expr.args + return im.let(("__cwsda_field_a", field_a), ("__cwsda_field_b", field_b))( + self.fp_transform( + im.concat_where( + conds[0], + self.fp_transform( + im.concat_where(conds[1], "__cwsda_field_a", "__cwsda_field_b") + ), + "__cwsda_field_b", + ) + ) + ) + if cpm.is_call_to(cond_expr, "or_"): + conds = cond_expr.args + return im.let(("__cwsda_field_a", field_a), ("__cwsda_field_b", field_b))( + self.fp_transform( + im.concat_where( + conds[0], + "__cwsda_field_a", + self.fp_transform( + im.concat_where(conds[1], "__cwsda_field_a", "__cwsda_field_b") + ), + ) + ) + ) + + # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) + if cpm.is_call_to(cond_expr, ("cartesian_domain", "unstructured_domain")): + domain = SymbolicDomain.from_expr(cond_expr) + if len(domain.ranges) == 1: + dim, range_ = next(iter(domain.ranges.items())) + if domain_utils.is_finite(range_): + complement = _range_complement(range_) + new_domains = [ + im.domain(domain.grid_type, {dim: (cr.start, cr.stop)}) + for cr in complement + ] + return self.fp_transform( + im.concat_where(im.call("or_")(*new_domains), field_b, field_a) + ) + else: + # TODO(tehrengruber): Implement. Note that this case can not be triggered by + # the frontend yet since domains can only be created by expressions like + # `IDim < 10`. + raise NotImplementedError() + + return None + + +simplify_domain_argument = _SimplifyDomainArgument.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py new file mode 100644 index 0000000000..1a9b2d13ef --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -0,0 +1,94 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import functools + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.transforms import symbol_ref_utils +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts + + +def _in(pos: itir.Expr, domain: itir.Expr) -> itir.Expr: + """ + Given a position and a domain return an expression that evaluates to `True` if the position is inside the domain. + + `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` + -> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1` + """ + ret = [] + for i, v in enumerate(domain_utils.SymbolicDomain.from_expr(domain).ranges.values()): + ret.append( + im.and_( + im.less_equal(v.start, im.tuple_get(i, pos)), + im.less(im.tuple_get(i, pos), v.stop), + ) + ) + return functools.reduce(im.and_, ret) + + +class _TransformToAsFieldop(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply(cls, node: itir.Node): + """ + Transform `concat_where` expressions into equivalent `as_fieldop` expressions. + + Note that (backward) domain inference may not be executed after this pass as it can not + correctly infer the accessed domains when the value selection is represented as an `if_` + inside the `as_fieldop. + """ + node = cls().visit(node) + node = type_inference.SanitizeTypes().visit(node) + return node + + def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: + node = self.generic_visit(node) + if cpm.is_call_to(node, "concat_where"): + cond, true_branch, false_branch = node.args + assert isinstance(cond.type, ts.DomainType) + position = [im.index(dim) for dim in cond.type.dims] + refs = symbol_ref_utils.collect_symbol_refs(cond) + + domains: tuple[domain_utils.SymbolicDomain, ...] = utils.flatten_nested_tuple( + node.annex.domain + ) + assert all(domain == domains[0] for domain in domains), ( + "At this point all `concat_where` arguments should be posed on the same domain." + ) + assert isinstance(domains[0], domain_utils.SymbolicDomain) + domain_expr = domains[0].as_expr() + + return im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( + im.let(*zip(refs, map(im.deref, refs), strict=True))( + im.if_( + _in(im.deref("__tcw_pos"), cond), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ) + ), + domain_expr, + )(im.make_tuple(*position), true_branch, false_branch, *refs) + + return node + + +transform_to_as_fieldop = _TransformToAsFieldop.apply diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 774e7a6702..bfb8378be6 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -56,7 +56,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: @dataclasses.dataclass(frozen=True, kw_only=True) class ConstantFolding( - fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor + fixed_point_transformation.CombinedFixedPointTransform, eve.PreserveLocationVisitor ): PRESERVED_ANNEX_ATTRS = ( "type", @@ -98,6 +98,8 @@ class Transformation(enum.Flag): # `if_(True, true_branch, false_branch)` -> `true_branch` FOLD_IF = enum.auto() + FOLD_INFINITY_ARITHMETIC = enum.auto() + @classmethod def all(self) -> ConstantFolding.Transformation: return functools.reduce(operator.or_, self.__members__.values()) @@ -239,3 +241,57 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: assert node.args[0].value == "False" return node.args[2] return None + + def transform_fold_infinity_arithmetic(self, node: ir.FunCall) -> Optional[ir.Node]: + if cpm.is_call_to(node, "plus"): + for arg in node.args: + # `a + inf` -> `inf` + if arg == ir.InfinityLiteral.POSITIVE: + return ir.InfinityLiteral.POSITIVE + # `a + (-inf)` -> `-inf` + if arg == ir.InfinityLiteral.NEGATIVE: + return ir.InfinityLiteral.NEGATIVE + + if cpm.is_call_to(node, "minimum"): + a, b = node.args + for arg, other_arg in ((a, b), (b, a)): + # `minimum(inf, a)` -> `a` + if arg == ir.InfinityLiteral.POSITIVE: + return other_arg + # `minimum(-inf, a)` -> `-inf` + if arg == ir.InfinityLiteral.NEGATIVE: + return ir.InfinityLiteral.NEGATIVE + + if cpm.is_call_to(node, "maximum"): + a, b = node.args + for arg, other_arg in ((a, b), (b, a)): + # `maximum(inf, a)` -> `inf` + if arg == ir.InfinityLiteral.POSITIVE: + return ir.InfinityLiteral.POSITIVE + # `maximum(-inf, a)` -> `a` + if arg == ir.InfinityLiteral.NEGATIVE: + return other_arg + + if cpm.is_call_to(node, ("less", "less_equal")): + a, b = node.args + # `-inf < v` -> `True` + # `v < inf` -> `True` + if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE: + return im.literal_from_value(True) + # `inf < v` -> `False` + # `v < -inf ` -> `False` + if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE: + return im.literal_from_value(False) + + if cpm.is_call_to(node, ("greater", "greater_equal")): + a, b = node.args + # `inf > v` -> `True` + # `v > -inf ` -> `True` + if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE: + return im.literal_from_value(True) + # `-inf > v` -> `False` + # `v > inf` -> `False` + if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE: + return im.literal_from_value(False) + + return None diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 87241e1ba8..2fcbd5df0d 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -87,9 +87,19 @@ def _is_collectable_expr(node: itir.Node) -> bool: # backend (single pass eager depth first visit approach) # do also not collect lifts or applied lifts as they become invisible to the lift inliner # otherwise - if cpm.is_call_to(node, ("lift", "shift", "reduce", "map_")) or cpm.is_applied_lift(node): + # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement + # instead of an as_fieldop + if cpm.is_call_to( + node, ("lift", "shift", "reduce", "map_", "index") + ) or cpm.is_applied_lift(node): return False return True + # do also not collect make_tuple(index) nodes because otherwise the right hand side of SetAts becomes a let statement + # instead of an as_fieldop + if cpm.is_call_to(node, "make_tuple") and all( + cpm.is_call_to(arg, "index") for arg in node.args + ): + return False elif isinstance(node, itir.Lambda): return True diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py index b48e0cc3cf..261008725b 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -15,20 +15,11 @@ from gt4py.next.iterator.type_system import inference as itir_type_inference -@dataclasses.dataclass(frozen=True, kw_only=True) class FixedPointTransformation(eve.NodeTranslator): """ - Transformation pass that transforms until no transformation is applicable anymore. + Base class for iterative transformations that converge when a fixed-point is reached. """ - #: Enum of all transformation (names). The transformations need to be defined as methods - #: named `transform_`. - Transformation: ClassVar[Type[enum.Flag]] - - #: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`. - #: Usually the default value is chosen to be all transformations. - enabled_transformations: enum.Flag - REINFER_TYPES: ClassVar[bool] = False def visit(self, node, **kwargs): @@ -43,18 +34,44 @@ def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: new_node = self.transform(node, **kwargs) if new_node is None: break + else: + new_node = self.post_transform(node, new_node) assert new_node != node node = new_node return node - def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: - """ - Transform node once. + def post_transform(self, node: ir.Node, new_node: ir.Node) -> ir.Node: + if self.REINFER_TYPES: + itir_type_inference.reinfer(new_node) + self._preserve_annex(node, new_node) + return new_node - Execute transformations until one is applicable. As soon as a transformation occured - the function will return the transformed node. Note that the transformation itself - may call other transformations on child nodes again. - """ + """ + Transform node once. + + Execute transformation if applicable. When a transformation occurred the function will return + the transformed node. Note that the transformation itself may call other transformations on + child nodes again. + """ + + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: ... + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class CombinedFixedPointTransform(FixedPointTransformation): + """ + Base class for a set of iterative transformations that converge when a fixed-point is reached. + """ + + #: Enum of all transformation (names). The transformations need to be defined as methods + #: named `transform_`. + Transformation: ClassVar[Type[enum.Flag]] + + #: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`. + #: Usually the default value is chosen to be all transformations. + enabled_transformations: enum.Flag + + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: for transformation in self.Transformation: if self.enabled_transformations & transformation: assert isinstance(transformation.name, str) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 26a8bcad1c..4b3a258396 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -231,7 +231,7 @@ def _make_tuple_element_inline_predicate(node: itir.Expr): @dataclasses.dataclass(frozen=True, kw_only=True) class FuseAsFieldOp( - fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor + fixed_point_transformation.CombinedFixedPointTransform, eve.PreserveLocationVisitor ): """ Merge multiple `as_fieldop` calls into one. @@ -450,7 +450,4 @@ def visit(self, node, **kwargs): node = super().visit(node, **kwargs) - if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"): - node.annex.domain = node.annex.domain - return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index e1e6d74a72..daeec2b675 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -329,11 +329,16 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - offset_provider_type = common.offset_provider_to_type(offset_provider) + # TODO: document why to keep existing domains, add test program = infer_domain.infer_program( - program, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes + program, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + keep_existing_domains=True, + ) + program = type_inference.infer( + program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) - program = type_inference.infer(program, offset_provider_type=offset_provider_type) if not uids: uids = eve_utils.UIDGenerator(prefix="__tmp") diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 3c6fbedf3d..422d9c9c13 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -21,6 +21,7 @@ ir_makers as im, misc as ir_misc, ) +from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain from gt4py.next.iterator.transforms import constant_folding, trace_shifts from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -56,6 +57,7 @@ class InferenceOptions(typing.TypedDict): offset_provider: common.OffsetProvider | common.OffsetProviderType symbolic_domain_sizes: Optional[dict[str, str]] allow_uninferred: bool + keep_existing_domains: bool class DomainAnnexDebugger(eve.NodeVisitor): @@ -182,11 +184,16 @@ def _infer_as_fieldop( offset_provider: common.OffsetProvider | common.OffsetProviderType, symbolic_domain_sizes: Optional[dict[str, str]], allow_uninferred: bool, + keep_existing_domains: bool, ) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") + + if len(applied_fieldop.fun.args) == 2 and keep_existing_domains: + target_domain = SymbolicDomain.from_expr(applied_fieldop.fun.args[1]) + # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough @@ -226,6 +233,7 @@ def _infer_as_fieldop( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains, ) transformed_inputs.append(transformed_input) @@ -348,6 +356,40 @@ def _infer_if( return result_expr, actual_domains +def _infer_concat_where( + expr: itir.Expr, + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: + assert cpm.is_call_to(expr, "concat_where") + infered_args_expr = [] + actual_domains: AccessedDomains = {} + cond, true_field, false_field = expr.args + symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond) + cond_complement = domain_utils.domain_complement(symbolic_cond) + + for arg in [true_field, false_field]: + + @tree_map + def mapper(d: NonTupleDomainAccess): + if isinstance(d, DomainAccessDescriptor): + return d + promoted_cond = domain_utils.promote_to_same_dimensions( + symbolic_cond if arg == true_field else cond_complement, # noqa: B023 # function is never used outside the loop + d, + ) + return domain_utils.domain_intersection(d, promoted_cond) + + domain_ = mapper(domain) + + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + return result_expr, actual_domains + + def _infer_broadcast( expr: itir.Expr, domain: DomainAccess, @@ -380,6 +422,8 @@ def _infer_expr( return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): return _infer_if(expr, domain, **kwargs) + elif cpm.is_call_to(expr, "concat_where"): + return _infer_concat_where(expr, domain, **kwargs) elif cpm.is_call_to(expr, "broadcast"): return _infer_broadcast(expr, domain, **kwargs) elif ( @@ -399,6 +443,7 @@ def infer_expr( offset_provider: common.OffsetProvider | common.OffsetProviderType, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, + keep_existing_domains: bool = False, ) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. @@ -413,6 +458,10 @@ def infer_expr( name that evaluates to the length of that axis. - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. because of a dynamic shift) or never accessed. + # TODO: describe why this is needed with concat_where (if inside as_fieldop might shrinken the + actually access domain) + - keep_existing_domains: If `True`, keep existing domains in `as_fieldop` expressions and + use them to propagate the domain further. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) @@ -458,8 +507,10 @@ def infer_expr( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains, ) - expr.annex.domain = domain + if not keep_existing_domains or not hasattr(expr.annex, "domain"): + expr.annex.domain = domain return expr, accessed_domains @@ -497,6 +548,8 @@ def infer_program( offset_provider: common.OffsetProvider | common.OffsetProviderType, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, + # TODO: add test + keep_existing_domains: bool = False, ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -522,6 +575,7 @@ def infer_program( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains, ) for stmt in program.body ], diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py new file mode 100644 index 0000000000..6447ceca32 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -0,0 +1,107 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import builtins, ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, + misc as ir_misc, +) +from gt4py.next.iterator.type_system import inference +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass +class InferDomainOps(PreserveLocationVisitor, NodeTranslator): + grid_type: common.GridType + + @classmethod + def apply(cls, program: itir.Program): + return cls(grid_type=ir_misc.grid_type_from_program(program)).visit(program, recurse=True) + + def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: + if kwargs["recurse"]: + node = self.generic_visit(node, **kwargs) + + # e.g. `IDim < a` + if cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) and any( + isinstance(arg, itir.AxisLiteral) for arg in node.args + ): + arg1, arg2 = node.args + if isinstance(arg2, itir.AxisLiteral): + # take complementary operation if we have e.g. `0 < IDim` use `IDim > 0` + complementary_op = { + "less": "greater", + "less_equal": "greater_equal", + "greater": "less", + "greater_equal": "less_equal", + "eq": "eq", + "not_eq": "not_eq", + } + return self.visit( + im.call(complementary_op[node.fun.id])(arg2, arg1), + **{**kwargs, "recurse": False}, + ) + + inference.reinfer(arg1) + assert isinstance(arg1.type, ts.DimensionType) + dim: common.Dimension = arg1.type.dim + value: itir.Expr = arg2 + + if cpm.is_call_to(node, ("less", "less_equal", "greater", "greater_equal", "eq")): + min_: itir.Expr + max_: itir.Expr + + # `IDim < 1` + if cpm.is_call_to(node, "less"): + min_ = itir.InfinityLiteral.NEGATIVE + max_ = value + # `IDim <= 1` + elif cpm.is_call_to(node, "less_equal"): + min_ = itir.InfinityLiteral.NEGATIVE + max_ = im.plus(value, 1) + # `IDim > 1` + elif cpm.is_call_to(node, "greater"): + min_ = im.plus(value, 1) + max_ = itir.InfinityLiteral.POSITIVE + # `IDim >= 1` + elif cpm.is_call_to(node, "greater_equal"): + min_ = value + max_ = itir.InfinityLiteral.POSITIVE + # `IDim == 1` + elif cpm.is_call_to(node, "eq"): + min_ = value + max_ = im.plus(value, 1) + + domain = domain_utils.SymbolicDomain( + self.grid_type, + ranges={dim: domain_utils.SymbolicRange(start=min_, stop=max_)}, + ) + + return domain.as_expr() + elif cpm.is_call_to(node, "not_eq"): + # `IDim != a` -> `IDim < a & IDim > a` + return self.visit( + im.call("and_")( + self.visit( + im.less(im.axis_literal(dim), value), **(kwargs | {"recurse": False}) + ), + self.visit( + im.greater(im.axis_literal(dim), value), **(kwargs | {"recurse": False}) + ), + ), + **(kwargs | {"recurse": False}), + ) + else: + raise AssertionError() + + return node diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 03b20d14fe..2b8767e4a2 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -32,6 +32,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: """ Remove all function declarations that are never called. + >>> from gt4py.next import common >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> fun1 = itir.FunctionDefinition( ... id="fun1", @@ -43,6 +44,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: ... params=[im.sym("a")], ... expr=im.deref("a"), ... ) + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) >>> program = itir.Program( ... id="testee", ... function_definitions=[fun1, fun2], @@ -51,7 +53,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: ... body=[ ... itir.SetAt( ... expr=im.call("fun1")("inp"), - ... domain=im.domain("cartesian_domain", {"IDim": (0, 10)}), + ... domain=im.domain("cartesian_domain", {IDim: (0, 10)}), ... target=im.ref("out"), ... ) ... ], diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index eef2c1bab0..e2e4217b7f 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,10 +12,12 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + concat_where, dead_code_elimination, fuse_as_fieldop, global_tmps, infer_domain, + infer_domain_ops, inline_dynamic_shifts, inline_fundefs, inline_lifts, @@ -81,6 +83,10 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = concat_where.simplify_domain_argument(ir) + + ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -88,6 +94,8 @@ def apply_common_transforms( ) ir = remove_broadcast.RemoveBroadcast.apply(ir) + ir = concat_where.transform_to_as_fieldop(ir) + for _ in range(10): inlined = ir @@ -183,6 +191,11 @@ def apply_fieldview_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + + ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = concat_where.simplify_domain_argument(ir) + ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) ir = remove_broadcast.RemoveBroadcast.apply(ir) return ir diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index de91b9ee87..8173ceebbb 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -274,6 +274,9 @@ class TraceShifts(PreserveLocationVisitor, NodeTranslator): def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: return Sentinel.VALUE + def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, ctx: dict[str, Any]): + return Sentinel.VALUE + def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: if node.id in ctx: return ctx[node.id] diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 237ad35140..c33e3a71d0 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -511,6 +511,12 @@ def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: assert isinstance(node.type, ts.ScalarType) return node.type + def visit_InfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + + def visit_NegInfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + def visit_SymRef( self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec] ) -> ts.TypeSpec | type_synthesizer.TypeSynthesizer: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 8f683c2ff9..ce99532645 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -15,7 +15,7 @@ from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union -from gt4py.next import common +from gt4py.next import common, utils from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -149,13 +149,31 @@ def _(arg: ts.ScalarType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) -@_register_builtin_type_synthesizer( - fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS | builtins.BINARY_LOGICAL_BUILTINS -) -def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: +def synthesize_binary_math_comparison_builtins( + lhs, rhs +) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.DimensionType): + return ts.DomainType(dims=[rhs.dim]) + if isinstance(lhs, ts.DimensionType) and isinstance(rhs, ts.ScalarType): + return ts.DomainType(dims=[lhs.dim]) + assert all(isinstance(lhs, (ts.ScalarType, ts.DeferredType)) for arg in (lhs, rhs)) return ts.ScalarType(kind=ts.ScalarKind.BOOL) +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS) +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + return synthesize_binary_math_comparison_builtins(lhs, rhs) + + +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_LOGICAL_BUILTINS) +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.DomainType) and isinstance(rhs, ts.DomainType): + assert lhs.dims != "unknown" and rhs.dims != "unknown" + return ts.DomainType(dims=common.promote_dims(lhs.dims, rhs.dims)) + else: + return synthesize_binary_math_comparison_builtins(lhs, rhs) + + @_register_builtin_type_synthesizer def deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.DataType | ts.DeferredType: if isinstance(it, ts.DeferredType): @@ -245,6 +263,39 @@ def index(arg: ts.DimensionType) -> ts.FieldType: ) +@_register_builtin_type_synthesizer +def concat_where( + domain: ts.DomainType, + true_field: ts.FieldType | ts.TupleType | ts.DeferredType, + false_field: ts.FieldType | ts.TupleType | ts.DeferredType, +) -> ts.FieldType | ts.TupleType | ts.DeferredType: + if isinstance(true_field, ts.DeferredType) or isinstance(false_field, ts.DeferredType): + return ts.DeferredType(constraint=None) + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda el: ts.TupleType(types=list(el)), + ) + def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): + if any(isinstance(b, ts.DeferredType) for b in [tb, fb]): + return ts.DeferredType(constraint=ts.FieldType) + + tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb]) + + assert tb_dtype == fb_dtype, ( + f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'." + ) + dtype = tb_dtype + + return_dims = common.promote_dims( + domain.dims, type_info.extract_dims(type_info.promote(tb, fb)) + ) + return_type = ts.FieldType(dims=return_dims, dtype=dtype) + return return_type + + return deduce_return_type(true_field, false_field) + + @_register_builtin_type_synthesizer def broadcast( arg: ts.FieldType | ts.ScalarType | ts.DeferredType, dims: tuple[ts.DimensionType] @@ -321,11 +372,7 @@ def _collect_and_check_dimensions(input_: ts.TypeSpec) -> list[common.Dimension] .filter(lambda dims: len(dims) > 0) .to_list() ) - if all_input_dims: - assert all(cur_input_dims == all_input_dims[0] for cur_input_dims in all_input_dims) - return all_input_dims[0] - - return [] + return common.promote_dims(*all_input_dims) def _convert_as_fieldop_input_to_iterator( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 702bd48dec..f7445461c0 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -125,7 +125,11 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el) + or _is_tuple_expr_of( + lambda expr: isinstance(expr, (SymRef, Literal)) + or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")), + el, + ) for el in value ): raise ValueError( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 696cfc62ea..a445390583 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -84,24 +84,6 @@ def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: return result -def _extract_grid_type(domain: itir.FunCall) -> common.GridType: - if domain.fun == itir.SymRef(id="cartesian_domain"): - return common.GridType.CARTESIAN - else: - assert domain.fun == itir.SymRef(id="unstructured_domain") - return common.GridType.UNSTRUCTURED - - -def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: - domains = _get_domains(body) - grid_types = {_extract_grid_type(d) for d in domains} - if len(grid_types) != 1: - raise ValueError( - f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." - ) - return grid_types.pop() - - def _name_from_named_range(named_range_call: itir.FunCall) -> str: assert isinstance(named_range_call, itir.FunCall) and named_range_call.fun == itir.SymRef( id="named_range" @@ -342,7 +324,7 @@ def apply( raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) - grid_type = _get_gridtype(node.body) + grid_type = ir_utils_misc.grid_type_from_program(node) if grid_type == common.GridType.UNSTRUCTURED: node = _CannonicalizeUnstructuredDomain.apply(node) return cls( diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py index ac7df22ccc..a531653aad 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py @@ -355,6 +355,13 @@ def get_field_layout( if len(domain) == 0: return [], [], [] domain_dims, domain_lbs, domain_ubs = zip(*domain) + # after introduction of concat_where, the strict order of lower and upper bounds is not guaranteed + domain_ubs = tuple( + [ + dace.symbolic.pystr_to_symbolic(f"max({lb}, {ub})") + for lb, ub in zip(domain_lbs, domain_ubs, strict=True) + ] + ) domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)] return list(domain_dims), list(domain_lbs), domain_sizes @@ -509,24 +516,27 @@ def extract_domain(node: gtir.Node) -> FieldopDomain: domain = [] - def parse_range_boundary(expr: gtir.Expr) -> str: - return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr)) - if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): for named_range in node.args: assert cpm.is_call_to(named_range, "named_range") assert len(named_range.args) == 3 axis = named_range.args[0] assert isinstance(axis, gtir.AxisLiteral) - lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3]) + lower_bound, upper_bound = ( + gtir_sdfg_utils.get_symbolic(arg) for arg in named_range.args[1:3] + ) dim = gtx_common.Dimension(axis.value, axis.kind) domain.append((dim, lower_bound, upper_bound)) elif isinstance(node, domain_utils.SymbolicDomain): - assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + assert isinstance(node.grid_type, gtx_common.GridType) for dim, drange in node.ranges.items(): domain.append( - (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) + ( + dim, + gtir_sdfg_utils.get_symbolic(drange.start), + gtir_sdfg_utils.get_symbolic(drange.stop), + ) ) else: @@ -597,6 +607,319 @@ def translate_as_fieldop( ) +def _make_concat_field_slice( + sdfg: dace.SDFG, + state: dace.SDFGState, + f: FieldopData, + f_desc: dace.data.Array, + concat_dim: gtx_common.Dimension, + concat_dim_index: int, + concat_dim_origin: dace.symbolic.SymbolicType, +) -> tuple[FieldopData, dace.data.Array]: + """ + Helper function called by `translate_concat_where` to create a slice along the + concat dimension, that is a new array with an extra diimension and a single level. + This allows to treat 'f' as a slice and concatanate it to the other argument field. + """ + assert isinstance(f.gt_type, ts.FieldType) + dims = [*f.gt_type.dims[:concat_dim_index], concat_dim, *f.gt_type.dims[concat_dim_index:]] + origin = tuple([*f.origin[:concat_dim_index], concat_dim_origin, *f.origin[concat_dim_index:]]) + shape = tuple([*f_desc.shape[:concat_dim_index], 1, *f_desc.shape[concat_dim_index:]]) + slice_data, slice_data_desc = sdfg.add_temp_transient(shape, f_desc.dtype) + slice_node = state.add_access(slice_data) + state.add_nedge( + f.dc_node, + slice_node, + dace.Memlet( + data=f.dc_node.data, + subset=dace_subsets.Range.from_array(f_desc), + other_subset=dace_subsets.Range.from_array(slice_data_desc), + ), + ) + fslice = FieldopData(slice_node, ts.FieldType(dims=dims, dtype=f.gt_type.dtype), origin) + return fslice, slice_data_desc + + +def _make_concat_scalar_broadcast( + sdfg: dace.SDFG, + state: dace.SDFGState, + inp: FieldopData, + inp_desc: dace.data.Array, + domain: FieldopDomain, + concat_dim_index: int, +) -> tuple[FieldopData, dace.data.Array]: + """ + Helper function called by `translate_concat_where` to create a mapped tasklet + that broadcasts one scalar value from the 1D-array 'f' on the given domain. + """ + assert isinstance(inp.gt_type, ts.FieldType) + assert len(inp.gt_type.dims) == 1 + out_dims, out_origin, out_shape = get_field_layout(domain) + out_type = ts.FieldType(dims=out_dims, dtype=inp.gt_type.dtype) + + out_name, out_desc = sdfg.add_temp_transient(out_shape, inp_desc.dtype) + out_node = state.add_access(out_name) + + map_variables = [gtir_sdfg_utils.get_map_variable(dim) for dim in out_dims] + inp_index = ( + "0" + if isinstance(inp.dc_node.desc(sdfg), dace.data.Scalar) + else ( + f"({map_variables[concat_dim_index]} + {out_origin[concat_dim_index] - inp.origin[0]})" + ) + ) + state.add_mapped_tasklet( + "broadcast", + map_ranges={ + index: r + for index, r in zip(map_variables, dace_subsets.Range.from_array(out_desc), strict=True) + }, + code="__out = __inp", + inputs={"__inp": dace.Memlet(data=inp.dc_node.data, subset=inp_index)}, + outputs={"__out": dace.Memlet(data=out_name, subset=",".join(map_variables))}, + input_nodes={inp.dc_node}, + output_nodes={out_node}, + external_edges=True, + ) + + out_field = FieldopData(out_node, out_type, tuple(out_origin)) + return out_field, out_desc + + +def translate_concat_where( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Lowers a `concat_where` expression to a dataflow where two memlets write + disjoint subsets on one data access node. + """ + assert cpm.is_call_to(node, "concat_where") + assert len(node.args) == 3 + + # First argument is a domain expression that defines the mask of the true branch: + # we extract the dimension along which we need to concatenate the field arguments, + # and determine whether the true branch argument should be on the lower or upper + # range with respect to the boundary value. + mask_domain = extract_domain(node.args[0]) + if len(mask_domain) != 1: + raise NotImplementedError("Expected `concat_where` along single axis.") + concat_dim, mask_lower_bound, mask_upper_bound = mask_domain[0] + + def concatenate_inputs( + node_domain: gtir.Expr, + tb_node_domain: gtir.Expr, + fb_node_domain: gtir.Expr, + tb_field: FieldopData, + fb_field: FieldopData, + ) -> FieldopData: + tb_data_desc, fb_data_desc = (inp.dc_node.desc(sdfg) for inp in [tb_field, fb_field]) + assert tb_data_desc.dtype == fb_data_desc.dtype + + tb_domain, fb_domain = ( + extract_domain(domain) for domain in [tb_node_domain, fb_node_domain] + ) + + # expect unbound range in the concat domain expression on lower or upper range + if mask_lower_bound == gtir_sdfg_utils.get_symbolic(gtir.InfinityLiteral.NEGATIVE): + concat_dim_bound = mask_upper_bound + lower, lower_desc, lower_domain = (tb_field, tb_data_desc, tb_domain) + upper, upper_desc, upper_domain = (fb_field, fb_data_desc, fb_domain) + elif mask_upper_bound == gtir_sdfg_utils.get_symbolic(gtir.InfinityLiteral.POSITIVE): + concat_dim_bound = mask_lower_bound + lower, lower_desc, lower_domain = (fb_field, fb_data_desc, fb_domain) + upper, upper_desc, upper_domain = (tb_field, tb_data_desc, tb_domain) + else: + raise ValueError(f"Unexpected concat mask {node.args[0]}.") + + # we use the concat domain, stored in the annex, as the domain of output field + output_domain = extract_domain(node_domain) + output_dims, output_origin, output_shape = get_field_layout(output_domain) + concat_dim_index = output_dims.index(concat_dim) + + # in case one of the arguments is a scalar value, we convert it to a single-element + # 1D field with the dimension of the concat expression + if isinstance(lower.gt_type, ts.ScalarType): + assert len(lower_domain) == 0 + assert isinstance(upper.gt_type, ts.FieldType) + lower = FieldopData( + lower.dc_node, + ts.FieldType(dims=[concat_dim], dtype=lower.gt_type), + origin=(concat_dim_bound - 1,), + ) + lower_bound = output_domain[concat_dim_index][1] + lower_domain = [(concat_dim, lower_bound, concat_dim_bound)] + elif isinstance(upper.gt_type, ts.ScalarType): + assert len(upper_domain) == 0 + assert isinstance(lower.gt_type, ts.FieldType) + upper = FieldopData( + upper.dc_node, + ts.FieldType(dims=[concat_dim], dtype=upper.gt_type), + origin=(concat_dim_bound,), + ) + upper_bound = output_domain[concat_dim_index][2] + upper_domain = [(concat_dim, concat_dim_bound, upper_bound)] + + if concat_dim not in lower.gt_type.dims: + assert lower.gt_type.dims == [ + *upper.gt_type.dims[0:concat_dim_index], + *upper.gt_type.dims[concat_dim_index + 1 :], + ] + lower, lower_desc = _make_concat_field_slice( + sdfg, state, lower, lower_desc, concat_dim, concat_dim_index, concat_dim_bound - 1 + ) + lower_bound = dace.symbolic.pystr_to_symbolic( + f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index][1]})" + ) + lower_domain.insert(concat_dim_index, (concat_dim, lower_bound, concat_dim_bound)) + elif concat_dim not in upper.gt_type.dims: + assert upper.gt_type.dims == [ + *lower.gt_type.dims[0:concat_dim_index], + *lower.gt_type.dims[concat_dim_index + 1 :], + ] + upper, upper_desc = _make_concat_field_slice( + sdfg, state, upper, upper_desc, concat_dim, concat_dim_index, concat_dim_bound + ) + upper_bound = dace.symbolic.pystr_to_symbolic( + f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index][2]})" + ) + upper_domain.insert(concat_dim_index, (concat_dim, concat_dim_bound, upper_bound)) + elif len(lower.gt_type.dims) == 1 and len(output_domain) > 1: + assert len(lower_domain) == 1 and lower_domain[0][0] == concat_dim + lower_domain = [ + *output_domain[:concat_dim_index], + lower_domain[0], + *output_domain[concat_dim_index + 1 :], + ] + lower, lower_desc = _make_concat_scalar_broadcast( + sdfg, state, lower, lower_desc, lower_domain, concat_dim_index + ) + elif len(upper.gt_type.dims) == 1 and len(output_domain) > 1: + assert len(upper_domain) == 1 and upper_domain[0][0] == concat_dim + upper_domain = [ + *output_domain[:concat_dim_index], + upper_domain[0], + *output_domain[concat_dim_index + 1 :], + ] + upper, upper_desc = _make_concat_scalar_broadcast( + sdfg, state, upper, upper_desc, upper_domain, concat_dim_index + ) + elif lower.gt_type.dims != upper.gt_type.dims: + raise NotImplementedError( + "concat_where on fields with different domain is not supported." + ) + + # ensure that the arguments have the same domain as the concat result + assert all(ftype.dims == output_dims for ftype in (lower.gt_type, upper.gt_type)) + + lower_range_0 = output_domain[concat_dim_index][1] + lower_range_1 = dace.symbolic.pystr_to_symbolic( + f"max({lower_range_0}, {lower_domain[concat_dim_index][2]})" + ) + lower_range_size = lower_range_1 - lower_range_0 + + upper_range_1 = output_domain[concat_dim_index][2] + upper_range_0 = dace.symbolic.pystr_to_symbolic( + f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})" + ) + upper_range_size = upper_range_1 - upper_range_0 + + output, output_desc = sdfg_builder.add_temp_array(sdfg, output_shape, lower_desc.dtype) + output_node = state.add_access(output) + + lower_subset = dace_subsets.Range( + [ + ( + lower_range_0 - lower.origin[dim_index], + lower_range_1 - lower.origin[dim_index] - 1, + 1, + ) + if dim_index == concat_dim_index + else ( + output_domain[dim_index][1] - lower.origin[dim_index], + output_domain[dim_index][1] - lower.origin[dim_index] + size - 1, + 1, + ) + for dim_index, size in enumerate(output_desc.shape) + ] + ) + # we write the data of the lower range into the output array starting from the index zero + lower_output_subset = dace_subsets.Range( + [ + (0, lower_range_size - 1, 1) if dim_index == concat_dim_index else (0, size - 1, 1) + for dim_index, size in enumerate(output_desc.shape) + ] + ) + state.add_nedge( + lower.dc_node, + output_node, + dace.Memlet( + data=lower.dc_node.data, + subset=lower_subset, + other_subset=lower_output_subset, + dynamic=True, # this memlet could be empty, but this is known only at runtime + ), + ) + + upper_subset = dace_subsets.Range( + [ + ( + upper_range_0 - upper.origin[dim_index], + upper_range_1 - upper.origin[dim_index] - 1, + 1, + ) + if dim_index == concat_dim_index + else ( + output_domain[dim_index][1] - upper.origin[dim_index], + output_domain[dim_index][1] - upper.origin[dim_index] + size - 1, + 1, + ) + for dim_index, size in enumerate(output_desc.shape) + ] + ) + # the upper range should be written next to the lower range, so the destination + # subset does not start from index zero + upper_output_subset = dace_subsets.Range( + [ + ( + lower_range_size, + lower_range_size + upper_range_size - 1, + 1, + ) + if dim_index == concat_dim_index + else (0, size - 1, 1) + for dim_index, size in enumerate(output_desc.shape) + ] + ) + state.add_nedge( + upper.dc_node, + output_node, + dace.Memlet( + data=upper.dc_node.data, + subset=upper_subset, + other_subset=upper_output_subset, + dynamic=True, # this memlet could be empty, but this is known only at runtime + ), + ) + + return FieldopData(output_node, lower.gt_type, origin=tuple(output_origin)) + + # we visit the field arguments for the true and false branch + tb, fb = (sdfg_builder.visit(node.args[i], sdfg=sdfg, head_state=state) for i in [1, 2]) + + return ( + concatenate_inputs( + node.annex.domain, node.args[1].annex.domain, node.args[2].annex.domain, tb, fb + ) + if isinstance(node.type, ts.FieldType) + else gtx_utils.tree_map(concatenate_inputs)( + node.annex.domain, node.args[1].annex.domain, node.args[2].annex.domain, tb, fb + ) + ) + + def _construct_if_branch_output( sdfg: dace.SDFG, state: dace.SDFGState, @@ -1073,6 +1396,7 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, + translate_concat_where, translate_if, translate_index, translate_literal, diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 9eefafe166..5c6c4ab513 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -587,16 +587,16 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: ) deref_node = self._add_tasklet( "deref", - {"field"} | set(index_connectors), - {"val"}, - code=f"val = field[{index_internals}]", + {"__field"} | set(index_connectors), + {"__val"}, + code=f"__val = __field[{index_internals}]", ) # add new termination point for the field parameter self._add_input_data_edge( arg_expr.field, dace_subsets.Range.from_array(field_desc), deref_node, - "field", + "__field", src_offset=[offset for (_, offset) in arg_expr.field_domain], ) @@ -622,7 +622,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: else: assert isinstance(index_expr, SymbolExpr) - return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") + return self._construct_tasklet_result(field_desc.dtype, deref_node, "__val") def _visit_if_branch_arg( self, @@ -1090,13 +1090,16 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: ) elif isinstance(index_arg, ValueExpr): tasklet_node = self._add_tasklet( - "list_get", inputs={"index", "list"}, outputs={"value"}, code="value = list[index]" + "list_get", + inputs={"__index", "list"}, + outputs={"__val"}, + code="__val = list[__index]", ) self._add_edge( index_arg.dc_node, None, tasklet_node, - "index", + "__index", dace.Memlet(data=index_arg.dc_node.data, subset="0"), ) self._add_edge( @@ -1107,7 +1110,7 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: self.sdfg.make_array_memlet(list_arg.dc_node.data), ) self._add_edge( - tasklet_node, "value", result_node, None, dace.Memlet(data=result, subset="0") + tasklet_node, "__val", result_node, None, dace.Memlet(data=result, subset="0") ) else: raise TypeError(f"Unexpected value {index_arg} as index argument.") diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index 199783d893..1219262f51 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -11,6 +11,7 @@ from typing import Any, Callable import numpy as np +import sympy from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt @@ -143,6 +144,11 @@ def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> s return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + def visit_InfinityLiteral( + self, node: gtir.InfinityLiteral, args_map: dict[str, gtir.Node] + ) -> str: + return str(sympy.oo) if node == gtir.InfinityLiteral.POSITIVE else str(-sympy.oo) + def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str: symbol = str(node.id) if symbol in args_map: diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py index 7988a4aad0..dcd4d14fa5 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py @@ -774,7 +774,9 @@ def visit_FunCall( head_state: dace.SDFGState, ) -> gtir_builtin_translators.FieldopResult: # use specialized dataflow builder classes for each builtin function - if cpm.is_call_to(node, "if_"): + if cpm.is_call_to(node, "concat_where"): + return gtir_builtin_translators.translate_concat_where(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "if_"): return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) elif cpm.is_call_to(node, "index"): return gtir_builtin_translators.translate_index(node, sdfg, head_state, self) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py index 9a27cad21c..03c8999242 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py @@ -16,6 +16,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.program_processors.runners.dace import gtir_python_codegen from gt4py.next.type_system import type_specifications as ts @@ -97,6 +98,8 @@ def replace_invalid_symbols(ir: gtir.Program) -> gtir.Program: """ class ReplaceSymbols(eve.PreserveLocationVisitor, eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + T = TypeVar("T", gtir.Sym, gtir.SymRef) def _replace_sym(self, node: T, symtable: Dict[str, str]) -> T: @@ -127,3 +130,13 @@ def visit_SymRef(self, node: gtir.SymRef, *, symtable: Dict[str, str]) -> gtir.S # assert that the new symbol names are not used in the IR assert ir_sym_ids.isdisjoint(invalid_symbols_mapping.values()) return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping) + + +def get_symbolic(node: gtir.Expr) -> dace.symbolic.SymbolicType: + """ + Specialized visit method for symbolic expressions. + + Returns: + A dace symbolic expression of the given GTIR. + """ + return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(node)) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index e7ef2c7c74..e395bcf991 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -223,7 +223,9 @@ class Params: run_gtfn_gpu = GTFNBackendFactory(gpu=True) -run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) +run_gtfn_gpu_cached = GTFNBackendFactory( + gpu=True, cached=True, otf_workflow__cached_translation=True +) run_gtfn_no_transforms = GTFNBackendFactory( otf_workflow__bare_translation__enable_itir_transforms=False diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 2ccbb94443..0bff0b0aa7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -144,7 +144,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ @@ -179,6 +178,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] +GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [ + (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), +] GTFN_SKIP_TEST_LIST = ( COMMON_SKIP_TEST_LIST + DOMAIN_INFERENCE_SKIP_LIST @@ -219,5 +221,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ], - ProgramBackendId.GTIR_EMBEDDED: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.GTIR_EMBEDDED: GTIR_EMBEDDED_SKIP_LIST, } diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 1360fd44cf..abf5707cf6 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -8,6 +8,7 @@ from __future__ import annotations +import copy import dataclasses import functools import inspect @@ -66,6 +67,7 @@ JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] +JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] IKField: TypeAlias = gtx.Field[[IDim, KDim], np.int32] # type: ignore [valid-type] @@ -107,7 +109,7 @@ def scalar(self, dtype: np.typing.DTypeLike) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: ... @@ -140,11 +142,11 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: return constructors.full( - domain=common.domain(sizes), fill_value=self.value, dtype=dtype, allocator=allocator + domain=domain, fill_value=self.value, dtype=dtype, allocator=allocator ) @@ -166,16 +168,17 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: - if len(sizes) > 1: + if len(domain.dims) > 1: raise ValueError( - f"'IndexInitializer' only supports fields with a single 'Dimension', got {sizes}." + f"'IndexInitializer' only supports fields with a single 'Dimension', got {domain}." ) - n_data = list(sizes.values())[0] return constructors.as_field( - domain=common.domain(sizes), data=np.arange(0, n_data, dtype=dtype), allocator=allocator + domain=domain, + data=np.arange(domain.ranges[0].start, domain.ranges[0].stop, dtype=dtype), + allocator=allocator, ) def from_case( @@ -207,16 +210,15 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: common.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: start = self.start - svals = tuple(sizes.values()) - n_data = int(np.prod(svals)) - self.start += n_data + assert isinstance(domain.size, int) + self.start += domain.size return constructors.as_field( - common.domain(sizes), - np.arange(start, start + n_data, dtype=dtype).reshape(svals), + common.domain(domain), + np.arange(start, self.start, dtype=dtype).reshape(domain.shape), allocator=allocator, ) @@ -329,6 +331,7 @@ def allocate( name: str, *, sizes: Optional[dict[gtx.Dimension, int]] = None, + domain: Optional[dict[gtx.Dimension, tuple[int, int]] | gtx.Domain] = None, strategy: Optional[DataInitializer] = None, dtype: Optional[np.typing.DTypeLike] = None, extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None, @@ -350,9 +353,23 @@ def allocate( Useful for shifted fields, which must start off bigger than the output field in the shifted dimension. """ - sizes = extend_sizes( - case.default_sizes | (sizes or {}), extend + if sizes: + assert not domain and all(dim in case.default_sizes for dim in sizes) + domain = { + dim: (0, sizes[dim] if dim in sizes else default_size) + for dim, default_size in case.default_sizes.items() + } + + if not domain: + domain = {dim: (0, size) for dim, size in case.default_sizes.items()} + + if not isinstance(domain, gtx.Domain): + domain = gtx.domain(domain) + + domain = extend_domain( + domain, extend ) # TODO: this should take into account the Domain of the allocated field + arg_type = get_param_types(fieldview_prog)[name] if strategy is None: if name in ["out", RETURN]: @@ -362,7 +379,7 @@ def allocate( return _allocate_from_type( case=case, arg_type=arg_type, - sizes=sizes, + domain=domain, dtype=dtype, strategy=strategy.from_case(case=case, fieldview_prog=fieldview_prog, arg_name=name), ) @@ -551,14 +568,14 @@ def unstructured_case_3d(unstructured_case): return dataclasses.replace( unstructured_case, default_sizes={**unstructured_case.default_sizes, KDim: 10}, - offset_provider={**unstructured_case.offset_provider, "KOff": KDim}, + offset_provider={**unstructured_case.offset_provider, "Koff": KDim}, ) def _allocate_from_type( case: Case, arg_type: ts.TypeSpec, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, strategy: DataInitializer, dtype: Optional[np.typing.DTypeLike] = None, tuple_start: Optional[int] = None, @@ -568,7 +585,7 @@ def _allocate_from_type( case ts.FieldType(dims=dims, dtype=arg_dtype): return strategy.field( allocator=case.allocator, - sizes={dim: sizes[dim] for dim in dims}, + domain=common.domain(tuple(domain[dim] for dim in dims)), dtype=dtype or arg_dtype.kind.name.lower(), ) case ts.ScalarType(kind=kind): @@ -577,7 +594,7 @@ def _allocate_from_type( return tuple( ( _allocate_from_type( - case=case, arg_type=t, sizes=sizes, dtype=dtype, strategy=strategy + case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy ) for t in types ) @@ -613,15 +630,26 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> raise TypeError(f"Can not get size for parameter of type '{param_type}'.") -def extend_sizes( - sizes: dict[gtx.Dimension, int], extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None +def extend_domain( + domain: gtx.Domain, extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None ) -> dict[gtx.Dimension, int]: """Calculate the sizes per dimension given a set of extensions.""" - sizes = sizes.copy() if extend: + domain = copy.deepcopy(domain) for dim, (lower, upper) in extend.items(): - sizes[dim] += upper - lower - return sizes + domain = domain.replace( + dim, + common.named_range( + ( + dim, + ( + domain[dim].unit_range.start - lower, + domain[dim].unit_range.stop + upper, + ), + ) + ), + ) + return domain def get_default_data( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 44ef9b62f0..0ef99d6b50 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -9,8 +9,10 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import KDim, cartesian_case +from next_tests.integration_tests.cases import IDim, JDim, KDim, cartesian_case from gt4py import next as gtx +from gt4py.next import errors +from gt4py.next import broadcast from gt4py.next.ffront.experimental import concat_where from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -20,108 +22,288 @@ pytestmark = pytest.mark.uses_concat_where -def test_boundary_same_size_fields(cartesian_case): +def test_concat_where_simple(cartesian_case): @gtx.field_operator - def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField - ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim > 0, air, ground) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate(cartesian_case, testee, "ground")() + air = cases.allocate(cartesian_case, testee, "air")() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where(k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy()) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + + +def test_concat_where(cartesian_case): + @gtx.field_operator + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, ground, air) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate(cartesian_case, testee, "ground")() + air = cases.allocate(cartesian_case, testee, "air")() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where(k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy()) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + + +def test_concat_where_non_overlapping(cartesian_case): + @gtx.field_operator + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, ground, air) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate( + cartesian_case, testee, "ground", domain=out.domain.slice_at[:, :, 0:1] + )() + air = cases.allocate(cartesian_case, testee, "air", domain=out.domain.slice_at[:, :, 1:])() + + ref = np.concatenate((ground.asnumpy(), air.asnumpy()), axis=2) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + + +def test_concat_where_scalar_broadcast(cartesian_case): + @gtx.field_operator + def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField: + return concat_where(KDim < N - 1, a, b) + + a = 3 + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.concatenate( + ( + np.full((*out.domain.shape[0:2], out.domain.shape[2] - 1), a), + b.asnumpy()[:, :, -1:], + ), + axis=2, + ) + cases.verify(cartesian_case, testee, a, b, cartesian_case.default_sizes[KDim], out=out, ref=ref) + + +def test_concat_where_scalar_broadcast_on_empty_branch(cartesian_case): + @gtx.field_operator + def testee(a: np.int32, b: cases.KField, N: np.int32) -> cases.KField: + return concat_where(KDim < N, a, b) + + a = 3 + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN, domain=b.domain.slice_at[1:])() + + ref = b.asnumpy()[1:] + cases.verify(cartesian_case, testee, a, b, 1, out=out, ref=ref) + + +def test_concat_where_single_level_broadcast(cartesian_case): + @gtx.field_operator + def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, a, b) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + a = cases.allocate( + cartesian_case, testee, "a", domain=gtx.domain({KDim: out.domain.shape[2]}) + )() + b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])() + + ref = np.concatenate( + ( + np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)), + b.asnumpy(), + ), + axis=2, + ) + cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) + + +def test_concat_where_single_level_restricted_domain_broadcast(cartesian_case): + @gtx.field_operator + def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, a, b) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + # note: this field is only defined on K: 0, 1, i.e., contains only a single value + a = cases.allocate(cartesian_case, testee, "a", domain=gtx.domain({KDim: (0, 1)}))() + b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])() + + ref = np.concatenate( + ( + np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)), + b.asnumpy(), + ), + axis=2, + ) + cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) + + +def test_boundary_single_layer_3d_bc(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary")() + boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy(), interior.asnumpy() + k[np.newaxis, np.newaxis, :] == 0, + np.broadcast_to(boundary.asnumpy(), interior.shape), + interior.asnumpy(), ) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) -def test_boundary_horizontal_slice(cartesian_case): +def test_boundary_single_layer_2d_bc(cartesian_case): @gtx.field_operator - def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJField - ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy()[:, :, np.newaxis], interior.asnumpy(), ) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) -def test_boundary_single_layer(cartesian_case): +def test_boundary_single_layer_2d_bc_on_empty_branch(cartesian_case): @gtx.field_operator - def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField - ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate( + cartesian_case, testee, cases.RETURN, domain=interior.domain.slice_at[:, :, 1:] + )() + + ref = interior.asnumpy()[:, :, 1:] + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + + +def test_dimension_two_nested_conditions(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: + return concat_where((KDim < 2), boundary, concat_where((KDim >= 5), boundary, interior)) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, - np.broadcast_to(boundary.asnumpy(), interior.shape), + (k[np.newaxis, np.newaxis, :] < 2) | (k[np.newaxis, np.newaxis, :] >= 5), + boundary.asnumpy(), interior.asnumpy(), ) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension_two_conditions_and(cartesian_case): + @gtx.field_operator + def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: + return concat_where((0 < KDim) & (KDim < (nlev - 1)), interior, boundary) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + nlev = cartesian_case.default_sizes[KDim] + k = np.arange(0, nlev) + ref = np.where((0 < k) & (k < (nlev - 1)), interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, nlev, out=out, ref=ref) -def test_alternating_mask(cartesian_case): +def test_dimension_two_conditions_eq(cartesian_case): @gtx.field_operator - def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: - return concat_where(k % 2 == 0, f1, f0) + def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where((KDim == 2), interior, boundary) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() - f0 = cases.allocate(cartesian_case, testee, "f0")() - f1 = cases.allocate(cartesian_case, testee, "f1")() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where(k == 2, interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) - cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) + +def test_dimension_two_conditions_or(cartesian_case): + @gtx.field_operator + def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where(((KDim < 2) | (KDim >= 5)), boundary, interior) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where((k < 2) | (k >= 5), boundary.asnumpy(), interior.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + + +def test_lap_like(cartesian_case): + @gtx.field_operator + def testee( + input: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] + ) -> cases.IJField: + # TODO add support for multi-dimensional concat_where masks + return concat_where( + (IDim == 0) | (IDim == shape[0] - 1), + boundary, + concat_where((JDim == 0) | (JDim == shape[1] - 1), boundary, input), + ) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + input = cases.allocate( + cartesian_case, testee, "input", domain=out.domain.slice_at[1:-1, 1:-1] + )() + boundary = 2 + + ref = np.full(out.domain.shape, np.nan) + ref[0, :] = boundary + ref[:, 0] = boundary + ref[-1, :] = boundary + ref[:, -1] = boundary + ref[1:-1, 1:-1] = input.asnumpy() + cases.verify(cartesian_case, testee, input, boundary, out.domain.shape, out=out, ref=ref) @pytest.mark.uses_tuple_returns def test_with_tuples(cartesian_case): @gtx.field_operator def testee( - k: cases.KField, interior0: cases.IJKField, boundary0: cases.IJField, interior1: cases.IJKField, boundary1: cases.IJField, - ) -> Tuple[cases.IJKField, cases.IJKField]: - return concat_where(k == 0, (boundary0, boundary1), (interior0, interior1)) + ) -> tuple[cases.IJKField, cases.IJKField]: + return concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior0 = cases.allocate(cartesian_case, testee, "interior0")() boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() interior1 = cases.allocate(cartesian_case, testee, "interior1")() boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref0 = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary0.asnumpy()[:, :, np.newaxis], interior0.asnumpy(), ) ref1 = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary1.asnumpy()[:, :, np.newaxis], interior1.asnumpy(), ) @@ -129,7 +311,71 @@ def testee( cases.verify( cartesian_case, testee, - k, + interior0, + boundary0, + interior1, + boundary1, + out=out, + ref=(ref0, ref1), + ) + + +def test_nested_conditions_with_empty_branches(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IField, boundary: cases.IField, N: gtx.int32) -> cases.IField: + interior = concat_where(IDim == 0, boundary, interior) + interior = concat_where((1 <= IDim) & (IDim < N - 1), interior * 2, interior) + interior = concat_where(IDim == N - 1, boundary, interior) + return interior + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + N = cartesian_case.default_sizes[IDim] + + i = np.arange(0, cartesian_case.default_sizes[IDim]) + ref = np.where( + (i[:] == 0) | (i[:] == N - 1), + boundary.asnumpy(), + interior.asnumpy() * 2, + ) + cases.verify(cartesian_case, testee, interior, boundary, N, out=out, ref=ref) + + +@pytest.mark.uses_tuple_returns +def test_with_tuples_different_domain(cartesian_case): + @gtx.field_operator + def testee( + interior0: cases.IJKField, + boundary0: cases.IJKField, + interior1: cases.KField, + boundary1: cases.KField, + ) -> tuple[cases.IJKField, cases.IJKField]: + a, b = concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) + # the broadcast is only needed since we can not return fields on different domains yet + return a, broadcast(b, (IDim, JDim, KDim)) + + interior0 = cases.allocate(cartesian_case, testee, "interior0")() + boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() + interior1 = cases.allocate(cartesian_case, testee, "interior1")() + boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref0 = np.where( + k[np.newaxis, np.newaxis, :] == 0, + boundary0.asnumpy(), + interior0.asnumpy(), + ) + ref1 = np.where( + k == 0, + boundary1.asnumpy(), + interior1.asnumpy(), + ) + + cases.verify( + cartesian_case, + testee, interior0, boundary0, interior1, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 803ab0c6bc..1a1984a71b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -135,7 +135,7 @@ def test_neighbor_sum(unstructured_case_3d, fop): @pytest.mark.uses_unstructured_shift -def test_reduction_execution_with_offset(unstructured_case): +def test_reduction_execution_with_offset(unstructured_case_3d): EKField: TypeAlias = gtx.Field[[Edge, KDim], np.int32] VKField: TypeAlias = gtx.Field[[Vertex, KDim], np.int32] @@ -152,12 +152,12 @@ def fencil_op(edge_f: EKField) -> VKField: def fencil(edge_f: EKField, out: VKField): fencil_op(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].asnumpy() - field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() - out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() + v2e_table = unstructured_case_3d.offset_provider["V2E"].asnumpy() + field = cases.allocate(unstructured_case_3d, fencil, "edge_f", sizes={KDim: 2})() + out = cases.allocate(unstructured_case_3d, fencil_op, cases.RETURN, sizes={KDim: 1})() cases.verify( - unstructured_case, + unstructured_case_3d, fencil, field, out, @@ -168,7 +168,7 @@ def fencil(edge_f: EKField, out: VKField): initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE, ).reshape(out.shape), - offset_provider=unstructured_case.offset_provider | {"Koff": KDim}, + offset_provider=unstructured_case_3d.offset_provider, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 4a574f256a..0589463777 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -259,6 +259,23 @@ def expression_test_cases(): ), ts.TupleType(types=[float_i_field, float_i_field]), ), + # concat_where + ( + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), + im.ref("a", float_i_field), + im.ref("b", float_ij_field), + ), + float_ij_field, + ), + ( + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), + im.ref("a", ts.TupleType(types=[float_i_field] * 2)), + im.ref("b", ts.TupleType(types=[float_i_field] * 2)), + ), + ts.TupleType(types=[float_i_field] * 2), + ), ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index d97869156b..f0ecd8aff0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause from gt4py.next import common +import pytest from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts @@ -114,7 +115,27 @@ def test_simple_tuple_get_make_tuple(): assert expected == actual -def test_propagate_tuple_get(): +@pytest.mark.parametrize("fun", ["if_", "concat_where"]) +def test_propagate_tuple_get(fun): + testee = im.tuple_get( + 0, im.call(fun)("cond", im.make_tuple("el1", "el2"), im.make_tuple("el1", "el2")) + ) + expected = im.call(fun)( + "cond", + im.tuple_get(0, im.make_tuple("el1", "el2")), + im.tuple_get(0, im.make_tuple("el1", "el2")), + ) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TUPLE_GET, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert expected == actual + + +def test_propagate_tuple_get_let(): expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2"))) testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2"))) actual = CollapseTuple.apply( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py new file mode 100644 index 0000000000..42ad292043 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py @@ -0,0 +1,53 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common +import pytest +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import ( + concat_where, + inline_lambdas, + infer_domain, + collapse_tuple, +) +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.type_system import type_specifications as it_ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +field_type = ts.FieldType(dims=[IDim], dtype=int_type) + + +def test_trivial(): + cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 2)}) + symbolic_domain = domain_utils.SymbolicDomain.from_expr(domain) + + testee = im.concat_where( + cond, + im.make_tuple(im.ref("a", field_type), im.ref("c", field_type)), + im.make_tuple(im.ref("b", field_type), im.ref("d", field_type)), + ) + testee, _ = infer_domain.infer_expr( + testee, + (symbolic_domain, symbolic_domain), + keep_existing_domains=True, + offset_provider={}, + ) + + expected = im.make_tuple(im.concat_where(cond, "a", "b"), im.concat_where(cond, "c", "d")) + + actual = concat_where.expand_tuple_args( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + + actual = collapse_tuple.CollapseTuple.apply( + actual, allow_undeclared_symbols=True, within_stencil=False + ) + + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py new file mode 100644 index 0000000000..beca1084b4 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py @@ -0,0 +1,56 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common +import pytest +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import ( + concat_where, + inline_lambdas, +) +from gt4py.next.type_system import type_specifications as ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +field_type = ts.FieldType(dims=[IDim], dtype=int_type) + + +def test_data(): + return [ + # testee, expected + ( + im.concat_where(im.and_("cond1", "cond2"), "a", "b"), + im.concat_where("cond1", im.concat_where("cond2", "a", "b"), "b"), + ), + ( + im.concat_where(im.or_("cond1", "cond2"), "a", "b"), + im.concat_where("cond1", "a", im.concat_where("cond2", "a", "b")), + ), + ( + im.concat_where(im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), "a", "b"), + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 0)}), + "b", + im.concat_where( + im.domain( + common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)} + ), + "b", + "a", + ), + ), + ), + ] + + +@pytest.mark.parametrize("testee, expected", test_data()) +def test_nested_concat_where(testee, expected): + actual = concat_where.simplify_domain_argument(testee) + actual = inline_lambdas.InlineLambdas.apply(actual, opcount_preserving=True) + + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py new file mode 100644 index 0000000000..2517c7ab55 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py @@ -0,0 +1,89 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common + +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import concat_where, inline_lambdas +from gt4py.next.iterator.transforms.concat_where import transform_to_as_fieldop +from gt4py.next.iterator.transforms.concat_where.transform_to_as_fieldop import _in +from gt4py.next.type_system import type_specifications as ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +JDim = common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL) + + +def test_in_helper(): + pos = im.make_tuple(0, 1) + bounds = { + IDim: (3, 4), + JDim: (5, 6), + } + expected = im.and_( + im.and_( + im.less_equal(bounds[IDim][0], im.tuple_get(0, pos)), + im.less(im.tuple_get(0, pos), bounds[IDim][1]), + ), + im.and_( + im.less_equal(bounds[JDim][0], im.tuple_get(1, pos)), + im.less(im.tuple_get(1, pos), bounds[JDim][1]), + ), + ) + actual = _in(pos, im.domain(common.GridType.CARTESIAN, bounds)) + assert actual == expected + + +def test_trivial(): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 2)}) + + cond = im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}) + testee = im.concat_where(cond, "true_branch", "false_branch") + testee.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + expected = im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1")( + im.if_( + _in(im.deref("__tcw_pos"), cond), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ), + domain, + )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch") + + actual = concat_where.transform_to_as_fieldop(testee) + actual = inline_lambdas.InlineLambdas.apply(actual) # simplify + + assert actual == expected + + +def test_capturing_cond(): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}) + + cond = im.domain(common.GridType.CARTESIAN, {IDim: ("start", "stop")}) + testee = im.concat_where(cond, "true_branch", "false_branch") + testee.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + expected = im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", "start", "stop")( + im.if_( + _in( + im.deref("__tcw_pos"), + im.domain( + common.GridType.CARTESIAN, {IDim: (im.deref("start"), im.deref("stop"))} + ), + ), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ), + domain, + )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch", "start", "stop") + + actual = concat_where.transform_to_as_fieldop(testee) + actual = inline_lambdas.InlineLambdas.apply(actual) # simplify + + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index d9dea7e2d5..bbc6433348 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -200,3 +200,90 @@ def test_constant_folding(test_case): testee, expected = test_case actual = constant_folding.ConstantFolding.apply(testee) assert actual == im.ensure_expr(expected) + + +# TODO: integrate into test structure above +def test_constant_folding_inf_maximum(): + testee = im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) + expected = itir.InfinityLiteral.POSITIVE + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + expected = itir.InfinityLiteral.POSITIVE + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) + expected = im.literal_from_value(1) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + expected = im.literal_from_value(1) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_folding_inf_minimum(): + testee = im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) + expected = im.literal_from_value(1) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + expected = im.literal_from_value(1) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) + expected = itir.InfinityLiteral.NEGATIVE + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + expected = itir.InfinityLiteral.NEGATIVE + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_greater_less(): + testee = im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) + expected = im.literal_from_value(False) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) + expected = im.literal_from_value(True) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) + expected = im.literal_from_value(True) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) + expected = im.literal_from_value(False) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + expected = im.literal_from_value(True) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + expected = im.literal_from_value(False) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + expected = im.literal_from_value(False) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + expected = im.literal_from_value(True) + actual = constant_folding.ConstantFolding.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index d157edde1a..0a090dc525 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1225,6 +1225,120 @@ def test_never_accessed_domain_tuple(offset_provider): run_test_expr(testee, testee, domain, expected_domains, offset_provider) +def test_concat_where(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 4)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 4)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (4, 11)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 4)}) +# Todo: nested concat wheres + + +def test_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 10)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (10, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (10, 20), JDim: (10, 30)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_two_dimensions_J(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {JDim: (20, itir.InfinityLiteral.POSITIVE)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (20, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_nested_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) + domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, itir.InfinityLiteral.POSITIVE)}) + domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 20)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) + domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) + testee = im.concat_where( + domain_cond1, + im.concat_where( + domain_cond2, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ), + im.as_fieldop("deref")("in_field3"), + ) + + expected = im.concat_where( + domain_cond1, # 0, 30; 10,20 + im.concat_where( + domain_cond2, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ), + im.as_fieldop("deref", domain3)("in_field3"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2, "in_field3": domain3} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + def test_broadcast(offset_provider): testee = im.call("broadcast")("in_field", im.make_tuple(itir.AxisLiteral(value="IDim"))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py new file mode 100644 index 0000000000..77ba3719be --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py @@ -0,0 +1,71 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import textwrap + +from gt4py.eve.utils import UIDGenerator +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.infer_domain_ops import InferDomainOps +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding + +from next_tests.integration_tests.cases import IDim, JDim, KDim + + +def test_data(): + return [ + ( + im.less(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + ), + ( + im.less_equal(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 2)}), + ), + ( + im.greater(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.greater_equal(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.less(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.less_equal(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.greater(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + ), + ( + im.greater_equal(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 2)}), + ), + (im.eq(1, im.axis_literal(IDim)), im.domain(common.GridType.CARTESIAN, {IDim: (1, 2)})), + ( + im.not_eq(1, im.axis_literal(IDim)), + im.and_( + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ), + ] + + +@pytest.mark.parametrize("testee,expected", test_data()) +def test_trivial(testee, expected): + actual = InferDomainOps(grid_type=common.GridType.CARTESIAN).visit(testee, recurse=True) + actual = ConstantFolding.apply(actual) # simplify expr to get simpler expected expressions + assert actual == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 8e56967a3a..7a46296ce7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -28,6 +28,7 @@ Cell, Edge, IDim, + JDim, MeshDescriptor, V2EDim, Vertex, @@ -2092,3 +2093,135 @@ def test_gtir_index(): sdfg(v, **FSYMBOLS) np.allclose(v, ref) + + +def test_gtir_concat_where(): + SUBSET_SIZE = 5 + assert SUBSET_SIZE < N + domain = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, N)}) + domain_cond_lhs = im.domain( + gtx_common.GridType.CARTESIAN, {IDim: (gtir.InfinityLiteral.NEGATIVE, N - SUBSET_SIZE)} + ) + domain_cond_rhs = im.domain( + gtx_common.GridType.CARTESIAN, {IDim: (SUBSET_SIZE, gtir.InfinityLiteral.POSITIVE)} + ) + domain_lhs = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, N - SUBSET_SIZE)}) + domain_rhs = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (N - SUBSET_SIZE, N)}) + + concat_expr_lhs = im.concat_where( + domain_cond_lhs, + im.as_fieldop("deref", domain_lhs)("x"), + im.as_fieldop("deref", domain_rhs)("y"), + ) + concat_expr_rhs = im.concat_where( + domain_cond_rhs, + im.as_fieldop("deref", domain_rhs)("y"), + im.as_fieldop("deref", domain_lhs)("x"), + ) + + a = np.random.rand(N) + b = np.random.rand(N) + ref = np.concatenate((a[:SUBSET_SIZE], b[SUBSET_SIZE:])) + + for concat_expr, suffix in [(concat_expr_lhs, "lhs"), (concat_expr_rhs, "rhs")]: + testee = gtir.Program( + id=f"gtir_concat_where_{suffix}", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="y", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="z", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=concat_expr, + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + # run domain inference in order to add the domain annex information to the concat_where node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + c = np.empty_like(a) + + sdfg(a, b, c, **FSYMBOLS) + np.allclose(c, ref) + + +def test_gtir_concat_where_two_dimensions(): + M, N = (30, 20) + domain = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) + domain_cond1 = im.domain( + gtx_common.GridType.CARTESIAN, {JDim: (10, gtir.InfinityLiteral.POSITIVE)} + ) + domain_cond2 = im.domain( + gtx_common.GridType.CARTESIAN, {IDim: (gtir.InfinityLiteral.NEGATIVE, 20)} + ) + domain1 = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + domain2 = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) + domain3 = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) + + testee = gtir.Program( + id=f"gtir_concat_where_two_dimensions", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim, JDim], dtype=FLOAT_TYPE)), + gtir.Sym(id="y", type=ts.FieldType(dims=[IDim, JDim], dtype=FLOAT_TYPE)), + gtir.Sym(id="w", type=ts.FieldType(dims=[IDim, JDim], dtype=FLOAT_TYPE)), + gtir.Sym(id="z", type=ts.FieldType(dims=[IDim, JDim], dtype=FLOAT_TYPE)), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.concat_where( + domain_cond1, # 0, 30; 10,20 + im.concat_where( + domain_cond2, + im.as_fieldop("deref", domain1)("x"), + im.as_fieldop("deref", domain2)("y"), + ), + im.as_fieldop("deref", domain3)("w"), + ), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(M, N) + b = np.random.rand(M, N) + c = np.random.rand(M, N) + d = np.empty_like(a) + ref = np.concatenate( + (c[:, :10], np.concatenate((a[:20, :], b[20:, :]), axis=0)[:, 10:]), axis=1 + ) + + field_symbols = { + "__x_0_range_1": a.shape[0], + "__x_1_range_1": a.shape[1], + "__x_stride_0": a.strides[0] // a.itemsize, + "__x_stride_1": a.strides[1] // a.itemsize, + "__y_0_range_1": b.shape[0], + "__y_1_range_1": b.shape[1], + "__y_stride_0": b.strides[0] // b.itemsize, + "__y_stride_1": b.strides[1] // b.itemsize, + "__w_0_range_1": c.shape[0], + "__w_1_range_1": c.shape[1], + "__w_stride_0": c.strides[0] // c.itemsize, + "__w_stride_1": c.strides[1] // c.itemsize, + "__z_0_range_1": d.shape[0], + "__z_1_range_1": d.shape[1], + "__z_stride_0": d.strides[0] // d.itemsize, + "__z_stride_1": d.strides[1] // d.itemsize, + } + + # run domain inference in order to add the domain annex information to the concat_where node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + + sdfg(a, b, c, d, **field_symbols) + + np.allclose(d, ref)