From dff7b53778c77e35e6c688e91cae1d9b4f37e3ad Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Wed, 25 Mar 2026 14:33:04 +0100 Subject: [PATCH 1/7] remove redundant region syntax --- .../cartesian/frontend/gtscript_frontend.py | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index edabd5f778..34a182c5d7 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -236,8 +236,8 @@ def apply( if isinstance(node, ast.Slice): slice_node = node elif isinstance(getattr(node, "slice", None), ast.Slice): - # This is the syntax with the inlined slice: K[3:4] - slice_node = node.slice + # This is the previously allowed syntax with the inlined slice: I[3:4] + raise parser.interval_error else: # It is a single value and will therefore be (value):(value+1) slice_node = cls._slice_from_value(node) @@ -258,12 +258,21 @@ def apply( def visit_Subscript(self, node: ast.Subscript) -> nodes.AxisBound: # This allows for the syntax - # `region[I[0] : I[2], J[0] : J[2]]` + # `region[I[0] : I[0] + 2, J[0] : J[0] + 2]` # to exist if not isinstance(node.value, ast.Name): raise self.interval_error if node.value.id != self.axis_name: - raise self.interval_error + raise GTScriptSyntaxError( + "Invalid horizontal range specification:" + f"Expected axis {self.axis_name}, got {node.value.id}" + ) + if not isinstance(node.slice, ast.Constant) or node.slice.value not in (0, -1): + raise GTScriptSyntaxError( + "Invalid horizontal range specification:" + f"Expected specification {self.axis_name}[0] or {self.axis_name}[-1]" + f", got {self.axis_name}[{node.slice.value}]" + ) index = self.visit(node.slice) @@ -285,7 +294,9 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref: f"Using field `{node.id}` with a K-Axis as a bound for an interval is invalid." ) return nodes.FieldRef.at_center( - name=node.id, axes=self.fields[node.id].axes, loc=nodes.Location.from_ast_node(node) + name=node.id, + axes=self.fields[node.id].axes, + loc=nodes.Location.from_ast_node(node), ) # Handle the scalar accesses return nodes.VarRef(name=node.id, loc=nodes.Location.from_ast_node(node)) @@ -1197,7 +1208,8 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref: if self._is_parameter(symbol): return nodes.VarRef( - name=symbol, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name) + name=symbol, + loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), ) if self._is_local_symbol(symbol): @@ -1247,7 +1259,8 @@ def _eval_new_spatial_index( ) if axis_index < last_index: raise GTScriptSyntaxError( - message=f"Axis {value.name} is specified out of order", loc=index_node + message=f"Axis {value.name} is specified out of order", + loc=index_node, ) if axis_index == last_index: raise GTScriptSyntaxError( @@ -1373,7 +1386,9 @@ def visit_UnaryOp(self, node: ast.UnaryOp): return eval("{op}{arg}".format(op=op.python_symbol, arg=arg)) return nodes.UnaryOpExpr( - op=op, arg=arg, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name) + op=op, + arg=arg, + loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), ) def visit_UAdd(self, node: ast.UAdd) -> nodes.UnaryOperator: @@ -1474,7 +1489,10 @@ def visit_Compare(self, node: ast.Compare) -> nodes.BinOpExpr: args.append(lhs) result = nodes.BinOpExpr( - op=op, lhs=lhs, rhs=rhs, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name) + op=op, + lhs=lhs, + rhs=rhs, + loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), ) return result @@ -1768,7 +1786,8 @@ def _resolve_assign( "and not yet implemented for the `gt:X` backends." ) warn_experimental_feature( - feature="2D temporaries", ADR="experimental/2d-temporaries.md" + feature="2D temporaries", + ADR="experimental/2d-temporaries.md", ) axes = self._domain_from_gtscript_axis(field_desc.axes).axes_names From 0f3aeb6f114c02d137404f4a901ddccf27cd1554 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Thu, 12 Mar 2026 14:30:36 +0100 Subject: [PATCH 2/7] ADD TESTS --- .../frontend_tests/test_gtscript_frontend.py | 136 ++++++++++++++++-- 1 file changed, 125 insertions(+), 11 deletions(-) diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 38d9054032..61fb89dea0 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -658,7 +658,22 @@ def definition_func(inout_field: gtscript.Field[float]): class TestIntervalSyntax: - def test_simple(self): + # Static intervals + def test_ellipsis(self): + def definition_func(field: gtscript.Field[float]): + with computation(PARALLEL), interval(...): + field[0, 0, 0] = 1 + + def_ir = parse_definition( + definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) + assert def_ir.computations[0].interval.start.offset == 0 + assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START + + assert def_ir.computations[0].interval.end.offset == 0 + assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.END + + def test_positive_numbers(self): def definition_func(field: gtscript.Field[float]): with computation(PARALLEL), interval(0, 1): field = 0 @@ -690,6 +705,20 @@ def definition_func(field: gtscript.Field[float]): level=nodes.LevelMarker.END, offset=0, loc=loc ) + def test_negative_numbers(self): + def definition_func(field: gtscript.Field[float]): # type: ignore + with computation(PARALLEL), interval(1, -2): + field[0, 0, 0] = 1 + + def_ir = parse_definition( + definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) + assert def_ir.computations[0].interval.start.offset == 1 + assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START + + assert def_ir.computations[0].interval.end.offset == -2 + assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.END + def test_externals(self): def definition_func(field: gtscript.Field[float]): from gt4py.cartesian.__externals__ import kstart @@ -715,6 +744,88 @@ def definition_func(field: gtscript.Field[float]): level=nodes.LevelMarker.END, offset=-1, loc=loc ) + def test_nonoverlapping_intervals(self): + def definition_func(field: gtscript.Field[float]): + with computation(PARALLEL): + with interval(0, 2): + field = 0 + with interval(3, -1): + field = 1 + with interval(-1, None): + field = 2 + + parse_definition( + definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) + + # Dynamic intervals + def test_dynamic_scalar(self): + def definition_func(field: gtscript.Field[float], scalar: int): # type: ignore + with computation(PARALLEL), interval(0, scalar): + field[0, 0, 0] = 1 + + def_ir = parse_definition( + definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) + assert def_ir.computations[0].interval.start.offset == 0 + assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START + + assert isinstance(def_ir.computations[0].interval.end.offset, nodes.VarRef) + assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START + + def test_dynamic_field(self): + def definition_func( + field: gtscript.Field[float], # type: ignore + idx_field: gtscript.Field[gtscript.IJ, int], # type: ignore + ): + with computation(PARALLEL), interval(0, idx_field): + field[0, 0, 0] = 1 + + def_ir = parse_definition( + definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) + assert def_ir.computations[0].interval.start.offset == 0 + assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START + + assert isinstance(def_ir.computations[0].interval.end.offset, nodes.FieldRef) + assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START + + def test_dynamic_field_zero_offset(self): + def definition_func( + field: gtscript.Field[float], # type: ignore + idx_field: gtscript.Field[gtscript.IJ, int], # type: ignore + ): + with computation(PARALLEL), interval(0, idx_field[0, 0]): + field[0, 0, 0] = 1 + + def_ir = parse_definition( + definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) + assert def_ir.computations[0].interval.start.offset == 0 + assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START + + assert isinstance(def_ir.computations[0].interval.end.offset, nodes.FieldRef) + assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START + + def test_dynamic_field_higher_dim(self): + def definition_func( + field: gtscript.Field[float], # type: ignore + idx_field: gtscript.Field[gtscript.IJ, (int, 2)], # type: ignore + ): + with computation(PARALLEL), interval(0, idx_field[0, 0][1]): + field[0, 0, 0] = 1 + + def_ir = parse_definition( + definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) + assert def_ir.computations[0].interval.start.offset == 0 + assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START + + assert isinstance(def_ir.computations[0].interval.end.offset, nodes.FieldRef) + assert def_ir.computations[0].interval.end.offset.name == "idx_field" + assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START + + # Illegal syntax def test_error_none(self): def definition_func(field: gtscript.Field[float]): with computation(PARALLEL), interval(None, -1): @@ -790,19 +901,22 @@ def definition_func(field: gtscript.Field[float]): module=self.__class__.__name__, ) - def test_nonoverlapping_intervals(self): - def definition_func(field: gtscript.Field[float]): - with computation(PARALLEL): - with interval(0, 2): - field = 0 - with interval(3, -1): - field = 1 - with interval(-1, None): - field = 2 + def test_field_index_with_offset(self): + def definition_func( + field: gtscript.Field[float], # type: ignore + idx_field: gtscript.Field[gtscript.IJ, int], # type: ignore + ): + with computation(PARALLEL), interval(0, idx_field[0, 0]): + field[0, 0, 0] = 1 - parse_definition( + def_ir = parse_definition( definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ ) + assert def_ir.computations[0].interval.start.offset == 0 + assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START + + assert isinstance(def_ir.computations[0].interval.end.offset, nodes.FieldRef) + assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START class TestRegions: From be8f4efb204766fa90870ce9b70861ed9b7491c3 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Thu, 5 Mar 2026 13:05:58 +0100 Subject: [PATCH 3/7] add hdf work --- .../cartesian/frontend/gtscript_frontend.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 34a182c5d7..a8ad9c425e 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -339,7 +339,42 @@ def apply( return nodes.AxisInterval(start=start, end=end, loc=loc) def visit_Subscript(self, node: ast.Subscript): - # This was previously allowed but is discontinued now. + # Check that this is a higher dimensional field + if isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name): + field_name = node.value.value.id + # Ensure the indexing is correct, first we need a 0-offset in i and j + if isinstance(node.value.slice, ast.Tuple): + axis_offsets: list[ast.Constant] = node.value.slice.elts + if not all(offset.value == 0 for offset in axis_offsets): + raise self.interval_error + else: + raise self.interval_error + # then we parse the actual offset in the higher dimension + if isinstance(node.slice, ast.Tuple): + higher_dim_offset = [self.visit(data_idx) for data_idx in node.slice.elts] + else: + higher_dim_offset = [self.visit(node.slice)] + literal_index = [ + nodes.ScalarLiteral(value=i, data_type=nodes.DataType.INT32) + for i in higher_dim_offset + ] + + return nodes.FieldRef.at_center( + name=field_name, + axes=self.fields[field_name].axes, + loc=nodes.Location.from_ast_node(node), + data_index=literal_index, + ) + # This is a non-higher dimensional field, but a normal field accessed with an offset + if isinstance(node.value, ast.Name) and isinstance(node.slice, ast.Tuple): + # We need to check that the offset is 0 everywhere, since no horizontal dependencies are allowed + axis_offsets: list[ast.Constant] = node.slice.elts + if not all(offset.value == 0 for offset in axis_offsets): + raise self.interval_error + # If the offset is 0, we are safe to visit the field + return self.visit(node.value) + + # Legal syntax never allows you to arrive here raise self.interval_error From 430d5a0db41909fcc56818383509b27f968efcf9 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Thu, 26 Mar 2026 10:35:52 +0100 Subject: [PATCH 4/7] fix wrong cherrypick --- .../cartesian/frontend/gtscript_frontend.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index a8ad9c425e..01d1ef4375 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -236,8 +236,8 @@ def apply( if isinstance(node, ast.Slice): slice_node = node elif isinstance(getattr(node, "slice", None), ast.Slice): - # This is the previously allowed syntax with the inlined slice: I[3:4] - raise parser.interval_error + # This is the syntax with the inlined slice: K[3:4] + slice_node = node.slice else: # It is a single value and will therefore be (value):(value+1) slice_node = cls._slice_from_value(node) @@ -258,21 +258,12 @@ def apply( def visit_Subscript(self, node: ast.Subscript) -> nodes.AxisBound: # This allows for the syntax - # `region[I[0] : I[0] + 2, J[0] : J[0] + 2]` + # `region[I[0] : I[2], J[0] : J[2]]` # to exist if not isinstance(node.value, ast.Name): raise self.interval_error if node.value.id != self.axis_name: - raise GTScriptSyntaxError( - "Invalid horizontal range specification:" - f"Expected axis {self.axis_name}, got {node.value.id}" - ) - if not isinstance(node.slice, ast.Constant) or node.slice.value not in (0, -1): - raise GTScriptSyntaxError( - "Invalid horizontal range specification:" - f"Expected specification {self.axis_name}[0] or {self.axis_name}[-1]" - f", got {self.axis_name}[{node.slice.value}]" - ) + raise self.interval_error index = self.visit(node.slice) From c2c8a7c0005053ce7eea76d2cec7c895d320eced Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Fri, 27 Mar 2026 16:52:49 +0100 Subject: [PATCH 5/7] reviewer's comments --- .../cartesian/frontend/gtscript_frontend.py | 20 ++++++++++++++++--- .../frontend_tests/test_gtscript_frontend.py | 16 +++++++-------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 01d1ef4375..b8e6953e9c 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -102,10 +102,12 @@ def __init__( axis_name: str, fields: dict[str, nodes.FieldDecl], loc: Optional[nodes.Location] = None, + literal_precision: Optional[int] = None, ): self.axis_name = axis_name self.fields = fields self.loc = loc + self._literal_precision = literal_precision error_msg = "Invalid interval range specification" @@ -114,6 +116,11 @@ def __init__( self.interval_error = GTScriptSyntaxError(error_msg) + def _default_int_datatype(self) -> nodes.DataType: + if self._literal_precision: + return nodes.DataType.from_dtype(np.dtype(f"i{int(self._literal_precision / 8)}")) + return nodes.DataType.INT64 + @staticmethod def _slice_from_value(node: ast.Expr) -> ast.Slice: """Create an ast.Slice node from a general ast.Expr node.""" @@ -299,8 +306,9 @@ def apply( axis_name: str, fields: dict[str, nodes.FieldDecl], loc: Optional[nodes.Location] = None, + literal_precision: Optional[int] = None, ) -> nodes.AxisInterval: - parser = cls(axis_name, fields, loc) + parser = cls(axis_name, fields, loc, literal_precision) if isinstance(node, ast.Subscript): raise parser.interval_error @@ -346,7 +354,7 @@ def visit_Subscript(self, node: ast.Subscript): else: higher_dim_offset = [self.visit(node.slice)] literal_index = [ - nodes.ScalarLiteral(value=i, data_type=nodes.DataType.INT32) + nodes.ScalarLiteral(value=i, data_type=self._default_int_datatype()) for i in higher_dim_offset ] @@ -1084,7 +1092,13 @@ def _visit_interval_node(self, node: ast.withitem, loc: nodes.Location): interval_node = args[0] seq_name = nodes.Domain.LatLonGrid().sequential_axis.name - interval = VerticalIntervalParser.apply(interval_node, seq_name, self.fields, loc=loc) + interval = VerticalIntervalParser.apply( + interval_node, + seq_name, + self.fields, + loc=loc, + literal_precision=self.literal_int_precision, + ) if ( interval.start.level == nodes.LevelMarker.END diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 61fb89dea0..f70e7a67e4 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -906,17 +906,15 @@ def definition_func( field: gtscript.Field[float], # type: ignore idx_field: gtscript.Field[gtscript.IJ, int], # type: ignore ): - with computation(PARALLEL), interval(0, idx_field[0, 0]): + with computation(PARALLEL), interval(0, idx_field[0, 1]): field[0, 0, 0] = 1 - def_ir = parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ - ) - assert def_ir.computations[0].interval.start.offset == 0 - assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START - - assert isinstance(def_ir.computations[0].interval.end.offset, nodes.FieldRef) - assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START + with pytest.raises( + gt_frontend.GTScriptSyntaxError, match="Invalid interval range specification" + ): + def_ir = parse_definition( + definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + ) class TestRegions: From 795709f5c605f576f1747e8b6f50fd827ec49098 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Tue, 31 Mar 2026 17:49:45 +0200 Subject: [PATCH 6/7] review v2 --- .../cartesian/frontend/gtscript_frontend.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index a4a70aee0f..484327914f 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -102,12 +102,10 @@ def __init__( axis_name: str, fields: dict[str, nodes.FieldDecl], loc: Optional[nodes.Location] = None, - literal_precision: Optional[int] = None, ): self.axis_name = axis_name self.fields = fields self.loc = loc - self._literal_precision = literal_precision error_msg = "Invalid interval range specification" @@ -308,6 +306,16 @@ class VerticalIntervalParser(IntervalParser): if an `ast.Subscript` is passed, this parses its slice attribute. """ + def __init__( + self, + axis_name: str, + fields: dict[str, nodes.FieldDecl], + literal_precision: int, + loc: Optional[nodes.Location] = None, + ): + super().__init__(axis_name, fields, loc) + self._literal_precision = literal_precision + def visit_Name(self, node: ast.Name) -> nodes.Ref: # Handle the field accesses if node.id in self.fields: @@ -329,10 +337,10 @@ def apply( node: Union[ast.Slice, ast.Subscript, ast.Constant], axis_name: str, fields: dict[str, nodes.FieldDecl], + literal_precision: int, loc: Optional[nodes.Location] = None, - literal_precision: Optional[int] = None, ) -> nodes.AxisInterval: - parser = cls(axis_name, fields, loc, literal_precision) + parser = cls(axis_name, fields, literal_precision, loc) if isinstance(node, ast.Subscript): raise parser.interval_error From e4724db0e595587caf71207f4fc6f00ab63a95ae Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Wed, 1 Apr 2026 13:42:54 +0200 Subject: [PATCH 7/7] move int precision to vertical only --- src/gt4py/cartesian/frontend/gtscript_frontend.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 484327914f..1a52d6d683 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -114,11 +114,6 @@ def __init__( self.interval_error = GTScriptSyntaxError(error_msg) - def _default_int_datatype(self) -> nodes.DataType: - if self._literal_precision: - return nodes.DataType.from_dtype(np.dtype(f"i{int(self._literal_precision / 8)}")) - return nodes.DataType.INT64 - @staticmethod def _slice_from_value(node: ast.Expr) -> ast.Slice: """Create an ast.Slice node from a general ast.Expr node.""" @@ -316,6 +311,9 @@ def __init__( super().__init__(axis_name, fields, loc) self._literal_precision = literal_precision + def _default_int_datatype(self) -> nodes.DataType: + return nodes.DataType.from_dtype(np.dtype(f"i{int(self._literal_precision / 8)}")) + def visit_Name(self, node: ast.Name) -> nodes.Ref: # Handle the field accesses if node.id in self.fields: