Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,19 @@ class VerticalIntervalParser(IntervalParser):
if an `ast.Subscript` is passed, this parses its slice attribute.
"""

def __init__(
self,
axis_name: str,
fields: dict[str, nodes.FieldDecl],
literal_precision: int,
loc: Optional[nodes.Location] = None,
):
super().__init__(axis_name, fields, loc)
self._literal_precision = literal_precision

def _default_int_datatype(self) -> nodes.DataType:
return nodes.DataType.from_dtype(np.dtype(f"i{int(self._literal_precision / 8)}"))

def visit_Name(self, node: ast.Name) -> nodes.Ref:
# Handle the field accesses
if node.id in self.fields:
Expand All @@ -322,9 +335,10 @@ def apply(
node: Union[ast.Slice, ast.Subscript, ast.Constant],
axis_name: str,
fields: dict[str, nodes.FieldDecl],
literal_precision: int,
loc: Optional[nodes.Location] = None,
) -> nodes.AxisInterval:
parser = cls(axis_name, fields, loc)
parser = cls(axis_name, fields, literal_precision, loc)

if isinstance(node, ast.Subscript):
raise parser.interval_error
Expand Down Expand Up @@ -354,7 +368,42 @@ def apply(
return nodes.AxisInterval(start=start, end=end, loc=loc)

def visit_Subscript(self, node: ast.Subscript):
# This was previously allowed but is discontinued now.
# Check that this is a higher dimensional field
if isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name):
field_name = node.value.value.id
# Ensure the indexing is correct, first we need a 0-offset in i and j
if isinstance(node.value.slice, ast.Tuple):
axis_offsets: list[ast.Constant] = node.value.slice.elts
if not all(offset.value == 0 for offset in axis_offsets):
raise self.interval_error
else:
raise self.interval_error
# then we parse the actual offset in the higher dimension
if isinstance(node.slice, ast.Tuple):
higher_dim_offset = [self.visit(data_idx) for data_idx in node.slice.elts]
else:
higher_dim_offset = [self.visit(node.slice)]
literal_index = [
nodes.ScalarLiteral(value=i, data_type=self._default_int_datatype())
for i in higher_dim_offset
]

return nodes.FieldRef.at_center(
name=field_name,
axes=self.fields[field_name].axes,
loc=nodes.Location.from_ast_node(node),
data_index=literal_index,
)
# This is a non-higher dimensional field, but a normal field accessed with an offset
if isinstance(node.value, ast.Name) and isinstance(node.slice, ast.Tuple):
# We need to check that the offset is 0 everywhere, since no horizontal dependencies are allowed
axis_offsets: list[ast.Constant] = node.slice.elts
if not all(offset.value == 0 for offset in axis_offsets):
raise self.interval_error
# If the offset is 0, we are safe to visit the field
return self.visit(node.value)

# Legal syntax never allows you to arrive here
raise self.interval_error


Expand Down Expand Up @@ -1073,7 +1122,13 @@ def _visit_interval_node(self, node: ast.withitem, loc: nodes.Location):
interval_node = args[0]

seq_name = nodes.Domain.LatLonGrid().sequential_axis.name
interval = VerticalIntervalParser.apply(interval_node, seq_name, self.fields, loc=loc)
interval = VerticalIntervalParser.apply(
interval_node,
seq_name,
self.fields,
loc=loc,
literal_precision=self.literal_int_precision,
)

if (
interval.start.level == nodes.LevelMarker.END
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,22 @@ def definition_func(inout_field: gtscript.Field[float]):


class TestIntervalSyntax:
def test_simple(self):
# Static intervals
def test_ellipsis(self):
def definition_func(field: gtscript.Field[float]):
with computation(PARALLEL), interval(...):
field[0, 0, 0] = 1

def_ir = parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)
assert def_ir.computations[0].interval.start.offset == 0
assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START

assert def_ir.computations[0].interval.end.offset == 0
assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.END

def test_positive_numbers(self):
def definition_func(field: gtscript.Field[float]):
with computation(PARALLEL), interval(0, 1):
field = 0
Expand Down Expand Up @@ -690,6 +705,20 @@ def definition_func(field: gtscript.Field[float]):
level=nodes.LevelMarker.END, offset=0, loc=loc
)

def test_negative_numbers(self):
def definition_func(field: gtscript.Field[float]): # type: ignore
with computation(PARALLEL), interval(1, -2):
field[0, 0, 0] = 1

def_ir = parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)
assert def_ir.computations[0].interval.start.offset == 1
assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START

assert def_ir.computations[0].interval.end.offset == -2
assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.END

def test_externals(self):
def definition_func(field: gtscript.Field[float]):
from gt4py.cartesian.__externals__ import kstart
Expand All @@ -715,6 +744,88 @@ def definition_func(field: gtscript.Field[float]):
level=nodes.LevelMarker.END, offset=-1, loc=loc
)

def test_nonoverlapping_intervals(self):
def definition_func(field: gtscript.Field[float]):
with computation(PARALLEL):
with interval(0, 2):
field = 0
with interval(3, -1):
field = 1
with interval(-1, None):
field = 2

parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)

# Dynamic intervals
def test_dynamic_scalar(self):
def definition_func(field: gtscript.Field[float], scalar: int): # type: ignore
with computation(PARALLEL), interval(0, scalar):
field[0, 0, 0] = 1

def_ir = parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)
assert def_ir.computations[0].interval.start.offset == 0
assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START

assert isinstance(def_ir.computations[0].interval.end.offset, nodes.VarRef)
assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START

def test_dynamic_field(self):
def definition_func(
field: gtscript.Field[float], # type: ignore
idx_field: gtscript.Field[gtscript.IJ, int], # type: ignore
):
with computation(PARALLEL), interval(0, idx_field):
field[0, 0, 0] = 1

def_ir = parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)
assert def_ir.computations[0].interval.start.offset == 0
assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START

assert isinstance(def_ir.computations[0].interval.end.offset, nodes.FieldRef)
assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START

def test_dynamic_field_zero_offset(self):
def definition_func(
field: gtscript.Field[float], # type: ignore
idx_field: gtscript.Field[gtscript.IJ, int], # type: ignore
):
with computation(PARALLEL), interval(0, idx_field[0, 0]):
field[0, 0, 0] = 1

def_ir = parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)
assert def_ir.computations[0].interval.start.offset == 0
assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START

assert isinstance(def_ir.computations[0].interval.end.offset, nodes.FieldRef)
assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START

def test_dynamic_field_higher_dim(self):
def definition_func(
field: gtscript.Field[float], # type: ignore
idx_field: gtscript.Field[gtscript.IJ, (int, 2)], # type: ignore
):
with computation(PARALLEL), interval(0, idx_field[0, 0][1]):
field[0, 0, 0] = 1

def_ir = parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)
assert def_ir.computations[0].interval.start.offset == 0
assert def_ir.computations[0].interval.start.level == nodes.LevelMarker.START

assert isinstance(def_ir.computations[0].interval.end.offset, nodes.FieldRef)
assert def_ir.computations[0].interval.end.offset.name == "idx_field"
assert def_ir.computations[0].interval.end.level == nodes.LevelMarker.START

# Illegal syntax
def test_error_none(self):
def definition_func(field: gtscript.Field[float]):
with computation(PARALLEL), interval(None, -1):
Expand Down Expand Up @@ -790,19 +901,20 @@ def definition_func(field: gtscript.Field[float]):
module=self.__class__.__name__,
)

def test_nonoverlapping_intervals(self):
def definition_func(field: gtscript.Field[float]):
with computation(PARALLEL):
with interval(0, 2):
field = 0
with interval(3, -1):
field = 1
with interval(-1, None):
field = 2
def test_field_index_with_offset(self):
def definition_func(
field: gtscript.Field[float], # type: ignore
idx_field: gtscript.Field[gtscript.IJ, int], # type: ignore
):
with computation(PARALLEL), interval(0, idx_field[0, 1]):
field[0, 0, 0] = 1

parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)
with pytest.raises(
gt_frontend.GTScriptSyntaxError, match="Invalid interval range specification"
):
def_ir = parse_definition(
definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__
)


class TestRegions:
Expand Down
Loading