diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index bd22aebe57..c9bea908a8 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -20,7 +20,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi @WhereBuiltinFunction def concat_where( - mask: common.Field, + mask: common.Domain, true_field: common.Field | core_defs.ScalarT | Tuple, false_field: common.Field | core_defs.ScalarT | Tuple, /, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 1210e96efc..028761e9fa 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -58,6 +58,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType + elif t is common.Domain: + return ts.DomainType elif t is type: return ( ts.FunctionType @@ -127,7 +129,7 @@ def __gt_type__(self) -> ts.FunctionType: ) -MaskT = TypeVar("MaskT", bound=common.Field) +MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain]) FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6b40cbb77f..ae3d3c6437 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -20,6 +20,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.foast_passes.utils import compute_assign_indices +from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -570,6 +571,19 @@ def _deduce_compare_type( self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: # check both types compatible + if ( + isinstance(left.type, ts.DimensionType) + and isinstance(right.type, ts.ScalarType) + and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ): + return ts.DomainType(dims=[left.type.dim]) + if ( + isinstance(right.type, ts.DimensionType) + and isinstance(left.type, ts.ScalarType) + and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ): + return ts.DomainType(dims=[right.type.dim]) + # TODO for arg in (left, right): if not type_info.is_arithmetic(arg.type): raise errors.DSLError( @@ -908,6 +922,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) try: + # TODO(tehrengruber): the construct_tuple_type function doesn't look correct if isinstance(true_branch_type, ts.TupleType) and isinstance( false_branch_type, ts.TupleType ): @@ -943,7 +958,24 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: location=node.location, ) - _visit_concat_where = _visit_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: + true_branch_type = node.args[1].type + false_branch_type = node.args[2].type + if true_branch_type != false_branch_type: + raise errors.DSLError( + node.location, + f"Incompatible argument in call to '{node.func!s}': expected " + f"'{true_branch_type}' and '{false_branch_type}' to be equal.", + ) + return_type = true_branch_type + + return foast.Call( + func=node.func, + args=node.args, + kwargs=node.kwargs, + type=return_type, + location=node.location, + ) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4519b4e571..636222aa95 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -224,7 +224,9 @@ def visit_Assign( def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: return im.sym(node.id) - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: + def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef | itir.AxisLiteral: + if isinstance(node.type, ts.DimensionType): + return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind) return im.ref(node.id) def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: @@ -261,6 +263,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: + # TODO: double-check if we need the changes in the original PR return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: @@ -394,7 +397,13 @@ def create_if( return im.let(cond_symref_name, cond_)(result) - _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + if not isinstance(node.type, ts.TupleType): # to keep the IR simpler + return im.call("concat_where")(*self.visit(node.args)) + else: + raise NotImplementedError() + + # TODO: tuple case def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) @@ -476,8 +485,9 @@ def _map( """ Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. """ + # TODO double-check that this code is consistent with the changes in the original PR if all( - isinstance(t, ts.ScalarType) + isinstance(t, (ts.ScalarType, ts.DimensionType)) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e875709631..0521d027c4 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -169,6 +169,10 @@ class FunctionDefinition(Node, SymbolTableTrait): "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) + "concat_where", + "in", + "inf", # TODO: discuss + "neg_inf", # TODO: discuss *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 4a023f7535..b66f21cf60 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -174,3 +174,49 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: + """Return the (set) intersection of a list of domains.""" + new_domain_ranges = {} + assert all(domain.grid_type == domains[0].grid_type for domain in domains) + for dim in domains[0].ranges.keys(): + start = functools.reduce( + lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + [domain.ranges[dim].start for domain in domains], + ) + stop = functools.reduce( + lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + [domain.ranges[dim].stop for domain in domains], + ) + new_domain_ranges[dim] = SymbolicRange(start, stop) + + return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: + """Return the (set) complement of a domain.""" + dims_dict = {} + for dim in domain.ranges.keys(): + lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop + if lb == im.ref("neg_inf"): + dims_dict[dim] = SymbolicRange(start=ub, stop=im.ref("inf")) + elif ub == im.ref("inf"): + dims_dict[dim] = SymbolicRange(start=im.ref("neg_inf"), stop=lb) + else: + raise ValueError("Invalid domain ranges") + return SymbolicDomain(domain.grid_type, dims_dict) + + +def promote_to_same_dimensions( + domain_small: SymbolicDomain, domain_large: SymbolicDomain +) -> SymbolicDomain: + """Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions.""" + dims_dict = {} + for dim in domain_large.ranges.keys(): + if dim in domain_small.ranges.keys(): + lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop + dims_dict[dim] = SymbolicRange(lb, ub) + else: + dims_dict[dim] = SymbolicRange(im.ref("neg_inf"), im.ref("inf")) + return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 0839e95b5b..91bfd8b50d 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -246,6 +246,11 @@ def if_(cond, true_val, false_val): return call("if_")(cond, true_val, false_val) +def concat_where(cond, true_field, false_field): + """Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``.""" + return call("concat_where")(cond, true_field, false_field) + + def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..6fe26f886b 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -8,7 +8,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import embedded, ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @@ -21,12 +21,42 @@ def visit_FunCall(self, node: ir.FunCall): new_node = self.generic_visit(node) if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id in ["minimum", "maximum"] + cpm.is_call_to(new_node, ("minimum", "maximum")) and new_node.args[0] == new_node.args[1] ): # `minimum(a, a)` -> `a` return new_node.args[0] + if cpm.is_call_to(new_node, "minimum"): # TODO: add tests + # `minimum(neg_inf, neg_inf)` -> `neg_inf` + if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to( + new_node.args[1], "neg_inf" + ): + return im.ref("neg_inf") + # `minimum(inf, a)` -> `a` + elif cpm.is_ref_to(new_node.args[0], "inf"): + return new_node.args[1] + # `minimum(a, inf)` -> `a` + elif cpm.is_ref_to(new_node.args[1], "inf"): + return new_node.args[0] + + if cpm.is_call_to(new_node, "maximum"): # TODO: add tests + # `minimum(inf, inf)` -> `inf` + if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"): + return im.ref("inf") + # `minimum(neg_inf, a)` -> `a` + elif cpm.is_ref_to(new_node.args[0], "neg_inf"): + return new_node.args[1] + # `minimum(a, neg_inf)` -> `a` + elif cpm.is_ref_to(new_node.args[1], "neg_inf"): + return new_node.args[0] + if cpm.is_call_to(new_node, ("less", "less_equal")) and cpm.is_ref_to( + new_node.args[0], "neg_inf" + ): + return im.literal_from_value(True) # TODO: add tests + if cpm.is_call_to(new_node, ("greater", "greater_equal")) and cpm.is_ref_to( + new_node.args[0], "inf" + ): + return im.literal_from_value(True) # TODO: add tests if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" @@ -52,6 +82,6 @@ def visit_FunCall(self, node: ir.FunCall): ] new_node = im.literal_from_value(fun(*arg_values)) except ValueError: - pass # happens for inf and neginf + pass # happens for SymRefs which are not inf or neg_inf return new_node diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py new file mode 100644 index 0000000000..0da3ff925c --- /dev/null +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -0,0 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from functools import reduce + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class ExpandLibraryFunctions(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if cpm.is_call_to(node, "in"): + ret = [] + pos, domain = node.args + for i, (k, v) in enumerate( + domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items() + ): + ret.append( + im.and_( + im.less_equal(v.start, im.tuple_get(i, pos)), + im.less(im.tuple_get(i, pos), v.stop), + ) + ) # TODO: avoid pos duplication + return reduce(im.and_, ret) + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f26d3f9ec2..cab17e0202 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -363,6 +363,36 @@ def _infer_if( return result_expr, actual_domains +def _infer_concat_where( + expr: itir.Expr, + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: + assert cpm.is_call_to(expr, "concat_where") + assert isinstance(domain, domain_utils.SymbolicDomain) + infered_args_expr = [] + actual_domains: AccessedDomains = {} + cond, true_field, false_field = expr.args + symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond) + for arg in [true_field, false_field]: + if arg == true_field: + extended_cond = domain_utils.promote_to_same_dimensions(symbolic_cond, domain) + domain_ = domain_utils.domain_intersection(domain, extended_cond) + elif arg == false_field: + cond_complement = domain_utils.domain_complement(symbolic_cond) + extended_cond_complement = domain_utils.promote_to_same_dimensions( + cond_complement, domain + ) + domain_ = domain_utils.domain_intersection(domain, extended_cond_complement) + + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + return result_expr, actual_domains + + def _infer_expr( expr: itir.Expr, domain: DomainAccess, @@ -382,6 +412,8 @@ def _infer_expr( return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): return _infer_if(expr, domain, **kwargs) + elif cpm.is_call_to(expr, "concat_where"): + return _infer_concat_where(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py new file mode 100644 index 0000000000..f86070c1a1 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -0,0 +1,83 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im + + +class InferDomainOps(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if cpm.is_call_to(node, ir.BINARY_MATH_COMPARISON_BUILTINS): # TODO: add tests + arg1, arg2 = node.args + fun = node.fun + if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal): + dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + value = int(arg2.value) + reverse = False + elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral): + dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + value = int(arg1.value) + reverse = True + else: + raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") + assert isinstance(fun, ir.SymRef) + min_: int | str + max_: int | str + match fun.id: + case ir.SymbolRef("less"): + if reverse: + min_ = value + 1 + max_ = "inf" + else: + min_ = "neg_inf" + max_ = value - 1 + case ir.SymbolRef("less_equal"): + if reverse: + min_ = value + max_ = "inf" + else: + min_ = "neg_inf" + max_ = value + case ir.SymbolRef("greater"): + if reverse: + min_ = "neg_inf" + max_ = value - 1 + else: + min_ = value + 1 + max_ = "inf" + case ir.SymbolRef("greater_equal"): + if reverse: + min_ = "neg_inf" + max_ = value + else: + min_ = value + max_ = "inf" + case ir.SymbolRef("eq"): + min_ = max_ = value + case ir.SymbolRef("not_eq"): + min1 = "neg_inf" + max1 = value - 1 + min2 = value + 1 + max2 = "inf" + return im.call("and_")( + im.domain(common.GridType.CARTESIAN, {dim: (min1, max1)}), + im.domain(common.GridType.CARTESIAN, {dim: (min2, max2)}), + ) + case _: + raise NotImplementedError + + return im.domain(common.GridType.CARTESIAN, {dim: (min_, max_)}) + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6906f81e3f..edc0fd0fe3 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,12 +12,15 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + expand_library_functions, fuse_as_fieldop, global_tmps, infer_domain, + infer_domain_ops, inline_dynamic_shifts, inline_fundefs, inline_lifts, + transform_concat_where, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -83,11 +86,15 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + ir = infer_domain_ops.InferDomainOps.apply(ir) ir = infer_domain.infer_program( ir, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) + ir = transform_concat_where.TransformConcatWhere.apply(ir) + ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) + # ir = ConstantFolding.apply(ir) # todo: remove for _ in range(10): inlined = ir diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py new file mode 100644 index 0000000000..92fff34592 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -0,0 +1,34 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() + dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] + return im.as_fieldop( + im.lambda_("pos", "a", "b")( + im.if_(im.call("in")(im.deref("pos"), cond_expr), im.deref("a"), im.deref("b")) + ) + )(im.make_tuple(*dims), field_a, field_b) + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index d0d39cbd34..4c96377895 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,7 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) - assert isinstance(domain, it_ts.DomainType) + assert isinstance(domain, ts.DomainType) assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 7825bf1c98..30c79c7c94 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -16,10 +16,6 @@ class NamedRangeType(ts.TypeSpec): dim: common.Dimension -class DomainType(ts.DataType): - dims: list[common.Dimension] | Literal["unknown"] - - class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | common.Dimension diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6e9936c4af..d740fc58a4 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -112,7 +112,12 @@ def _(arg: ts.ScalarType) -> ts.ScalarType: @_register_builtin_type_synthesizer( fun_names=itir.BINARY_MATH_COMPARISON_BUILTINS | itir.BINARY_LOGICAL_BUILTINS ) -def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.DimensionType): + return ts.DomainType(dims=[rhs.dim]) + if isinstance(lhs, ts.DimensionType) and isinstance(rhs, ts.ScalarType): + return ts.DomainType(dims=[lhs.dim]) + assert isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.ScalarType) return ts.ScalarType(kind=ts.ScalarKind.BOOL) @@ -183,9 +188,9 @@ def named_range( @_register_builtin_type_synthesizer(fun_names=["cartesian_domain", "unstructured_domain"]) -def _(*args: it_ts.NamedRangeType) -> it_ts.DomainType: +def _(*args: it_ts.NamedRangeType) -> ts.DomainType: assert all(isinstance(arg, it_ts.NamedRangeType) for arg in args) - return it_ts.DomainType(dims=[arg.dim for arg in args]) + return ts.DomainType(dims=[arg.dim for arg in args]) @_register_builtin_type_synthesizer @@ -201,6 +206,16 @@ def index(arg: ts.DimensionType) -> ts.FieldType: ) +@_register_builtin_type_synthesizer +def concat_where( + domain: ts.DomainType, + true_field: ts.FieldType | ts.TupleType, + false_field: ts.FieldType | ts.TupleType, +) -> ts.FieldType | ts.TupleType: + assert true_field == false_field + return true_field + + @_register_builtin_type_synthesizer def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: assert ( @@ -244,7 +259,7 @@ def apply_lift( def _convert_as_fieldop_input_to_iterator( - domain: it_ts.DomainType, input_: ts.TypeSpec + domain: ts.DomainType, input_: ts.TypeSpec ) -> it_ts.IteratorType: # get the dimensions of all non-zero-dimensional field inputs and check they agree all_input_dims = ( @@ -284,7 +299,7 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( stencil: TypeSynthesizer, - domain: Optional[it_ts.DomainType] = None, + domain: Optional[ts.DomainType] = None, *, offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: @@ -299,7 +314,7 @@ def as_fieldop( # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` # it is unclear if the result has dimension I, J or J, I. if domain is None: - domain = it_ts.DomainType(dims="unknown") + domain = ts.DomainType(dims="unknown") @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index c1c0f0b5e1..fd946075d1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Optional, Sequence, Union +from typing import Iterator, Literal, Optional, Sequence, Union from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common @@ -127,3 +127,7 @@ def __str__(self) -> str: kwarg_strs = [f"{key}: {value}" for key, value in self.pos_or_kw_args.items()] args_str = ", ".join((*arg_strs, *kwarg_strs)) return f"({args_str}) -> {self.returns}" + + +class DomainType(DataType): + dims: list[common.Dimension] | Literal["unknown"] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 364434029f..cf96bbe885 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -37,6 +37,25 @@ def testee( cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where(KDim <= 2, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] <= 0, boundary.asnumpy(), interior.asnumpy() + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) # TODO + + def test_boundary_horizontal_slice(cartesian_case): @gtx.field_operator def testee( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d4d7c60d69..8af157ebcc 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -72,6 +72,7 @@ def expression_test_cases(): return ( # itir expr, type + # TODO: write test for IDim < 10, concat_where (im.call("abs")(1), int_type), (im.call("power")(2.0, 2), float64_type), (im.plus(1, 2), int_type), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0bf8dcb65d..e0a46b48f6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -60,3 +60,10 @@ def test_constant_folding_literal_maximum(): expected = im.literal_from_value(2) actual = ConstantFolding.apply(testee) assert actual == expected + + +def test_constant_folding_inf_maximum(): + testee = im.call("maximum")(im.literal_from_value(1), im.ref("inf")) + expected = im.ref("inf") + actual = ConstantFolding.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 779ab738cb..034e4993d8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1093,3 +1093,117 @@ def test_never_accessed_domain_tuple(offset_provider): "in_field2": infer_domain.DomainAccessDescriptor.NEVER, } run_test_expr(testee, testee, domain, expected_domains, offset_provider) + + +def test_concat_where(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 4)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 4)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (4, 11)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 4)}) +# Todo: nested concat wheres + + +def test_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 10)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (10, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (10, 20), JDim: (10, 30)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_two_dimensions_J(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {JDim: (20, "inf")}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (20, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_nested_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) + domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, "inf")}) + domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 20)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) + domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) + testee = im.concat_where( + domain_cond1, + im.concat_where( + domain_cond2, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ), + im.as_fieldop("deref")("in_field3"), + ) + + expected = im.concat_where( + domain_cond1, # 0, 30; 10,20 + im.concat_where( + domain_cond2, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ), + im.as_fieldop("deref", domain3)("in_field3"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2, "in_field3": domain3} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains)