From d6e873245206f24427dada14324c92692db43470 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Oct 2024 16:47:20 +0200 Subject: [PATCH 1/3] Add concat_where frontend and domain inference --- src/gt4py/next/ffront/experimental.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 2 + .../ffront/foast_passes/type_deduction.py | 34 +++++- src/gt4py/next/ffront/foast_to_gtir.py | 14 ++- src/gt4py/next/iterator/ir.py | 3 + .../next/iterator/ir_utils/domain_utils.py | 46 +++++++ src/gt4py/next/iterator/ir_utils/ir_makers.py | 5 + .../iterator/transforms/constant_folding.py | 31 ++++- .../next/iterator/transforms/infer_domain.py | 31 +++++ .../iterator/transforms/infer_domain_ops.py | 81 +++++++++++++ .../next/iterator/transforms/pass_manager.py | 2 + .../type_system/type_specifications.py | 4 - .../iterator/type_system/type_synthesizer.py | 23 +++- .../next/type_system/type_specifications.py | 6 +- .../ffront_tests/test_concat_where.py | 2 +- .../iterator_tests/test_type_inference.py | 1 + .../transforms_tests/test_constant_folding.py | 7 ++ .../transforms_tests/test_domain_inference.py | 114 ++++++++++++++++++ 18 files changed, 390 insertions(+), 18 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/infer_domain_ops.py diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index bd22aebe57..c9bea908a8 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -20,7 +20,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi @WhereBuiltinFunction def concat_where( - mask: common.Field, + mask: common.Domain, true_field: common.Field | core_defs.ScalarT | Tuple, false_field: common.Field | core_defs.ScalarT | Tuple, /, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 1210e96efc..d55af4fa29 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -58,6 +58,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType + elif t is common.Domain: + return ts.DomainType elif t is type: return ( ts.FunctionType diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6b40cbb77f..ae3d3c6437 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -20,6 +20,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.foast_passes.utils import compute_assign_indices +from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -570,6 +571,19 @@ def _deduce_compare_type( self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: # check both types compatible + if ( + isinstance(left.type, ts.DimensionType) + and isinstance(right.type, ts.ScalarType) + and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ): + return ts.DomainType(dims=[left.type.dim]) + if ( + isinstance(right.type, ts.DimensionType) + and isinstance(left.type, ts.ScalarType) + and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ): + return ts.DomainType(dims=[right.type.dim]) + # TODO for arg in (left, right): if not type_info.is_arithmetic(arg.type): raise errors.DSLError( @@ -908,6 +922,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) try: + # TODO(tehrengruber): the construct_tuple_type function doesn't look correct if isinstance(true_branch_type, ts.TupleType) and isinstance( false_branch_type, ts.TupleType ): @@ -943,7 +958,24 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: location=node.location, ) - _visit_concat_where = _visit_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: + true_branch_type = node.args[1].type + false_branch_type = node.args[2].type + if true_branch_type != false_branch_type: + raise errors.DSLError( + node.location, + f"Incompatible argument in call to '{node.func!s}': expected " + f"'{true_branch_type}' and '{false_branch_type}' to be equal.", + ) + return_type = true_branch_type + + return foast.Call( + func=node.func, + args=node.args, + kwargs=node.kwargs, + type=return_type, + location=node.location, + ) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4519b4e571..82d16202c4 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -225,6 +225,8 @@ def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: return im.sym(node.id) def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: + if isinstance(node.type, ts.DimensionType): + return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind) return im.ref(node.id) def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: @@ -261,6 +263,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: + # TODO: double-check if we need the changes in the original PR return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: @@ -394,7 +397,13 @@ def create_if( return im.let(cond_symref_name, cond_)(result) - _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + if not isinstance(node.type, ts.TupleType): # to keep the IR simpler + return im.call("concat_where")(*self.visit(node.args)) + else: + raise NotImplementedError() + + # TODO: tuple case def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) @@ -476,8 +485,9 @@ def _map( """ Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. """ + # TODO double-check that this code is consistent with the changes in the original PR if all( - isinstance(t, ts.ScalarType) + isinstance(t, ts.ScalarType, ts.DimensionType) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e875709631..61ac0aee74 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -169,6 +169,9 @@ class FunctionDefinition(Node, SymbolTableTrait): "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) + "concat_where", + "inf", # TODO: discuss + "neg_inf", # TODO: discuss *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 4a023f7535..8e549828eb 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -174,3 +174,49 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: + """Return the (set) intersection of a list of domains.""" + new_domain_ranges = {} + assert all(domain.grid_type == domains[0].grid_type for domain in domains) + for dim in domains[0].ranges.keys(): + start = functools.reduce( + lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + [domain.ranges[dim].start for domain in domains], + ) + stop = functools.reduce( + lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + [domain.ranges[dim].stop for domain in domains], + ) + new_domain_ranges[dim] = SymbolicRange(start, stop) + + return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: + """Return the (set) complement of a domain.""" + dims_dict = {} + for dim in domain.ranges.keys(): + lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop + if lb == im.ref("neg_inf"): + dims_dict[dim] = SymbolicRange(int(ub.value), "inf") + elif ub == im.ref("inf"): + dims_dict[dim] = SymbolicRange("neg_inf", int(lb.value)) + else: + raise ValueError("Invalid domain ranges") + return SymbolicDomain(domain.grid_type, dims_dict) + + +def promote_to_same_dimensions( + domain_small: SymbolicDomain, domain_large: SymbolicDomain +) -> SymbolicDomain: + """Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions.""" + dims_dict = {} + for dim in domain_large.ranges.keys(): + if dim in domain_small.ranges.keys(): + lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop + dims_dict[dim] = SymbolicRange(lb, ub) + else: + dims_dict[dim] = SymbolicRange("neg_inf", "inf") + return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 0839e95b5b..91bfd8b50d 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -246,6 +246,11 @@ def if_(cond, true_val, false_val): return call("if_")(cond, true_val, false_val) +def concat_where(cond, true_field, false_field): + """Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``.""" + return call("concat_where")(cond, true_field, false_field) + + def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..f8e86670ed 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -8,7 +8,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import embedded, ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @@ -21,12 +21,35 @@ def visit_FunCall(self, node: ir.FunCall): new_node = self.generic_visit(node) if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id in ["minimum", "maximum"] + cpm.is_call_to(new_node, ("minimum", "maximum")) and new_node.args[0] == new_node.args[1] ): # `minimum(a, a)` -> `a` return new_node.args[0] + if cpm.is_call_to(new_node, "minimum"): + # `minimum(neg_inf, neg_inf)` -> `neg_inf` + if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to( + new_node.args[1], "neg_inf" + ): + return im.ref("neg_inf") + # `minimum(inf, a)` -> `a` + elif cpm.is_ref_to(new_node.args[0], "inf"): + return new_node.args[1] + # `minimum(a, inf)` -> `a` + elif cpm.is_ref_to(new_node.args[1], "inf"): + return new_node.args[0] + + if cpm.is_call_to(new_node, "maximum"): + # `minimum(inf, inf)` -> `inf` + if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"): + return im.ref("inf") + # `minimum(neg_inf, a)` -> `a` + elif cpm.is_ref_to(new_node.args[0], "neg_inf"): + return new_node.args[1] + # `minimum(a, neg_inf)` -> `a` + elif cpm.is_ref_to(new_node.args[1], "neg_inf"): + return new_node.args[0] + if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" @@ -52,6 +75,6 @@ def visit_FunCall(self, node: ir.FunCall): ] new_node = im.literal_from_value(fun(*arg_values)) except ValueError: - pass # happens for inf and neginf + pass # happens for SymRefs which are not inf or neg_inf return new_node diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f26d3f9ec2..f2044e4b6f 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -363,6 +363,35 @@ def _infer_if( return result_expr, actual_domains +def _infer_concat_where( + expr: itir.Expr, + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: + assert cpm.is_call_to(expr, "concat_where") + infered_args_expr = [] + actual_domains: AccessedDomains = {} + cond, true_field, false_field = expr.args + symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond) + for arg in [true_field, false_field]: + if arg == true_field: + extended_cond = domain_utils.promote_to_same_dimensions(symbolic_cond, domain) + domain_ = domain_utils.domain_intersection(domain, extended_cond) + elif arg == false_field: + cond_complement = domain_utils.domain_complement(symbolic_cond) + extended_cond_complement = domain_utils.promote_to_same_dimensions( + cond_complement, domain + ) + domain_ = domain_utils.domain_intersection(domain, extended_cond_complement) + + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + return result_expr, actual_domains + + def _infer_expr( expr: itir.Expr, domain: DomainAccess, @@ -382,6 +411,8 @@ def _infer_expr( return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): return _infer_if(expr, domain, **kwargs) + elif cpm.is_call_to(expr, "concat_where"): + return _infer_concat_where(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py new file mode 100644 index 0000000000..f4422d506a --- /dev/null +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -0,0 +1,81 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im + + +class InferDomainOps(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if isinstance(node, ir.FunCall) and cpm.is_call_to( + node, ir.BINARY_MATH_COMPARISON_BUILTINS + ): + if isinstance(node.args[0], ir.AxisLiteral) and isinstance(node.args[1], ir.Literal): + dim = common.Dimension(value=node.args[0].value, kind=common.DimensionKind.VERTICAL) + value = int(node.args[1].value) + reverse = False + elif isinstance(node.args[0], ir.Literal) and isinstance(node.args[1], ir.AxisLiteral): + dim = common.Dimension(value=node.args[1].value, kind=common.DimensionKind.VERTICAL) + value = int(node.args[0].value) + reverse = True + else: + raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") + + match node.fun.id: + case ir.SymbolRef("less"): + if reverse: + min = value + 1 + max = "inf" + else: + min = "neg_inf" + max = value - 1 + case ir.SymbolRef("less_equal"): + if reverse: + min = value + max = "inf" + else: + min = "neg_inf" + max = value + case ir.SymbolRef("greater"): + if reverse: + min = "neg_inf" + max = value - 1 + else: + min = value + 1 + max = "inf" + case ir.SymbolRef("greater_equal"): + if reverse: + min = "neg_inf" + max = value + else: + min = value + max = "inf" + case ir.SymbolRef("eq"): + min = max = value + case ir.SymbolRef("not_eq"): + min1 = "neg_inf" + max1 = value - 1 + min2 = value + 1 + max2 = "inf" + return im.call("and_")( + im.domain(common.GridType.CARTESIAN, {dim: (min1, max1)}), + im.domain(common.GridType.CARTESIAN, {dim: (min2, max2)}), + ) + case _: + raise NotImplementedError + + return im.domain(common.GridType.CARTESIAN, {dim: (min, max)}) + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6906f81e3f..e92b4b1fd0 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -15,6 +15,7 @@ fuse_as_fieldop, global_tmps, infer_domain, + infer_domain_ops, inline_dynamic_shifts, inline_fundefs, inline_lifts, @@ -83,6 +84,7 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + ir = infer_domain_ops.InferDomainOps.apply(ir) ir = infer_domain.infer_program( ir, offset_provider=offset_provider, diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 7825bf1c98..30c79c7c94 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -16,10 +16,6 @@ class NamedRangeType(ts.TypeSpec): dim: common.Dimension -class DomainType(ts.DataType): - dims: list[common.Dimension] | Literal["unknown"] - - class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | common.Dimension diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6e9936c4af..7e4e36da98 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -112,7 +112,12 @@ def _(arg: ts.ScalarType) -> ts.ScalarType: @_register_builtin_type_synthesizer( fun_names=itir.BINARY_MATH_COMPARISON_BUILTINS | itir.BINARY_LOGICAL_BUILTINS ) -def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.DimensionType): + return ts.DomainType(dims=[rhs.dim]) + if isinstance(lhs, ts.DimensionType) and isinstance(rhs, ts.ScalarType): + return ts.DomainType(dims=[lhs.dim]) + assert isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.ScalarType) return ts.ScalarType(kind=ts.ScalarKind.BOOL) @@ -183,9 +188,9 @@ def named_range( @_register_builtin_type_synthesizer(fun_names=["cartesian_domain", "unstructured_domain"]) -def _(*args: it_ts.NamedRangeType) -> it_ts.DomainType: +def _(*args: it_ts.NamedRangeType) -> ts.DomainType: assert all(isinstance(arg, it_ts.NamedRangeType) for arg in args) - return it_ts.DomainType(dims=[arg.dim for arg in args]) + return ts.DomainType(dims=[arg.dim for arg in args]) @_register_builtin_type_synthesizer @@ -202,7 +207,17 @@ def index(arg: ts.DimensionType) -> ts.FieldType: @_register_builtin_type_synthesizer -def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: +def concat_where( + domain: ts.DomainType, + true_field: ts.FieldType | ts.TupleType, + false_field: ts.FieldType | ts.TupleType, +) -> ts.FieldType: + assert true_field == false_field + return true_field + + +@_register_builtin_type_synthesizer +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) and isinstance(offset_literal.value, common.Dimension) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index c1c0f0b5e1..fd946075d1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Optional, Sequence, Union +from typing import Iterator, Literal, Optional, Sequence, Union from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common @@ -127,3 +127,7 @@ def __str__(self) -> str: kwarg_strs = [f"{key}: {value}" for key, value in self.pos_or_kw_args.items()] args_str = ", ".join((*arg_strs, *kwarg_strs)) return f"({args_str}) -> {self.returns}" + + +class DomainType(DataType): + dims: list[common.Dimension] | Literal["unknown"] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 364434029f..27e6988744 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -23,7 +23,7 @@ def test_boundary_same_size_fields(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim <= 2, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d4d7c60d69..8af157ebcc 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -72,6 +72,7 @@ def expression_test_cases(): return ( # itir expr, type + # TODO: write test for IDim < 10, concat_where (im.call("abs")(1), int_type), (im.call("power")(2.0, 2), float64_type), (im.plus(1, 2), int_type), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0bf8dcb65d..e0a46b48f6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -60,3 +60,10 @@ def test_constant_folding_literal_maximum(): expected = im.literal_from_value(2) actual = ConstantFolding.apply(testee) assert actual == expected + + +def test_constant_folding_inf_maximum(): + testee = im.call("maximum")(im.literal_from_value(1), im.ref("inf")) + expected = im.ref("inf") + actual = ConstantFolding.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 779ab738cb..034e4993d8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1093,3 +1093,117 @@ def test_never_accessed_domain_tuple(offset_provider): "in_field2": infer_domain.DomainAccessDescriptor.NEVER, } run_test_expr(testee, testee, domain, expected_domains, offset_provider) + + +def test_concat_where(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 4)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 4)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (4, 11)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 4)}) +# Todo: nested concat wheres + + +def test_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 10)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (10, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (10, 20), JDim: (10, 30)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_two_dimensions_J(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {JDim: (20, "inf")}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (20, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_nested_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) + domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, "inf")}) + domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 20)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) + domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) + testee = im.concat_where( + domain_cond1, + im.concat_where( + domain_cond2, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ), + im.as_fieldop("deref")("in_field3"), + ) + + expected = im.concat_where( + domain_cond1, # 0, 30; 10,20 + im.concat_where( + domain_cond2, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ), + im.as_fieldop("deref", domain3)("in_field3"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2, "in_field3": domain3} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) From 69f6b118549586f51c1e15d9b0ba1dec48c71330 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Oct 2024 19:18:45 +0200 Subject: [PATCH 2/3] Finish domain inference for (nested) concat_where and transform to as_fieldop --- src/gt4py/next/iterator/ir.py | 1 + .../next/iterator/ir_utils/domain_utils.py | 8 +-- .../iterator/transforms/constant_folding.py | 13 +++- .../transforms/expand_library_functions.py | 39 ++++++++++++ .../next/iterator/transforms/infer_domain.py | 1 + .../iterator/transforms/infer_domain_ops.py | 60 ++++++++++--------- .../next/iterator/transforms/pass_manager.py | 5 ++ .../transforms/transform_concat_where.py | 34 +++++++++++ .../next/iterator/type_system/inference.py | 2 +- .../iterator/type_system/type_synthesizer.py | 8 +-- .../ffront_tests/test_concat_where.py | 21 ++++++- 11 files changed, 150 insertions(+), 42 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/expand_library_functions.py create mode 100644 src/gt4py/next/iterator/transforms/transform_concat_where.py diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 61ac0aee74..0521d027c4 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -170,6 +170,7 @@ class FunctionDefinition(Node, SymbolTableTrait): "index", # `index(dim)` creates a dim-field that has the current index at each point "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) "concat_where", + "in", "inf", # TODO: discuss "neg_inf", # TODO: discuss *ARITHMETIC_BUILTINS, diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8e549828eb..b66f21cf60 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -200,12 +200,12 @@ def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: for dim in domain.ranges.keys(): lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop if lb == im.ref("neg_inf"): - dims_dict[dim] = SymbolicRange(int(ub.value), "inf") + dims_dict[dim] = SymbolicRange(start=ub, stop=im.ref("inf")) elif ub == im.ref("inf"): - dims_dict[dim] = SymbolicRange("neg_inf", int(lb.value)) + dims_dict[dim] = SymbolicRange(start=im.ref("neg_inf"), stop=lb) else: raise ValueError("Invalid domain ranges") - return SymbolicDomain(domain.grid_type, dims_dict) + return SymbolicDomain(domain.grid_type, dims_dict) def promote_to_same_dimensions( @@ -218,5 +218,5 @@ def promote_to_same_dimensions( lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop dims_dict[dim] = SymbolicRange(lb, ub) else: - dims_dict[dim] = SymbolicRange("neg_inf", "inf") + dims_dict[dim] = SymbolicRange(im.ref("neg_inf"), im.ref("inf")) return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index f8e86670ed..6fe26f886b 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -26,7 +26,7 @@ def visit_FunCall(self, node: ir.FunCall): ): # `minimum(a, a)` -> `a` return new_node.args[0] - if cpm.is_call_to(new_node, "minimum"): + if cpm.is_call_to(new_node, "minimum"): # TODO: add tests # `minimum(neg_inf, neg_inf)` -> `neg_inf` if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to( new_node.args[1], "neg_inf" @@ -39,7 +39,7 @@ def visit_FunCall(self, node: ir.FunCall): elif cpm.is_ref_to(new_node.args[1], "inf"): return new_node.args[0] - if cpm.is_call_to(new_node, "maximum"): + if cpm.is_call_to(new_node, "maximum"): # TODO: add tests # `minimum(inf, inf)` -> `inf` if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"): return im.ref("inf") @@ -49,7 +49,14 @@ def visit_FunCall(self, node: ir.FunCall): # `minimum(a, neg_inf)` -> `a` elif cpm.is_ref_to(new_node.args[1], "neg_inf"): return new_node.args[0] - + if cpm.is_call_to(new_node, ("less", "less_equal")) and cpm.is_ref_to( + new_node.args[0], "neg_inf" + ): + return im.literal_from_value(True) # TODO: add tests + if cpm.is_call_to(new_node, ("greater", "greater_equal")) and cpm.is_ref_to( + new_node.args[0], "inf" + ): + return im.literal_from_value(True) # TODO: add tests if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py new file mode 100644 index 0000000000..0da3ff925c --- /dev/null +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -0,0 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from functools import reduce + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class ExpandLibraryFunctions(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if cpm.is_call_to(node, "in"): + ret = [] + pos, domain = node.args + for i, (k, v) in enumerate( + domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items() + ): + ret.append( + im.and_( + im.less_equal(v.start, im.tuple_get(i, pos)), + im.less(im.tuple_get(i, pos), v.stop), + ) + ) # TODO: avoid pos duplication + return reduce(im.and_, ret) + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f2044e4b6f..cab17e0202 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -369,6 +369,7 @@ def _infer_concat_where( **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "concat_where") + assert isinstance(domain, domain_utils.SymbolicDomain) infered_args_expr = [] actual_domains: AccessedDomains = {} cond, true_field, false_field = expr.args diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index f4422d506a..f86070c1a1 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -19,51 +19,53 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - if isinstance(node, ir.FunCall) and cpm.is_call_to( - node, ir.BINARY_MATH_COMPARISON_BUILTINS - ): - if isinstance(node.args[0], ir.AxisLiteral) and isinstance(node.args[1], ir.Literal): - dim = common.Dimension(value=node.args[0].value, kind=common.DimensionKind.VERTICAL) - value = int(node.args[1].value) + if cpm.is_call_to(node, ir.BINARY_MATH_COMPARISON_BUILTINS): # TODO: add tests + arg1, arg2 = node.args + fun = node.fun + if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal): + dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + value = int(arg2.value) reverse = False - elif isinstance(node.args[0], ir.Literal) and isinstance(node.args[1], ir.AxisLiteral): - dim = common.Dimension(value=node.args[1].value, kind=common.DimensionKind.VERTICAL) - value = int(node.args[0].value) + elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral): + dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + value = int(arg1.value) reverse = True else: raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") - - match node.fun.id: + assert isinstance(fun, ir.SymRef) + min_: int | str + max_: int | str + match fun.id: case ir.SymbolRef("less"): if reverse: - min = value + 1 - max = "inf" + min_ = value + 1 + max_ = "inf" else: - min = "neg_inf" - max = value - 1 + min_ = "neg_inf" + max_ = value - 1 case ir.SymbolRef("less_equal"): if reverse: - min = value - max = "inf" + min_ = value + max_ = "inf" else: - min = "neg_inf" - max = value + min_ = "neg_inf" + max_ = value case ir.SymbolRef("greater"): if reverse: - min = "neg_inf" - max = value - 1 + min_ = "neg_inf" + max_ = value - 1 else: - min = value + 1 - max = "inf" + min_ = value + 1 + max_ = "inf" case ir.SymbolRef("greater_equal"): if reverse: - min = "neg_inf" - max = value + min_ = "neg_inf" + max_ = value else: - min = value - max = "inf" + min_ = value + max_ = "inf" case ir.SymbolRef("eq"): - min = max = value + min_ = max_ = value case ir.SymbolRef("not_eq"): min1 = "neg_inf" max1 = value - 1 @@ -76,6 +78,6 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: case _: raise NotImplementedError - return im.domain(common.GridType.CARTESIAN, {dim: (min, max)}) + return im.domain(common.GridType.CARTESIAN, {dim: (min_, max_)}) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index e92b4b1fd0..edc0fd0fe3 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,6 +12,7 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + expand_library_functions, fuse_as_fieldop, global_tmps, infer_domain, @@ -19,6 +20,7 @@ inline_dynamic_shifts, inline_fundefs, inline_lifts, + transform_concat_where, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -90,6 +92,9 @@ def apply_common_transforms( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) + ir = transform_concat_where.TransformConcatWhere.apply(ir) + ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) + # ir = ConstantFolding.apply(ir) # todo: remove for _ in range(10): inlined = ir diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py new file mode 100644 index 0000000000..92fff34592 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -0,0 +1,34 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() + dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] + return im.as_fieldop( + im.lambda_("pos", "a", "b")( + im.if_(im.call("in")(im.deref("pos"), cond_expr), im.deref("a"), im.deref("b")) + ) + )(im.make_tuple(*dims), field_a, field_b) + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index d0d39cbd34..4c96377895 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,7 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) - assert isinstance(domain, it_ts.DomainType) + assert isinstance(domain, ts.DomainType) assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 7e4e36da98..79831d2064 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -213,7 +213,7 @@ def concat_where( false_field: ts.FieldType | ts.TupleType, ) -> ts.FieldType: assert true_field == false_field - return true_field + return true_field # TODO: tuples? @_register_builtin_type_synthesizer @@ -259,7 +259,7 @@ def apply_lift( def _convert_as_fieldop_input_to_iterator( - domain: it_ts.DomainType, input_: ts.TypeSpec + domain: ts.DomainType, input_: ts.TypeSpec ) -> it_ts.IteratorType: # get the dimensions of all non-zero-dimensional field inputs and check they agree all_input_dims = ( @@ -299,7 +299,7 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( stencil: TypeSynthesizer, - domain: Optional[it_ts.DomainType] = None, + domain: Optional[ts.DomainType] = None, *, offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: @@ -314,7 +314,7 @@ def as_fieldop( # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` # it is unclear if the result has dimension I, J or J, I. if domain is None: - domain = it_ts.DomainType(dims="unknown") + domain = ts.DomainType(dims="unknown") @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 27e6988744..cf96bbe885 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -23,7 +23,7 @@ def test_boundary_same_size_fields(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(KDim <= 2, boundary, interior) + return concat_where(k == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -37,6 +37,25 @@ def testee( cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where(KDim <= 2, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] <= 0, boundary.asnumpy(), interior.asnumpy() + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) # TODO + + def test_boundary_horizontal_slice(cartesian_case): @gtx.field_operator def testee( From 05e74c29838defa729f66e6cc89fb01b512be8ae Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 20 Jan 2025 14:16:18 +0100 Subject: [PATCH 3/3] fix merge conflicts --- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/ffront/foast_to_gtir.py | 4 ++-- src/gt4py/next/iterator/type_system/type_synthesizer.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index d55af4fa29..028761e9fa 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -129,7 +129,7 @@ def __gt_type__(self) -> ts.FunctionType: ) -MaskT = TypeVar("MaskT", bound=common.Field) +MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain]) FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 82d16202c4..636222aa95 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -224,7 +224,7 @@ def visit_Assign( def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: return im.sym(node.id) - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: + def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef | itir.AxisLiteral: if isinstance(node.type, ts.DimensionType): return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind) return im.ref(node.id) @@ -487,7 +487,7 @@ def _map( """ # TODO double-check that this code is consistent with the changes in the original PR if all( - isinstance(t, ts.ScalarType, ts.DimensionType) + isinstance(t, (ts.ScalarType, ts.DimensionType)) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 79831d2064..d740fc58a4 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -211,13 +211,13 @@ def concat_where( domain: ts.DomainType, true_field: ts.FieldType | ts.TupleType, false_field: ts.FieldType | ts.TupleType, -) -> ts.FieldType: +) -> ts.FieldType | ts.TupleType: assert true_field == false_field - return true_field # TODO: tuples? + return true_field @_register_builtin_type_synthesizer -def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) and isinstance(offset_literal.value, common.Dimension)