diff --git a/src/gt4py/cartesian/frontend/defir_builder.py b/src/gt4py/cartesian/frontend/defir_builder.py new file mode 100644 index 0000000000..1420051214 --- /dev/null +++ b/src/gt4py/cartesian/frontend/defir_builder.py @@ -0,0 +1,70 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any + +from gt4py.cartesian.frontend.defir_to_gtir import DataDimensionsChecker, UnrollVectorAssignments +from gt4py.cartesian.frontend.nodes import ( + ArgumentInfo, + ComputationBlock, + Domain, + FieldDecl, + Location, + StencilDefinition, + VarDecl, +) + + +class DefIRBuilder: + """Assemble DefinitionIR from DefIR nodes build from the Python AST parsing""" + + def __init__(self, stencil_name: str) -> None: + self.stencil_name = stencil_name + + def build( + self, + domain: Domain, + api_signature: list[ArgumentInfo], + fields_decls: dict[str, FieldDecl], + parameter_decls: dict[str, VarDecl], + computations: list[ComputationBlock], + externals: dict[str, Any] | None = None, + sources: dict[str, str] | None = None, + docstring: str = "", + loc: Location | None = None, + ) -> StencilDefinition: + """Assemble signature, fields and computations nodes into a StencilDefinition""" + api_fields = [ + fields_decls[item.name] for item in api_signature if item.name in fields_decls + ] + parameters = [ + parameter_decls[item.name] for item in api_signature if item.name in parameter_decls + ] + + stencil_definition = StencilDefinition( + name=self.stencil_name, # type: ignore[call-arg] + domain=domain, + api_signature=api_signature, + api_fields=api_fields, + parameters=parameters, + computations=computations, + externals=externals, + sources=sources, + docstring=docstring, + loc=loc, + ) + + stencil_definition = UnrollVectorAssignments.apply( + stencil_definition, + fields_decls=fields_decls, + ) + + # We check fields with data dimensions are all fully indexed + DataDimensionsChecker.apply(stencil_definition, fields_decls) + + return stencil_definition diff --git a/src/gt4py/cartesian/frontend/defir_to_gtir.py b/src/gt4py/cartesian/frontend/defir_to_gtir.py index 78d25b14fb..9c70f8da38 100644 --- a/src/gt4py/cartesian/frontend/defir_to_gtir.py +++ b/src/gt4py/cartesian/frontend/defir_to_gtir.py @@ -10,11 +10,12 @@ import functools import itertools import numbers -from typing import Any, Dict, Final, List, Optional, Tuple, Union, cast +from typing import Any, Final, List, Optional, Tuple, Union, cast import numpy as np from gt4py.cartesian import utils +from gt4py.cartesian.frontend.exceptions import GTScriptSyntaxError from gt4py.cartesian.frontend.node_util import ( IRNodeMapper, IRNodeVisitor, @@ -100,12 +101,33 @@ def _make_literal(v: numbers.Number) -> gtir.Literal: return gtir.Literal(dtype=dtype, value=value) +class DataDimensionsChecker(IRNodeVisitor): + """Check data dimensions are fully indexed. + + Call on a fully visited Definition IR. + """ + + @classmethod + def apply(cls, def_ir: StencilDefinition, field_decls: dict[str, FieldDecl]) -> None: + return cls().visit(def_ir, field_decls=field_decls) + + def visit_FieldRef(self, node: FieldRef, *, field_decls: dict[str, FieldDecl]) -> None: + if len(field_decls[node.name].data_dims) != len(node.data_index): + cdims = [0] * len(field_decls[node.name].axes) + ddims = ["x"] * len(field_decls[node.name].data_dims) + raise GTScriptSyntaxError( + f"Field {node.name} has data dimensions but no data dimensions index is specified. " + f"Use `{node.name}.A{ddims}` or `{node.name}{cdims}{ddims}`.", + loc=node.loc, + ) + + class UnrollVectorAssignments(IRNodeMapper): @classmethod def apply(cls, root, **kwargs): return cls().visit(root, **kwargs) - def _is_vector_assignment(self, stmt: Node, fields_decls: Dict[str, FieldDecl]) -> bool: + def _is_vector_assignment(self, stmt: Node, fields_decls: dict[str, FieldDecl]) -> bool: if not isinstance(stmt, Assign): return False @@ -113,7 +135,7 @@ def _is_vector_assignment(self, stmt: Node, fields_decls: Dict[str, FieldDecl]) return fields_decls[stmt.target.name].data_dims and not stmt.target.data_index def visit_StencilDefinition( - self, node: StencilDefinition, *, fields_decls: Dict[str, FieldDecl], **kwargs + self, node: StencilDefinition, *, fields_decls: dict[str, FieldDecl], **kwargs ) -> StencilDefinition: node = copy.deepcopy(node) @@ -138,7 +160,7 @@ def _nested_list_dim(self, a: List) -> List[int]: return [len(a), *self._nested_list_dim(a[0])] def visit_Assign( - self, node: Assign, *, fields_decls: Dict[str, FieldDecl], **kwargs + self, node: Assign, *, fields_decls: dict[str, FieldDecl], **kwargs ) -> Union[gtir.ParAssignStmt, List[gtir.ParAssignStmt]]: if self._is_vector_assignment(node, fields_decls): assert isinstance(node.target, FieldRef) or isinstance(node.target, VarRef) @@ -174,7 +196,7 @@ def visit_Assign( class UnrollVectorExpressions(IRNodeMapper): @classmethod - def apply(cls, root, *, expected_dim: Tuple[int, ...], fields_decls: Dict[str, FieldDecl]): + def apply(cls, root, *, expected_dim: Tuple[int, ...], fields_decls: dict[str, FieldDecl]): result = cls().visit(root, fields_decls=fields_decls) # if the expression is just a scalar broadcast to the expected dimensions if not isinstance(result, list): @@ -183,7 +205,7 @@ def apply(cls, root, *, expected_dim: Tuple[int, ...], fields_decls: Dict[str, F ) return result - def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl], **kwargs): + def visit_FieldRef(self, node: FieldRef, *, fields_decls: dict[str, FieldDecl], **kwargs): name = node.name if fields_decls[name].data_dims: field_list: List[Union[FieldRef, List[FieldRef]]] = [] @@ -225,7 +247,7 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: Dict[str, FieldDecl], return node - def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: Dict[str, FieldDecl], **kwargs): + def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs): if node.op == UnaryOperator.TRANSPOSED: node = self.visit(node.arg, fields_decls=fields_decls, **kwargs) assert isinstance(node, list) and all( @@ -237,7 +259,7 @@ def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: Dict[str, FieldD return self.generic_visit(node, **kwargs) - def visit_BinOpExpr(self, node: BinOpExpr, *, fields_decls: Dict[str, FieldDecl], **kwargs): + def visit_BinOpExpr(self, node: BinOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs): lhs = self.visit(node.lhs, fields_decls=fields_decls, **kwargs) rhs = self.visit(node.rhs, fields_decls=fields_decls, **kwargs) result: Union[List[BinOpExpr], BinOpExpr] = [] @@ -407,7 +429,7 @@ def visit_StencilDefinition(self, node: StencilDefinition) -> gtir.Stencil: loc=location_to_source_location(node.loc), ) - def visit_ArgumentInfo(self, node: ArgumentInfo, all_params: Dict[str, gtir.Decl]) -> gtir.Decl: + def visit_ArgumentInfo(self, node: ArgumentInfo, all_params: dict[str, gtir.Decl]) -> gtir.Decl: return all_params[node.name] def visit_ComputationBlock(self, node: ComputationBlock) -> gtir.VerticalLoop: diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 171274bce5..6ff2c55e59 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -35,7 +35,8 @@ from gt4py.cartesian import definitions as gt_definitions, gtscript, utils as gt_utils from gt4py.cartesian.frontend import node_util, nodes from gt4py.cartesian.frontend.base import Frontend, register -from gt4py.cartesian.frontend.defir_to_gtir import DefIRToGTIR, UnrollVectorAssignments +from gt4py.cartesian.frontend.defir_builder import DefIRBuilder +from gt4py.cartesian.frontend.defir_to_gtir import DefIRToGTIR from gt4py.cartesian.frontend.exceptions import ( GTScriptAssertionError, GTScriptDataTypeError, @@ -1342,7 +1343,7 @@ def visit_Subscript(self, node: ast.Subscript): if isinstance(result, nodes.VarRef): assert index is not None result.index = index[0] - else: + elif isinstance(result, nodes.FieldRef): if isinstance(index, nodes.AbsoluteKIndex): result.offset = index elif isinstance(node.value, ast.Name): @@ -1391,6 +1392,12 @@ def visit_Subscript(self, node: ast.Subscript): loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), ) + else: + raise GTScriptSyntaxError( + f"Unrecognized node type {type(result)} is subscripted", + loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), + ) + return result # -- Expressions nodes -- @@ -2021,7 +2028,6 @@ def __init__(self, definition, *, options, externals=None, dtypes=None): self.options = options self.build_info = options.build_info self.main_name = options.name - self.definition_ir = None self.external_context = externals or {} self.resolved_externals = {} self.block = None @@ -2464,7 +2470,6 @@ def run(self, backend_name: str): func_node=main_func_node, ) - # Generate definition IR domain = nodes.Domain.LatLonGrid() computations = IRMaker( fields=fields_decls, @@ -2477,27 +2482,17 @@ def run(self, backend_name: str): options=self.options, )(self.ast_root) - self.definition_ir = nodes.StencilDefinition( - name=self.main_name, + return DefIRBuilder(self.main_name).build( domain=domain, api_signature=api_signature, - api_fields=[ - fields_decls[item.name] for item in api_signature if item.name in fields_decls - ], - parameters=[ - parameter_decls[item.name] for item in api_signature if item.name in parameter_decls - ], + fields_decls=fields_decls, + parameter_decls=parameter_decls, computations=init_computations + computations, externals=self.resolved_externals, docstring=inspect.getdoc(self.definition) or "", loc=nodes.Location.from_ast_node(self.ast_root.body[0]), ) - self.definition_ir = UnrollVectorAssignments.apply( - self.definition_ir, fields_decls=fields_decls - ) - return self.definition_ir - @register class GTScriptFrontend(Frontend): diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index aa86673d25..f1f9057060 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -10,6 +10,7 @@ import functools import textwrap import types +import re from typing import Any, Callable, Dict, Optional, Type import numpy as np @@ -1719,6 +1720,41 @@ def data_dims_with_at( module=self.__class__.__name__, ) + # Forgetting to index data dimensions + def data_dims_rhs_not_fully_indexed( + out_field: gtscript.Field[gtscript.IJK, np.int32], + global_field: gtscript.Field[gtscript.IJK, (np.int32, (3))], + ): + with computation(PARALLEL), interval(...): + out_field = global_field[0, 0, 0] # Expecting [0,0,0][0] + + with pytest.raises( + gt_frontend.GTScriptSyntaxError, + match="Field global_field has data dimensions but no data dimensions index is specified. Use", + ): + parse_definition( + data_dims_rhs_not_fully_indexed, + name=inspect.stack()[0][3], + module=self.__class__.__name__, + ) + + def data_dims_rhs_absolute_not_fully_indexed( + out_field: gtscript.Field[gtscript.IJK, np.int32], + global_field: gtscript.Field[gtscript.IJK, (np.int32, (3))], + ): + with computation(PARALLEL), interval(...): + out_field = global_field.at(K=1) # Expecting ddims=[] + + with pytest.raises( + gt_frontend.GTScriptSyntaxError, + match="Field global_field has data dimensions but no data dimensions index is specified. Use*", + ): + parse_definition( + data_dims_rhs_absolute_not_fully_indexed, + name=inspect.stack()[0][3], + module=self.__class__.__name__, + ) + class TestNestedWithSyntax: def test_nested_with(self):