Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def apply(
if isinstance(node, ast.Slice):
slice_node = node
elif isinstance(getattr(node, "slice", None), ast.Slice):
# This is the syntax with the inlined slice: K[3:4]
slice_node = node.slice
# This is the previously allowed syntax with the inlined slice: I[0:2]
raise parser.interval_error
else:
# It is a single value and will therefore be (value):(value+1)
slice_node = cls._slice_from_value(node)
Expand All @@ -258,12 +258,36 @@ def apply(

def visit_Subscript(self, node: ast.Subscript) -> nodes.AxisBound:
# This allows for the syntax
# `region[I[0] : I[2], J[0] : J[2]]`
# `region[I[0] : I[0] + 2, J[0] : J[0] + 2]`
# to exist
if not isinstance(node.value, ast.Name):
raise self.interval_error
if node.value.id != self.axis_name:
raise self.interval_error
raise GTScriptSyntaxError(
"Invalid horizontal range specification:"
f"Expected axis {self.axis_name}, got {node.value.id}"
)
if isinstance(node.slice, ast.Constant):
if node.slice.value != 0:
raise GTScriptSyntaxError(
"Invalid horizontal range specification:"
f"Expected specification {self.axis_name}[0] or {self.axis_name}[-1]"
f", got {self.axis_name}[{node.slice.value}]"
)
elif isinstance(node.slice, ast.UnaryOp):
if not isinstance(node.slice.operand, ast.Constant) or node.slice.operand.value not in (
0,
1,
):
raise GTScriptSyntaxError(
"Invalid horizontal range specification:"
f"Expected specification {self.axis_name}[0] or {self.axis_name}[-1]."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why there this exception to allow I[-1]? What is different about -1?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the goal is to allow to access both ends of the domain. The way this was previously done was with "standard" slicing notation, so I[-1] is the end of the I-domain. That way you can offset from both sides without having to know the domain-size.

And because we're still in AST-land here, -1 is not resolved into a negative constant but it is a unary-up with the value 1.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - all good then.

)
else:
raise GTScriptSyntaxError(
"Invalid horizontal range specification:"
f"Expected axis {self.axis_name}, got {node.value.id}"
)

index = self.visit(node.slice)

Expand All @@ -285,7 +309,9 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref:
f"Using field `{node.id}` with a K-Axis as a bound for an interval is invalid."
)
return nodes.FieldRef.at_center(
name=node.id, axes=self.fields[node.id].axes, loc=nodes.Location.from_ast_node(node)
name=node.id,
axes=self.fields[node.id].axes,
loc=nodes.Location.from_ast_node(node),
)
# Handle the scalar accesses
return nodes.VarRef(name=node.id, loc=nodes.Location.from_ast_node(node))
Expand Down Expand Up @@ -1197,7 +1223,8 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref:

if self._is_parameter(symbol):
return nodes.VarRef(
name=symbol, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name)
name=symbol,
loc=nodes.Location.from_ast_node(node, scope=self.stencil_name),
)

if self._is_local_symbol(symbol):
Expand Down Expand Up @@ -1247,7 +1274,8 @@ def _eval_new_spatial_index(
)
if axis_index < last_index:
raise GTScriptSyntaxError(
message=f"Axis {value.name} is specified out of order", loc=index_node
message=f"Axis {value.name} is specified out of order",
loc=index_node,
)
if axis_index == last_index:
raise GTScriptSyntaxError(
Expand Down Expand Up @@ -1373,7 +1401,9 @@ def visit_UnaryOp(self, node: ast.UnaryOp):
return eval("{op}{arg}".format(op=op.python_symbol, arg=arg))

return nodes.UnaryOpExpr(
op=op, arg=arg, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name)
op=op,
arg=arg,
loc=nodes.Location.from_ast_node(node, scope=self.stencil_name),
)

def visit_UAdd(self, node: ast.UAdd) -> nodes.UnaryOperator:
Expand Down Expand Up @@ -1474,7 +1504,10 @@ def visit_Compare(self, node: ast.Compare) -> nodes.BinOpExpr:
args.append(lhs)

result = nodes.BinOpExpr(
op=op, lhs=lhs, rhs=rhs, loc=nodes.Location.from_ast_node(node, scope=self.stencil_name)
op=op,
lhs=lhs,
rhs=rhs,
loc=nodes.Location.from_ast_node(node, scope=self.stencil_name),
)

return result
Expand Down Expand Up @@ -1768,7 +1801,8 @@ def _resolve_assign(
"and not yet implemented for the `gt:X` backends."
)
warn_experimental_feature(
feature="2D temporaries", ADR="experimental/2d-temporaries.md"
feature="2D temporaries",
ADR="experimental/2d-temporaries.md",
)

axes = self._domain_from_gtscript_axis(field_desc.axes).axes_names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,17 +450,23 @@ def two_optional_fields(
@register
def horizontal_regions(field_in: Field3D, field_out: Field3D):
with computation(PARALLEL), interval(...):
with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]):
with horizontal(
region[I[0] : I[0] + 2, J[0] : J[0] + 2], region[I[-1] - 2 : I[-1], J[-1] - 2 : J[-1]]
):
field_out = field_in + 1.0

with horizontal(region[I[0] : I[2], J[-3] : J[-1]], region[I[-3] : I[-1], J[0] : J[2]]):
with horizontal(
region[I[0] : I[0] + 2, J[-1] - 2 : J[-1]], region[I[-1] - 2 : I[-1], J[0] : J[0] + 2]
):
field_out = field_in - 1.0


@register
def horizontal_region_with_conditional(field_in: Field3D, field_out: Field3D):
with computation(PARALLEL), interval(...):
with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]):
with horizontal(
region[I[0] : I[0] + 2, J[0] : J[0] + 2], region[I[-1] - 2 : I[-1], J[-1] - 2 : J[-1]]
):
if field_in > 0:
field_out = field_in + 1.0
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -923,11 +923,17 @@ class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite):

def definition(field_in, field_out):
with computation(PARALLEL), interval(...):
with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]):
with horizontal(
region[I[0] : I[0] + 2, J[0] : J[0] + 2],
region[I[-1] - 2 : I[-1], J[-1] - 2 : J[-1]],
):
field_out = ( # noqa: F841 [unused-variable]
field_in + 1.0
)
with horizontal(region[I[0] : I[2], J[-3] : J[-1]], region[I[-3] : I[-1], J[0] : J[2]]):
with horizontal(
region[I[0] : I[0] + 2, J[-1] - 2 : J[-1]],
region[I[-1] - 2 : I[-1], J[0] : J[0] + 2],
):
field_out = ( # noqa: F841 [unused-variable]
field_in - 1.0
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def definition_func(field: gtscript.Field[float]):
class TestRegions:
def test_one_interval_only(self):
def stencil(in_f: gtscript.Field[np.float64]):
with computation(PARALLEL), interval(...), horizontal(region[I[0:3], :]):
with computation(PARALLEL), interval(...), horizontal(region[I[0] : I[0] + 3, :]):
in_f = 1.0

def_ir = parse_definition(
Expand Down Expand Up @@ -841,7 +841,7 @@ def stencil(in_f: gtscript.Field[np.float64]):
stencil,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
externals={"i1": I[1]},
externals={"i1": I[0] + 1},
)

assert len(def_ir.computations) == 1
Expand Down Expand Up @@ -914,6 +914,54 @@ def stencil(in_f: gtscript.Field[np.float64]):
):
parse_definition(stencil, name=inspect.stack()[0][3], module=self.__class__.__name__)

def test_axis_offset(self):
def stencil(field: gtscript.Field[float]): # type: ignore
with computation(PARALLEL), interval(...):
with horizontal(region[I[0] : I[2], :]):
field[0, 0, 0] = 0

with pytest.raises(
gt_frontend.GTScriptSyntaxError,
match="Invalid horizontal range specification",
):
parse_definition(
stencil,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)

def test_axis_offset_negative(self):
def stencil(field: gtscript.Field[float]): # type: ignore
with computation(PARALLEL), interval(...):
with horizontal(region[I[-3] : I[-1], :]):
field[0, 0, 0] = 0

with pytest.raises(
gt_frontend.GTScriptSyntaxError,
match="Invalid horizontal range specification",
):
parse_definition(
stencil,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)

def test_axis_slice(self):
def stencil(field: gtscript.Field[float]): # type: ignore
with computation(PARALLEL), interval(...):
with horizontal(region[I[0:2], :]):
field[0, 0, 0] = 0

with pytest.raises(
gt_frontend.GTScriptSyntaxError,
match="Invalid interval range specification",
):
parse_definition(
stencil,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)


class TestExternalsWithSubroutines:
def test_all_legal_combinations(self):
Expand Down
Loading