diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index edabd5f778..171274bce5 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -236,8 +236,8 @@ def apply( if isinstance(node, ast.Slice): slice_node = node elif isinstance(getattr(node, "slice", None), ast.Slice): - # This is the syntax with the inlined slice: K[3:4] - slice_node = node.slice + # This is the previously allowed syntax with the inlined slice: I[0:2] + raise parser.interval_error else: # It is a single value and will therefore be (value):(value+1) slice_node = cls._slice_from_value(node) @@ -258,12 +258,36 @@ def apply( def visit_Subscript(self, node: ast.Subscript) -> nodes.AxisBound: # This allows for the syntax - # `region[I[0] : I[2], J[0] : J[2]]` + # `region[I[0] : I[0] + 2, J[0] : J[0] + 2]` # to exist if not isinstance(node.value, ast.Name): raise self.interval_error if node.value.id != self.axis_name: - raise self.interval_error + raise GTScriptSyntaxError( + "Invalid horizontal range specification:" + f"Expected axis {self.axis_name}, got {node.value.id}" + ) + if isinstance(node.slice, ast.Constant): + if node.slice.value != 0: + raise GTScriptSyntaxError( + "Invalid horizontal range specification:" + f"Expected specification {self.axis_name}[0] or {self.axis_name}[-1]" + f", got {self.axis_name}[{node.slice.value}]" + ) + elif isinstance(node.slice, ast.UnaryOp): + if not isinstance(node.slice.operand, ast.Constant) or node.slice.operand.value not in ( + 0, + 1, + ): + raise GTScriptSyntaxError( + "Invalid horizontal range specification:" + f"Expected specification {self.axis_name}[0] or {self.axis_name}[-1]." + ) + else: + raise GTScriptSyntaxError( + "Invalid horizontal range specification:" + f"Expected axis {self.axis_name}, got {node.value.id}" + ) index = self.visit(node.slice) @@ -285,7 +309,9 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref: f"Using field `{node.id}` with a K-Axis as a bound for an interval is invalid." ) return nodes.FieldRef.at_center( - name=node.id, axes=self.fields[node.id].axes, loc=nodes.Location.from_ast_node(node) + name=node.id, + axes=self.fields[node.id].axes, + loc=nodes.Location.from_ast_node(node), ) # Handle the scalar accesses return nodes.VarRef(name=node.id, loc=nodes.Location.from_ast_node(node)) @@ -1197,7 +1223,8 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref: if self._is_parameter(symbol): return nodes.VarRef( - name=symbol, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name) + name=symbol, + loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), ) if self._is_local_symbol(symbol): @@ -1247,7 +1274,8 @@ def _eval_new_spatial_index( ) if axis_index < last_index: raise GTScriptSyntaxError( - message=f"Axis {value.name} is specified out of order", loc=index_node + message=f"Axis {value.name} is specified out of order", + loc=index_node, ) if axis_index == last_index: raise GTScriptSyntaxError( @@ -1373,7 +1401,9 @@ def visit_UnaryOp(self, node: ast.UnaryOp): return eval("{op}{arg}".format(op=op.python_symbol, arg=arg)) return nodes.UnaryOpExpr( - op=op, arg=arg, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name) + op=op, + arg=arg, + loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), ) def visit_UAdd(self, node: ast.UAdd) -> nodes.UnaryOperator: @@ -1474,7 +1504,10 @@ def visit_Compare(self, node: ast.Compare) -> nodes.BinOpExpr: args.append(lhs) result = nodes.BinOpExpr( - op=op, lhs=lhs, rhs=rhs, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name) + op=op, + lhs=lhs, + rhs=rhs, + loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), ) return result @@ -1768,7 +1801,8 @@ def _resolve_assign( "and not yet implemented for the `gt:X` backends." ) warn_experimental_feature( - feature="2D temporaries", ADR="experimental/2d-temporaries.md" + feature="2D temporaries", + ADR="experimental/2d-temporaries.md", ) axes = self._domain_from_gtscript_axis(field_desc.axes).axes_names diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 86af266ba4..3c97675fd5 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -450,17 +450,23 @@ def two_optional_fields( @register def horizontal_regions(field_in: Field3D, field_out: Field3D): with computation(PARALLEL), interval(...): - with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]): + with horizontal( + region[I[0] : I[0] + 2, J[0] : J[0] + 2], region[I[-1] - 2 : I[-1], J[-1] - 2 : J[-1]] + ): field_out = field_in + 1.0 - with horizontal(region[I[0] : I[2], J[-3] : J[-1]], region[I[-3] : I[-1], J[0] : J[2]]): + with horizontal( + region[I[0] : I[0] + 2, J[-1] - 2 : J[-1]], region[I[-1] - 2 : I[-1], J[0] : J[0] + 2] + ): field_out = field_in - 1.0 @register def horizontal_region_with_conditional(field_in: Field3D, field_out: Field3D): with computation(PARALLEL), interval(...): - with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]): + with horizontal( + region[I[0] : I[0] + 2, J[0] : J[0] + 2], region[I[-1] - 2 : I[-1], J[-1] - 2 : J[-1]] + ): if field_in > 0: field_out = field_in + 1.0 else: diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 032dc3bb5e..91419f1932 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -923,11 +923,17 @@ class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite): def definition(field_in, field_out): with computation(PARALLEL), interval(...): - with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]): + with horizontal( + region[I[0] : I[0] + 2, J[0] : J[0] + 2], + region[I[-1] - 2 : I[-1], J[-1] - 2 : J[-1]], + ): field_out = ( # noqa: F841 [unused-variable] field_in + 1.0 ) - with horizontal(region[I[0] : I[2], J[-3] : J[-1]], region[I[-3] : I[-1], J[0] : J[2]]): + with horizontal( + region[I[0] : I[0] + 2, J[-1] - 2 : J[-1]], + region[I[-1] - 2 : I[-1], J[0] : J[0] + 2], + ): field_out = ( # noqa: F841 [unused-variable] field_in - 1.0 ) diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 38d9054032..aa86673d25 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -808,7 +808,7 @@ def definition_func(field: gtscript.Field[float]): class TestRegions: def test_one_interval_only(self): def stencil(in_f: gtscript.Field[np.float64]): - with computation(PARALLEL), interval(...), horizontal(region[I[0:3], :]): + with computation(PARALLEL), interval(...), horizontal(region[I[0] : I[0] + 3, :]): in_f = 1.0 def_ir = parse_definition( @@ -841,7 +841,7 @@ def stencil(in_f: gtscript.Field[np.float64]): stencil, name=inspect.stack()[0][3], module=self.__class__.__name__, - externals={"i1": I[1]}, + externals={"i1": I[0] + 1}, ) assert len(def_ir.computations) == 1 @@ -914,6 +914,54 @@ def stencil(in_f: gtscript.Field[np.float64]): ): parse_definition(stencil, name=inspect.stack()[0][3], module=self.__class__.__name__) + def test_axis_offset(self): + def stencil(field: gtscript.Field[float]): # type: ignore + with computation(PARALLEL), interval(...): + with horizontal(region[I[0] : I[2], :]): + field[0, 0, 0] = 0 + + with pytest.raises( + gt_frontend.GTScriptSyntaxError, + match="Invalid horizontal range specification", + ): + parse_definition( + stencil, + name=inspect.stack()[0][3], + module=self.__class__.__name__, + ) + + def test_axis_offset_negative(self): + def stencil(field: gtscript.Field[float]): # type: ignore + with computation(PARALLEL), interval(...): + with horizontal(region[I[-3] : I[-1], :]): + field[0, 0, 0] = 0 + + with pytest.raises( + gt_frontend.GTScriptSyntaxError, + match="Invalid horizontal range specification", + ): + parse_definition( + stencil, + name=inspect.stack()[0][3], + module=self.__class__.__name__, + ) + + def test_axis_slice(self): + def stencil(field: gtscript.Field[float]): # type: ignore + with computation(PARALLEL), interval(...): + with horizontal(region[I[0:2], :]): + field[0, 0, 0] = 0 + + with pytest.raises( + gt_frontend.GTScriptSyntaxError, + match="Invalid interval range specification", + ): + parse_definition( + stencil, + name=inspect.stack()[0][3], + module=self.__class__.__name__, + ) + class TestExternalsWithSubroutines: def test_all_legal_combinations(self):