Skip to content
70 changes: 70 additions & 0 deletions src/gt4py/cartesian/frontend/defir_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any

from gt4py.cartesian.frontend.defir_to_gtir import DataDimensionsChecker, UnrollVectorAssignments
from gt4py.cartesian.frontend.nodes import (
ArgumentInfo,
ComputationBlock,
Domain,
FieldDecl,
Location,
StencilDefinition,
VarDecl,
)


class DefIRBuilder:
"""Assemble DefinitionIR from DefIR nodes build from the Python AST parsing"""

def __init__(self, stencil_name: str) -> None:
self.stencil_name = stencil_name

def build(
self,
domain: Domain,
api_signature: list[ArgumentInfo],
fields_decls: dict[str, FieldDecl],
parameter_decls: dict[str, VarDecl],
computations: list[ComputationBlock],
externals: dict[str, Any] | None = None,
sources: dict[str, str] | None = None,
docstring: str = "",
loc: Location | None = None,
) -> StencilDefinition:
"""Assemble signature, fields and computations nodes into a StencilDefinition"""
api_fields = [
fields_decls[item.name] for item in api_signature if item.name in fields_decls
]
parameters = [
parameter_decls[item.name] for item in api_signature if item.name in parameter_decls
]

stencil_definition = StencilDefinition(
name=self.stencil_name, # type: ignore[call-arg]
domain=domain,
api_signature=api_signature,
api_fields=api_fields,
parameters=parameters,
computations=computations,
externals=externals,
sources=sources,
docstring=docstring,
loc=loc,
)

stencil_definition = UnrollVectorAssignments.apply(
stencil_definition,
fields_decls=fields_decls,
)

# We check fields with data dimensions are all fully indexed
DataDimensionsChecker.apply(stencil_definition, fields_decls)

return stencil_definition
40 changes: 31 additions & 9 deletions src/gt4py/cartesian/frontend/defir_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import functools
import itertools
import numbers
from typing import Any, Dict, Final, List, Optional, Tuple, Union, cast
from typing import Any, Final, List, Optional, Tuple, Union, cast

import numpy as np

from gt4py.cartesian import utils
from gt4py.cartesian.frontend.exceptions import GTScriptSyntaxError
from gt4py.cartesian.frontend.node_util import (
IRNodeMapper,
IRNodeVisitor,
Expand Down Expand Up @@ -100,20 +101,41 @@ def _make_literal(v: numbers.Number) -> gtir.Literal:
return gtir.Literal(dtype=dtype, value=value)


class DataDimensionsChecker(IRNodeVisitor):
"""Check data dimensions are fully indexed.

Call on a fully visited Definition IR.
"""

@classmethod
def apply(cls, def_ir: StencilDefinition, field_decls: dict[str, FieldDecl]) -> None:
return cls().visit(def_ir, field_decls=field_decls)

def visit_FieldRef(self, node: FieldRef, *, field_decls: dict[str, FieldDecl]) -> None:
if len(field_decls[node.name].data_dims) != len(node.data_index):
cdims = [0] * len(field_decls[node.name].axes)
ddims = ["x"] * len(field_decls[node.name].data_dims)
raise GTScriptSyntaxError(
f"Field {node.name} has data dimensions but no data dimensions index is specified. "
f"Use `{node.name}.A{ddims}` or `{node.name}{cdims}{ddims}`.",
loc=node.loc,
)


class UnrollVectorAssignments(IRNodeMapper):
@classmethod
def apply(cls, root, **kwargs):
return cls().visit(root, **kwargs)

def _is_vector_assignment(self, stmt: Node, fields_decls: Dict[str, FieldDecl]) -> bool:
def _is_vector_assignment(self, stmt: Node, fields_decls: dict[str, FieldDecl]) -> bool:
if not isinstance(stmt, Assign):
return False

# does the referenced field has data dimensions and the access is not element wise
return fields_decls[stmt.target.name].data_dims and not stmt.target.data_index

def visit_StencilDefinition(
self, node: StencilDefinition, *, fields_decls: Dict[str, FieldDecl], **kwargs
self, node: StencilDefinition, *, fields_decls: dict[str, FieldDecl], **kwargs
) -> StencilDefinition:
node = copy.deepcopy(node)

Expand All @@ -138,7 +160,7 @@ def _nested_list_dim(self, a: List) -> List[int]:
return [len(a), *self._nested_list_dim(a[0])]

def visit_Assign(
self, node: Assign, *, fields_decls: Dict[str, FieldDecl], **kwargs
self, node: Assign, *, fields_decls: dict[str, FieldDecl], **kwargs
) -> Union[gtir.ParAssignStmt, List[gtir.ParAssignStmt]]:
if self._is_vector_assignment(node, fields_decls):
assert isinstance(node.target, FieldRef) or isinstance(node.target, VarRef)
Expand Down Expand Up @@ -174,7 +196,7 @@ def visit_Assign(

class UnrollVectorExpressions(IRNodeMapper):
@classmethod
def apply(cls, root, *, expected_dim: Tuple[int, ...], fields_decls: Dict[str, FieldDecl]):
def apply(cls, root, *, expected_dim: Tuple[int, ...], fields_decls: dict[str, FieldDecl]):
result = cls().visit(root, fields_decls=fields_decls)
# if the expression is just a scalar broadcast to the expected dimensions
if not isinstance(result, list):
Expand All @@ -183,7 +205,7 @@ def apply(cls, root, *, expected_dim: Tuple[int, ...], fields_decls: Dict[str, F
)
return result

def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl], **kwargs):
def visit_FieldRef(self, node: FieldRef, *, fields_decls: dict[str, FieldDecl], **kwargs):
name = node.name
if fields_decls[name].data_dims:
field_list: List[Union[FieldRef, List[FieldRef]]] = []
Expand Down Expand Up @@ -225,7 +247,7 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl],

return node

def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: Dict[str, FieldDecl], **kwargs):
def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs):
if node.op == UnaryOperator.TRANSPOSED:
node = self.visit(node.arg, fields_decls=fields_decls, **kwargs)
assert isinstance(node, list) and all(
Expand All @@ -237,7 +259,7 @@ def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: Dict[str, FieldD

return self.generic_visit(node, **kwargs)

def visit_BinOpExpr(self, node: BinOpExpr, *, fields_decls: Dict[str, FieldDecl], **kwargs):
def visit_BinOpExpr(self, node: BinOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs):
lhs = self.visit(node.lhs, fields_decls=fields_decls, **kwargs)
rhs = self.visit(node.rhs, fields_decls=fields_decls, **kwargs)
result: Union[List[BinOpExpr], BinOpExpr] = []
Expand Down Expand Up @@ -407,7 +429,7 @@ def visit_StencilDefinition(self, node: StencilDefinition) -> gtir.Stencil:
loc=location_to_source_location(node.loc),
)

def visit_ArgumentInfo(self, node: ArgumentInfo, all_params: Dict[str, gtir.Decl]) -> gtir.Decl:
def visit_ArgumentInfo(self, node: ArgumentInfo, all_params: dict[str, gtir.Decl]) -> gtir.Decl:
return all_params[node.name]

def visit_ComputationBlock(self, node: ComputationBlock) -> gtir.VerticalLoop:
Expand Down
29 changes: 12 additions & 17 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from gt4py.cartesian import definitions as gt_definitions, gtscript, utils as gt_utils
from gt4py.cartesian.frontend import node_util, nodes
from gt4py.cartesian.frontend.base import Frontend, register
from gt4py.cartesian.frontend.defir_to_gtir import DefIRToGTIR, UnrollVectorAssignments
from gt4py.cartesian.frontend.defir_builder import DefIRBuilder
from gt4py.cartesian.frontend.defir_to_gtir import DefIRToGTIR
from gt4py.cartesian.frontend.exceptions import (
GTScriptAssertionError,
GTScriptDataTypeError,
Expand Down Expand Up @@ -1342,7 +1343,7 @@ def visit_Subscript(self, node: ast.Subscript):
if isinstance(result, nodes.VarRef):
assert index is not None
result.index = index[0]
else:
elif isinstance(result, nodes.FieldRef):
if isinstance(index, nodes.AbsoluteKIndex):
result.offset = index
elif isinstance(node.value, ast.Name):
Expand Down Expand Up @@ -1391,6 +1392,12 @@ def visit_Subscript(self, node: ast.Subscript):
loc=nodes.Location.from_ast_node(node, scope=self.stencil_name),
)

else:
raise GTScriptSyntaxError(
f"Unrecognized node type {type(result)} is subscripted",
loc=nodes.Location.from_ast_node(node, scope=self.stencil_name),
)

return result

# -- Expressions nodes --
Expand Down Expand Up @@ -2021,7 +2028,6 @@ def __init__(self, definition, *, options, externals=None, dtypes=None):
self.options = options
self.build_info = options.build_info
self.main_name = options.name
self.definition_ir = None
self.external_context = externals or {}
self.resolved_externals = {}
self.block = None
Expand Down Expand Up @@ -2464,7 +2470,6 @@ def run(self, backend_name: str):
func_node=main_func_node,
)

# Generate definition IR
domain = nodes.Domain.LatLonGrid()
computations = IRMaker(
fields=fields_decls,
Expand All @@ -2477,27 +2482,17 @@ def run(self, backend_name: str):
options=self.options,
)(self.ast_root)

self.definition_ir = nodes.StencilDefinition(
name=self.main_name,
return DefIRBuilder(self.main_name).build(
domain=domain,
api_signature=api_signature,
api_fields=[
fields_decls[item.name] for item in api_signature if item.name in fields_decls
],
parameters=[
parameter_decls[item.name] for item in api_signature if item.name in parameter_decls
],
fields_decls=fields_decls,
parameter_decls=parameter_decls,
computations=init_computations + computations,
externals=self.resolved_externals,
docstring=inspect.getdoc(self.definition) or "",
loc=nodes.Location.from_ast_node(self.ast_root.body[0]),
)

self.definition_ir = UnrollVectorAssignments.apply(
self.definition_ir, fields_decls=fields_decls
)
return self.definition_ir


@register
class GTScriptFrontend(Frontend):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import functools
import textwrap
import types
import re
from typing import Any, Callable, Dict, Optional, Type

import numpy as np
Expand Down Expand Up @@ -1719,6 +1720,41 @@ def data_dims_with_at(
module=self.__class__.__name__,
)

# Forgetting to index data dimensions
def data_dims_rhs_not_fully_indexed(
out_field: gtscript.Field[gtscript.IJK, np.int32],
global_field: gtscript.Field[gtscript.IJK, (np.int32, (3))],
):
with computation(PARALLEL), interval(...):
out_field = global_field[0, 0, 0] # Expecting [0,0,0][0]

with pytest.raises(
gt_frontend.GTScriptSyntaxError,
match="Field global_field has data dimensions but no data dimensions index is specified. Use",
):
parse_definition(
data_dims_rhs_not_fully_indexed,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)

def data_dims_rhs_absolute_not_fully_indexed(
out_field: gtscript.Field[gtscript.IJK, np.int32],
global_field: gtscript.Field[gtscript.IJK, (np.int32, (3))],
):
with computation(PARALLEL), interval(...):
out_field = global_field.at(K=1) # Expecting ddims=[]

with pytest.raises(
gt_frontend.GTScriptSyntaxError,
match="Field global_field has data dimensions but no data dimensions index is specified. Use*",
):
parse_definition(
data_dims_rhs_absolute_not_fully_indexed,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)


class TestNestedWithSyntax:
def test_nested_with(self):
Expand Down
Loading