2D temporaries could be developed as an experimenbtal feature quickly with support in dace:X and debug backend.
gt:X backend are locked behind an hypothesis that pre-suppose temporaries to be 3D - unclear how easy it is to undo.
Syntax remains to be decided. We propose to re-use the typing of parameters within stencils, e.g.:
def the_stencil(in_field: Field[IJK, np.float64], out_field: Field[IJK, np.float64]):
with computation(FORWARD), interval(0, 1):
tmp_2d: Field[IJK, np.float64] = in_field
with computation(PARALLEL), interval(...):
out_field = in_field + tmp_2d
in order to compact the types we also propose to bring down shortcuts that have been introduced in our NDSL layer: FloarField and FloatFieldIJ.
Rewriting the above to
def the_stencil(in_field: FloatField, out_field: FloatField):
with computation(FORWARD), interval(0, 1):
tmp_2d: FloatFieldIJ = in_field
with computation(PARALLEL), interval(...):
out_field = in_field + tmp_2d
There's one caveat. Currently mixed precision implementation as introduced a quick way to define non-standard precision on 3D temporaries. E.g.
tmp_3d: float64 = in_field
With the introduction of this feature we should undo (deprecate then remove) this feature and move to explicitly stating the full type, e.g.
tmp_3d: FloatField64 = in_field
We would, in contrast, keep the current default behavior which is that any temporary defined without a type hint is a FloatField
Dev NOTE.
The frontend work is easy:
- Annotations are already intercepted since the mixed precision work
- Temporary declaration can see the annotations and reason on it
Here's a hard-coded version that works in dace:X
diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py
index 8f461e8c2..3d6d63a8b 100644
--- a/src/gt4py/cartesian/frontend/gtscript_frontend.py
+++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py
@@ -1599,10 +1599,18 @@ class IRMaker(ast.NodeVisitor):
loc=nodes.Location.from_ast_node(t),
)
dtype = nodes.DataType.AUTO
+ axes = nodes.Domain.LatLonGrid().axes_names
if target_annotation is not None:
source = ast.unparse(target_annotation)
+ if source.startswith("IJTemporary"):
+ axes = nodes.Domain.LatLonGrid()
+ axes.sequential_axis = None
+ axes = axes.axes_names
+ dtype_to_translate = ast.unparse(target_annotation.slice)
+ else:
+ dtype_to_translate = source
try:
- dtype = eval(source, self.temporary_type_to_native_type)
+ dtype = eval(dtype_to_translate, self.temporary_type_to_native_type)
except NameError:
raise GTScriptSyntaxError(
message=f"Failed to recognize type {source} for local symbol {name}."
@@ -1612,7 +1620,7 @@ class IRMaker(ast.NodeVisitor):
field_decl = nodes.FieldDecl(
name=name,
data_type=dtype,
- axes=nodes.Domain.LatLonGrid().axes_names,
+ axes=axes,
is_api=False,
loc=nodes.Location.from_ast_node(t),
)
Doing the work properly would mean using eval(target_annotation with a subset of gtscript.Field and/or FloatField symbols, then capture the axes from those object which do carry them.
2D temporaries could be developed as an experimenbtal feature quickly with support in
dace:Xanddebugbackend.gt:Xbackend are locked behind an hypothesis that pre-suppose temporaries to be 3D - unclear how easy it is to undo.Syntax remains to be decided. We propose to re-use the typing of parameters within stencils, e.g.:
in order to compact the types we also propose to bring down shortcuts that have been introduced in our NDSL layer:
FloarFieldandFloatFieldIJ.Rewriting the above to
There's one caveat. Currently mixed precision implementation as introduced a quick way to define non-standard precision on 3D temporaries. E.g.
With the introduction of this feature we should undo (deprecate then remove) this feature and move to explicitly stating the full type, e.g.
We would, in contrast, keep the current default behavior which is that any temporary defined without a type hint is a
FloatFieldDev NOTE.
The frontend work is easy:
Here's a hard-coded version that works in
dace:XDoing the work properly would mean using
eval(target_annotationwith a subset ofgtscript.Fieldand/orFloatFieldsymbols, then capture the axes from those object which do carry them.