Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def write_before_read(
if offsets[tmp] != {(0, 0, 0)}:
return False
return next(
o.is_write and o.horizontal_mask is None
for o in ordered_accesses
if o.field == tmp
o.is_write and not o.is_conditional for o in ordered_accesses if o.field == tmp
)

write_before_read_tmps = {
Expand Down
20 changes: 17 additions & 3 deletions src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,20 @@ class GenericAccess(Generic[OffsetT]):
is_write: bool
data_index: List[oir.Expr] = dataclasses.field(default_factory=list)
horizontal_mask: Optional[common.HorizontalMask] = None
conditional: bool | None = None
"""Set to True if this access is conditional, e.g. part of the body of a MaskStmt."""

@property
def is_read(self) -> bool:
return not self.is_write

@property
def is_conditional(self) -> bool:
if self.conditional is None:
Comment thread
twicki marked this conversation as resolved.
raise RuntimeError("It is unknown whether or not this access is conditional.")

return self.conditional

def to_extent(
self,
horizontal_extent: Extent,
Expand Down Expand Up @@ -87,6 +96,7 @@ def visit_FieldAccess(
accesses: List[GeneralAccess],
is_write: bool,
horizontal_mask: Optional[common.HorizontalMask] = None,
conditional: bool | None = None,
**kwargs: Any,
) -> None:
self.generic_visit(node, accesses=accesses, is_write=is_write, **kwargs)
Expand All @@ -98,6 +108,7 @@ def visit_FieldAccess(
data_index=node.data_index,
is_write=is_write,
horizontal_mask=horizontal_mask,
conditional=conditional is True,
)
)

Expand All @@ -107,14 +118,17 @@ def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs: Any) -> None:

def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> None:
self.visit(node.mask, is_write=False, **kwargs)
self.visit(node.body, **kwargs)
kwargs.pop("conditional", None) # avoid multiple values for kwarg `conditional`
self.visit(node.body, conditional=True, **kwargs)

def visit_While(self, node: oir.While, **kwargs: Any) -> None:
self.visit(node.cond, is_write=False, **kwargs)
self.visit(node.body, **kwargs)
kwargs.pop("conditional", None) # avoid multiple values for kwarg `conditional`
self.visit(node.body, conditional=True, **kwargs)

def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, **kwargs: Any) -> None:
self.visit(node.body, horizontal_mask=node.mask, **kwargs)
kwargs.pop("conditional", None) # avoid multiple values for kwarg `conditional`
self.visit(node.body, horizontal_mask=node.mask, conditional=True, **kwargs)

def visit_Interval(self, node: oir.Interval, **kwargs: Any) -> None:
self.visit(node.start, is_write=False, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.cartesian.gtc import oir
from gt4py.cartesian.gtc.common import BuiltInLiteral, DataType, LoopOrder
from gt4py.cartesian.gtc.passes.oir_optimizations.temporaries import (
LocalTemporariesToScalars,
WriteBeforeReadTemporariesToScalars,
Expand All @@ -15,28 +16,35 @@
from ...oir_utils import (
AssignStmtFactory,
HorizontalExecutionFactory,
LiteralFactory,
MaskStmtFactory,
StencilFactory,
TemporaryFactory,
VerticalLoopFactory,
VerticalLoopSectionFactory,
)


def test_local_temporaries_to_scalars_basic():
def test_local_temporaries_to_scalars_basic() -> None:
testee = StencilFactory(
vertical_loops__0__sections__0__horizontal_executions__0__body=[
AssignStmtFactory(left__name="tmp"),
AssignStmtFactory(right__name="tmp"),
],
declarations=[TemporaryFactory(name="tmp")],
)

transformed = LocalTemporariesToScalars().visit(testee)

assert isinstance(transformed, oir.Stencil)
hexec = transformed.vertical_loops[0].sections[0].horizontal_executions[0]
assert isinstance(hexec.body[0].left, oir.ScalarAccess)
assert isinstance(hexec.body[1].right, oir.ScalarAccess)
assert not transformed.declarations
assert len(hexec.declarations) == 1


def test_local_temporaries_to_scalars_multiexec():
def test_local_temporaries_to_scalars_multiexec() -> None:
testee = StencilFactory(
vertical_loops__0__sections__0__horizontal_executions=[
HorizontalExecutionFactory(
Expand All @@ -46,12 +54,15 @@ def test_local_temporaries_to_scalars_multiexec():
],
declarations=[TemporaryFactory(name="tmp")],
)

transformed = LocalTemporariesToScalars().visit(testee)

assert isinstance(transformed, oir.Stencil)
assert "tmp" in {d.name for d in transformed.declarations}
assert not transformed.walk_values().if_isinstance(oir.ScalarAccess).to_list()


def test_write_before_read_temporaries_to_scalars():
def test_write_before_read_temporaries_to_scalars() -> None:
testee = StencilFactory(
vertical_loops__0__sections__0__horizontal_executions=[
HorizontalExecutionFactory(
Expand All @@ -75,7 +86,10 @@ def test_write_before_read_temporaries_to_scalars():
TemporaryFactory(name="tmp3"),
],
)

transformed = WriteBeforeReadTemporariesToScalars().visit(testee)

assert isinstance(transformed, oir.Stencil)
hexec0 = transformed.vertical_loops[0].sections[0].horizontal_executions[0]
hexec1 = transformed.vertical_loops[0].sections[0].horizontal_executions[1]
assert len(hexec0.declarations) == 2
Expand All @@ -89,3 +103,54 @@ def test_write_before_read_temporaries_to_scalars():
assert not isinstance(hexec1.body[0].right, oir.ScalarAccess)
assert isinstance(hexec1.body[1].left, oir.ScalarAccess)
assert isinstance(hexec1.body[2].right, oir.ScalarAccess)


def test_conditional_write_before_read() -> None:
testee = StencilFactory(
vertical_loops=[
VerticalLoopFactory(
loop_order=LoopOrder.FORWARD,
sections=[
VerticalLoopSectionFactory(
horizontal_executions=[
HorizontalExecutionFactory(
body=[
AssignStmtFactory(left__name="tmp1"),
]
),
]
)
],
),
VerticalLoopFactory(
loop_order=LoopOrder.PARALLEL,
sections=[
VerticalLoopSectionFactory(
horizontal_executions=[
HorizontalExecutionFactory(
body=[
MaskStmtFactory(
mask=LiteralFactory(
dtype=DataType.BOOL, value=BuiltInLiteral.FALSE
),
body=[AssignStmtFactory(left__name="tmp1")],
),
AssignStmtFactory(right__name="tmp1"),
]
)
]
)
],
),
],
declarations=[
TemporaryFactory(name="tmp1"),
],
)

transformed = WriteBeforeReadTemporariesToScalars().visit(testee)

# Make sure we don't scalarize in case of conditional write before unconditional read
assert isinstance(transformed, oir.Stencil)
assert len(transformed.declarations) == 1
assert transformed.declarations[0].name == "tmp1"
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def test_access_collector():
"mask": {(-1, -1, 1)},
}
ordered_accesses = [
GeneralAccess(field="foo", offset=(1, 0, 0), is_write=False),
GeneralAccess(field="tmp", offset=(0, 0, 0), is_write=True),
GeneralAccess(field="tmp", offset=(0, 0, 0), is_write=False),
GeneralAccess(field="bar", offset=(0, 0, 0), is_write=True),
GeneralAccess(field="mask", offset=(-1, -1, 1), is_write=False),
GeneralAccess(field="tmp", offset=(0, 1, 0), is_write=False),
GeneralAccess(field="baz", offset=(0, 0, 0), is_write=True),
GeneralAccess(field="foo", offset=(1, 0, 0), is_write=False, conditional=False),
GeneralAccess(field="tmp", offset=(0, 0, 0), is_write=True, conditional=False),
GeneralAccess(field="tmp", offset=(0, 0, 0), is_write=False, conditional=False),
GeneralAccess(field="bar", offset=(0, 0, 0), is_write=True, conditional=False),
GeneralAccess(field="mask", offset=(-1, -1, 1), is_write=False, conditional=False),
GeneralAccess(field="tmp", offset=(0, 1, 0), is_write=False, conditional=True),
GeneralAccess(field="baz", offset=(0, 0, 0), is_write=True, conditional=True),
]

result = AccessCollector.apply(testee)
Expand Down
Loading