Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
262 commits
Select commit Hold shift + click to select a range
e73d56e
Merge remote-tracking branch 'origin/main' into gtir-dace-concat_where
edopao Feb 5, 2025
2219314
Merge branch 'main' into GTIR_concat_where
SF-N Feb 5, 2025
d89aba2
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Feb 5, 2025
140046f
Merge remote-tracking branch 'origin/main' into gtir-dace-concat_where
edopao Feb 6, 2025
ba6e81f
add test for multi-dim concat_where
edopao Feb 6, 2025
5ec2562
code cleanup
edopao Feb 6, 2025
e7d35af
minor edit
edopao Feb 6, 2025
9eb428a
Merge origin/main
tehrengruber Feb 14, 2025
d16bbd5
ITIR type inference: store param type in Lambda
tehrengruber Feb 15, 2025
aca4824
Merge branch 'main' into store_lambda_param_type
tehrengruber Feb 17, 2025
813f328
Flatten as_fieldop tuple arguments
tehrengruber Feb 18, 2025
3745461
Add support for scan and nested tuples
tehrengruber Feb 19, 2025
1f23e17
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 19, 2025
8bec9ab
Merge branch 'store_lambda_param_type' into GTIR_concat_where
tehrengruber Feb 19, 2025
06806fb
Preserve annex on new nodes
tehrengruber Feb 19, 2025
bab4fe1
Fix unnecessary import
tehrengruber Feb 19, 2025
6257a2b
Merge branch 'eve_annex_preserve_new_node' into GTIR_concat_where
tehrengruber Feb 19, 2025
14b4bf3
Cleanup
tehrengruber Feb 19, 2025
fc20d7c
Fix doctest
tehrengruber Feb 19, 2025
c5fba83
Fix failing tests
tehrengruber Feb 19, 2025
fa17228
Merge branch 'store_lambda_param_type' into collapse_tuple_as_fieldop…
tehrengruber Feb 19, 2025
04ae430
Fix tests
tehrengruber Feb 19, 2025
5136adc
Fix tests
tehrengruber Feb 19, 2025
5939618
Cleanup frontend type deduction
tehrengruber Feb 19, 2025
157b0e2
Cleanup frontend type deduction
tehrengruber Feb 19, 2025
435d057
Cleanup concat where:
tehrengruber Feb 20, 2025
5e5c66e
Merge branch 'eve_annex_preserve_new_node' into GTIR_concat_where
tehrengruber Feb 20, 2025
bd8dbaa
Fix iterator tests
tehrengruber Feb 20, 2025
2c14648
Fix infer domain ops
tehrengruber Feb 20, 2025
a7f3cac
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
1200803
Cleanup
tehrengruber Feb 20, 2025
cf0ffb2
Fix format
tehrengruber Feb 20, 2025
515f79b
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Feb 20, 2025
d84e937
Merge remote-tracking branch 'origin/main' into gtir-dace-concat_where
edopao Feb 20, 2025
dfc3310
fix gtir dace test cases
edopao Feb 20, 2025
335e932
Fix broken scan (e.g. test_tuple_scalar_scan)
tehrengruber Feb 20, 2025
ee1cd1c
Merge remote-tracking branch 'origin/main' into gtir-dace-concat_where
edopao Feb 20, 2025
7518b9c
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
39652de
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Feb 20, 2025
ba03c7e
Merge remote-tracking branch 'origin/main' into collapse_tuple_as_fie…
tehrengruber Feb 20, 2025
71980af
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
c18b7ad
Fix failing tests
tehrengruber Feb 20, 2025
d399c65
Fix format
tehrengruber Feb 20, 2025
5ad7701
Fix failing tests
tehrengruber Feb 20, 2025
d3957bd
Fix format
tehrengruber Feb 20, 2025
e95fdf0
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
b52a07c
Cleanup
tehrengruber Feb 20, 2025
f8703b2
Merge branch 'collapse_tuple_as_fieldop_args' into GTIR_concat_where
tehrengruber Feb 20, 2025
c5c3e5f
Fix pyproject.toml test marker
tehrengruber Feb 20, 2025
f59fabf
Remove unnecessary visits
tehrengruber Feb 20, 2025
c8e06bd
Cleanup trace shifts
tehrengruber Feb 20, 2025
04d944c
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Feb 20, 2025
c34ad3d
enable concat_where frontend tests
edopao Feb 20, 2025
f748da7
Fix type inference
tehrengruber Feb 20, 2025
45f8b09
Add concat_where transforms to field view transforms
tehrengruber Feb 20, 2025
b3647bf
Fix typo
tehrengruber Feb 20, 2025
6ea11e5
Add support for tuples
tehrengruber Feb 20, 2025
60d0d9a
Fixes
tehrengruber Feb 20, 2025
93a6d33
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Feb 21, 2025
132e576
Improve docs
tehrengruber Feb 21, 2025
e469075
Improve docs
tehrengruber Feb 21, 2025
d80c78d
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Feb 21, 2025
24e2f57
Fix typo
tehrengruber Feb 21, 2025
44036b8
make simple tests work
edopao Feb 21, 2025
01e8390
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Feb 21, 2025
d14fb21
Cleanup & improve test coverage
tehrengruber Feb 24, 2025
1e3ced5
Cleanup
tehrengruber Feb 24, 2025
545ec6c
concat_where_non_overlapping passes
edopao Feb 24, 2025
4fd741b
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Feb 24, 2025
e7ea3ce
minor edit
edopao Feb 24, 2025
595b675
Cleanup
tehrengruber Feb 24, 2025
fac7de0
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Feb 24, 2025
ba40dd6
d=skip concat_where broadcast test cases
edopao Feb 24, 2025
0b186d3
fix test with horizontal slice
edopao Feb 24, 2025
c242fb5
Revert "d=skip concat_where broadcast test cases"
edopao Feb 24, 2025
6c17104
concat_where with implicit broadcast
edopao Feb 25, 2025
fd44ffb
fix for negative size of temporary fields
edopao Feb 26, 2025
425e2f0
use domain info from node annex
edopao Feb 27, 2025
945908d
preserve domain information in ReplaceSymbols IR pass
edopao Feb 27, 2025
6b9dc5c
fix boundaries of field domain
edopao Feb 28, 2025
59a1226
Improve type inference for concat_where tuple case
tehrengruber Feb 28, 2025
f832a19
Fix typo
tehrengruber Feb 28, 2025
6f34d40
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Feb 28, 2025
475bce2
update stride for slice argument
edopao Feb 28, 2025
7b815f9
renaming of variables
edopao Feb 28, 2025
51a0fd2
more renaming
edopao Feb 28, 2025
75cc4f2
Fix bug in infer domain ops
tehrengruber Mar 2, 2025
6e85bd0
Address review comments
tehrengruber Mar 2, 2025
a8b9736
Merge remote-tracking branch 'origin_tehrengruber/store_lambda_param_…
tehrengruber Mar 2, 2025
9978a43
Address review comments
tehrengruber Mar 2, 2025
232d4b8
Address review comments
tehrengruber Mar 2, 2025
57abfaf
Merge remote-tracking branch 'origin/main' into store_lambda_param_type
tehrengruber Mar 2, 2025
596debf
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Mar 3, 2025
682aef4
Merge remote-tracking branch 'origin/main' into gtir-dace-concat_where
edopao Mar 3, 2025
2cf3c97
edit code comments
edopao Mar 3, 2025
f488b1a
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Mar 3, 2025
55dc611
Merge branch 'store_lambda_param_type' into GTIR_concat_where
tehrengruber Mar 3, 2025
2674f11
Merge origin/main
tehrengruber Mar 3, 2025
c645cf5
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Mar 3, 2025
d0f93be
Fix deferred type in concat_where
tehrengruber Mar 3, 2025
cd248aa
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Mar 3, 2025
cf50a37
Fix tuple concat_where (not fully done yet)
tehrengruber Mar 3, 2025
4c29e73
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Mar 3, 2025
9659fba
handle scalar argument in concat_where
edopao Mar 3, 2025
5fc42ce
Fix tuple concat_where (not fully done yet)
tehrengruber Mar 3, 2025
ec6a507
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Mar 3, 2025
e3b4f20
handle domain of tuples
edopao Mar 3, 2025
fa82974
fix source subset in concat_where memlet
edopao Mar 4, 2025
321a204
cleanup
edopao Mar 4, 2025
6cfc7eb
Merge remote-tracking branch 'origin/main' into gtir-dace-concat_where
edopao Mar 5, 2025
f0657df
change broadcast
edopao Mar 6, 2025
6f23bc6
fix previous commit
edopao Mar 6, 2025
d63aea9
rename map range variables
edopao Mar 6, 2025
3f58894
try - fix segfault by turning 1:1 into 0:0
edopao Mar 6, 2025
d4d301b
fix formatting
edopao Mar 6, 2025
6eddf30
pruning of empty branches
edopao Mar 6, 2025
674391c
fix for isolated nodes
edopao Mar 6, 2025
7d328ae
undo extra change
edopao Mar 6, 2025
ae495cb
add test case
edopao Mar 6, 2025
5f2aed6
add missing other_subset info
edopao Mar 7, 2025
b1579a7
fix test cases
edopao Mar 7, 2025
77edc98
Unclean fixes (revert tuple lowering)
tehrengruber Mar 11, 2025
ef4979d
fix for field domain with scalr broadcast
edopao Mar 13, 2025
1a4bf3a
Enable laplacian test
tehrengruber Mar 14, 2025
61c2159
fix memlet subset issue
edopao Mar 18, 2025
65f7fcf
undo extra changes
edopao Mar 19, 2025
cadce61
use dynamic memlets for write to output
edopao Mar 20, 2025
b2b12e4
prefix tasklet connectors with '__' to avoid name collision with data…
edopao Mar 20, 2025
294f0cf
use temp array, not view, for concat slice
edopao Mar 20, 2025
ac0625f
Merge origin/main
tehrengruber Mar 21, 2025
b28f1f0
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Mar 21, 2025
1af561e
Merge origin/main
tehrengruber Mar 21, 2025
1ab8c69
embedded concat_where
havogt Mar 21, 2025
73fba27
Merge branch 'GTIR_concat_where' of github.com:SF-N/gt4py into GTIR_c…
havogt Mar 21, 2025
ae07826
add support for more comparison operators
havogt Mar 21, 2025
a8fe04e
change Dimension comparison
havogt Mar 22, 2025
40cf33b
embedded: non-python int comparison
havogt Mar 23, 2025
eb72671
Merge remote-tracking branch 'SF-N/GTIR_concat_where' into gtir-dace-…
edopao Mar 24, 2025
c60492c
add todo comment
edopao Mar 25, 2025
53c2cb7
formatting
edopao Mar 25, 2025
0037278
fix slice memlet
edopao Mar 25, 2025
ac522d2
fix broadcast of scalar value
edopao Mar 31, 2025
2feab82
pruning of empty branches is done by gt_simplify
edopao Mar 31, 2025
8c4fc45
Fix import
tehrengruber Apr 14, 2025
ac6fbb4
Merge origin/main
tehrengruber Apr 14, 2025
9ef6dd0
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
tehrengruber Apr 15, 2025
2731432
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber Apr 15, 2025
8ae4b58
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
tehrengruber Apr 15, 2025
bf7ae21
Merge remote-tracking branch 'upstream/main' into GTIR_concat_where
havogt Apr 17, 2025
e0d65e4
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
havogt Apr 17, 2025
b8b80f8
Merge origin/main
tehrengruber Apr 24, 2025
1f49b71
Merge remote-tracking branch 'origin_sf_n/GTIR_concat_where' into GTI…
tehrengruber Apr 24, 2025
a41eb1a
feat[next]: GTIR concat_where frontend
havogt Apr 24, 2025
e2c053c
disable concat_where tests
havogt Apr 24, 2025
4b46fcd
one more it_ts.DomainType
havogt Apr 24, 2025
d77a4c0
add test for concat_where on scalars and fix typing
havogt Apr 24, 2025
f41c112
add test for chained comparison
havogt Apr 25, 2025
de287c2
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber May 1, 2025
3a46c71
Merge branch 'main' into GTIR_concat_where
tehrengruber May 2, 2025
f4861fd
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
tehrengruber May 2, 2025
8a3aef5
Merge remote-tracking branch 'origin_edopao/gtir-dace-concat_where' i…
tehrengruber May 2, 2025
13baa21
Merge branch 'main' into GTIR_concat_where
tehrengruber May 7, 2025
fc329e6
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
tehrengruber May 7, 2025
4d41d86
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber May 9, 2025
7cbd2cb
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
tehrengruber May 9, 2025
c88adf7
Merge origin/main
tehrengruber May 18, 2025
359a921
Fix broken merge
tehrengruber May 18, 2025
4764c2b
Simplify tuple lowering, unit tests, cleanup
tehrengruber May 18, 2025
62db9ff
Small fix
tehrengruber May 18, 2025
7053c39
Cleanup
tehrengruber May 18, 2025
d597a4d
Cleanup
tehrengruber May 18, 2025
45ccbbc
Cleanup
tehrengruber May 18, 2025
0326e80
Add more unit tests
tehrengruber May 20, 2025
d069b67
Merge branch 'main' into GTIR_concat_where
edopao May 20, 2025
d1ba62f
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao May 20, 2025
ebb0831
fix for list index on slice concat_where
edopao May 20, 2025
ab719d9
fix for new domain on scalars and 1d fields as concat_where args
edopao May 20, 2025
1ce4ed4
Cleanup
tehrengruber May 22, 2025
62688b2
Cleanup
tehrengruber May 23, 2025
afdd60c
Merge remote-tracking branch 'origin/main' into GTIR_concat_where
tehrengruber May 23, 2025
eb9adf9
Cleanup
tehrengruber May 23, 2025
a7527df
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao May 23, 2025
dc98855
Cleanup
tehrengruber May 23, 2025
351ca0f
fix pass_manager
edopao May 23, 2025
b4e5fd1
Cleanup
tehrengruber May 23, 2025
e1b9d88
Merge commit 'a41eb1a' into GTIR_concat_where
tehrengruber May 23, 2025
dfd2b14
Merge remote-tracking branch 'origin/main' into concat_where_frontend
tehrengruber May 23, 2025
ec6b305
Merge branch 'concat_where_frontend' into GTIR_concat_where (#1998)
tehrengruber May 23, 2025
0bd26ce
Cleanup
tehrengruber May 23, 2025
e5dbf4a
Cleanup
tehrengruber May 23, 2025
0e8faad
Fix dace
tehrengruber May 23, 2025
1ac9719
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao May 25, 2025
77b8efe
Merge branch 'main' into GTIR_concat_where
edopao May 26, 2025
e215d66
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao May 26, 2025
c497022
Merge branch 'main' into GTIR_concat_where
edopao May 27, 2025
75c7c1e
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao May 27, 2025
cfe389d
Merge branch 'main' into GTIR_concat_where
edopao May 28, 2025
4af6f19
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao May 28, 2025
599d8f1
Merge remote-tracking branch 'upstream/main' into concat_where_frontend
havogt Jun 3, 2025
2499389
remove unchain comparison (because doesn't make sense)
havogt Jun 4, 2025
398ec68
improve error messages
havogt Jun 4, 2025
f81393a
fix chain test
havogt Jun 4, 2025
eae7dc7
simplify typing
havogt Jun 4, 2025
16e1c65
rename
havogt Jun 4, 2025
5f7e251
add promotion tests
havogt Jun 5, 2025
b1e8f89
Fix small type inference bug
tehrengruber Jun 4, 2025
35c026e
Merge branch 'main' into GTIR_concat_where
tehrengruber Jun 5, 2025
4feb941
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
tehrengruber Jun 5, 2025
fb01638
Merge branch 'concat_where_frontend' into GTIR_concat_where
tehrengruber Jun 5, 2025
06905b8
Merge branch 'main' into GTIR_concat_where
tehrengruber Jun 5, 2025
a66e5ca
Merge remote-tracking branch 'origin/main' into concat_where_frontend
tehrengruber Jun 5, 2025
d89cff6
Backport fixes from main PR
tehrengruber Jun 5, 2025
6f9ebff
Merge branch 'concat_where_frontend' into GTIR_concat_where
tehrengruber Jun 5, 2025
3dac495
Cleanup
tehrengruber Jun 5, 2025
506c2b5
Extract concat_where transformations
tehrengruber Jun 5, 2025
8ae99ae
Merge branch 'gtir_concat_where_passes' into GTIR_concat_where
tehrengruber Jun 5, 2025
af36bc9
Small fix
tehrengruber Jun 5, 2025
f1a99bd
Format
tehrengruber Jun 5, 2025
1f6b284
Format
tehrengruber Jun 5, 2025
45a2e23
Cleanup
tehrengruber Jun 5, 2025
721dde3
Merge branch 'gtir_concat_where_passes' into GTIR_concat_where
tehrengruber Jun 5, 2025
f770961
Merge branch 'main' into GTIR_concat_where
edopao Jun 11, 2025
d8252b7
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jun 11, 2025
aadf582
remove uses_concat_where from COMMON_SKIP_TEST_LIST
edopao Jun 11, 2025
1fcc669
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jun 11, 2025
e631f00
set find_new_name=True for concat_where slice
edopao Jun 11, 2025
0957ca9
Merge branch 'main' into GTIR_concat_where
edopao Jun 12, 2025
65c2116
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jun 12, 2025
b056955
Merge branch 'main' into GTIR_concat_where
edopao Jun 13, 2025
8a4308b
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jun 13, 2025
b23ad21
fix domain in scalar broadcast
edopao Jun 20, 2025
fb074fe
Merge branch 'main' into GTIR_concat_where
edopao Jun 23, 2025
08cb3a7
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jun 23, 2025
d669e3b
fix merge commit
edopao Jun 23, 2025
80ba402
add test case for coverage of bug in scalar broadcast on empty branch
edopao Jun 23, 2025
4f5f8ae
fix formatting
edopao Jun 25, 2025
a090872
bugfix - support for empty branch on 2d/3d concat_where
edopao Jun 25, 2025
fc2df23
Merge branch 'main' into GTIR_concat_where
edopao Jun 25, 2025
2bd9d3b
add test cases for empty branches
edopao Jun 25, 2025
7aabf11
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jun 25, 2025
3c78ec1
Merge branch 'main' into GTIR_concat_where
edopao Jun 25, 2025
a9e9864
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jun 25, 2025
eaf2195
fix domain
edopao Jun 26, 2025
f26b312
add extra check for scalar broadcast
edopao Jun 27, 2025
78a61ca
extend test case scalar_broadcast_on_empty_branch
edopao Jun 27, 2025
a9f146e
Merge branch 'main' into GTIR_concat_where
edopao Jun 27, 2025
31da410
pre-commit - format code
edopao Jun 27, 2025
ad31327
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jun 27, 2025
71f46f3
Merge branch 'main' into GTIR_concat_where
edopao Jul 4, 2025
de8e737
Merge branch 'GTIR_concat_where' into gtir-dace-concat_where
edopao Jul 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,41 @@ def __add__(self, offset: int) -> Connectivity:
def __sub__(self, offset: int) -> Connectivity:
return self + (-offset)

def __gt__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),))

def __ge__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),))

def __lt__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),))

def __le__(self, value: core_defs.IntegralScalar) -> Domain:
# TODO add test
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),))

def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain:
if isinstance(value, Dimension):
return self.value == value.value
elif isinstance(value, core_defs.INTEGRAL_TYPES):
# TODO probably only within valid embedded context?
return Domain(dims=(self,), ranges=(UnitRange(value, value + 1),))
else:
return False

def __ne__(self, value: Dimension | core_defs.IntegralScalar) -> bool | tuple[Domain, Domain]:
# TODO add test
if isinstance(value, Dimension):
return self.value != value.value
elif isinstance(value, core_defs.INTEGRAL_TYPES):
# TODO probably only within valid embedded context?
return (
Domain(self, UnitRange(Infinity.NEGATIVE, value)),
Domain(self, UnitRange(value + 1, Infinity.POSITIVE)),
)
else:
return True


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
Expand Down Expand Up @@ -500,6 +535,24 @@ def __and__(self, other: Domain) -> Domain:
)
return Domain(dims=broadcast_dims, ranges=intersected_ranges)

def __or__(self, other: Domain) -> Domain:
# TODO support arbitrary union of domains
# TODO add tests
if self.ndim > 1 or other.ndim > 1:
raise NotImplementedError("Union of multidimensional domains is not supported.")
if self.ndim == 0:
return other
if other.ndim == 0:
return self
sorted_ = sorted((self, other), key=lambda x: x.ranges[0].start)
if sorted_[0].ranges[0].stop >= sorted_[1].ranges[0].start:
return Domain(
dims=(self.dims[0],),
ranges=(UnitRange(sorted_[0].ranges[0].start, sorted_[1].ranges[0].stop),),
)
else:
return (sorted_[0], sorted_[1])

@functools.cached_property
def slice_at(self) -> utils.IndexerCallable[slice, Domain]:
"""
Expand Down
147 changes: 77 additions & 70 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,25 +810,6 @@ def _hyperslice(
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _compute_mask_slices(
mask: core_defs.NDArrayObject,
) -> list[tuple[bool, slice]]:
"""Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices."""
# TODO: does it make sense to upgrade this naive algorithm to numpy?
assert mask.ndim == 1
cur = bool(mask[0].item())
ind = 0
res = []
for i in range(1, mask.shape[0]):
# Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy
if (mask_i := bool(mask[i].item())) != cur:
res.append((cur, slice(ind, i)))
cur = mask_i
ind = i
res.append((cur, slice(ind, mask.shape[0])))
return res


def _trim_empty_domains(
lst: Iterable[tuple[bool, common.Domain]],
) -> list[tuple[bool, common.Domain]]:
Expand Down Expand Up @@ -896,82 +877,108 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c

def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field:
# TODO(havogt): this function could be extended to a general concat
# currently only concatenate along the given dimension and requires the fields to be ordered
# currently only concatenate along the given dimension
sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start)

if (
len(fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty()
len(sorted_fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in sorted_fields]).is_empty()
):
raise ValueError("Fields to concatenate must not overlap.")
new_domain = _stack_domains(*[f.domain for f in fields], dim=dim)
new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim)
if new_domain is None:
raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.")
nd_array_class = _get_nd_array_class(*fields)
nd_array_class = _get_nd_array_class(*sorted_fields)
return nd_array_class.from_array(
nd_array_class.array_ns.concatenate(
[nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields],
[
nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape)
for f in sorted_fields
],
axis=new_domain.dim_index(dim, allow_missing=False),
),
domain=new_domain,
)


def _invert_domain(
domains: common.Domain | tuple[common.Domain],
) -> common.Domain | tuple[common.Domain, ...]:
if not isinstance(domains, tuple):
domains = (domains,)

assert all(d.ndim == 1 for d in domains)
dim = domains[0].dims[0]
assert all(d.dims[0] == dim for d in domains)
sorted_domains = sorted(domains, key=lambda d: d.ranges[0].start)

result = []
if domains[0].ranges[0].start is not common.Infinity.NEGATIVE:
result.append(
common.Domain(
dims=(dim,),
ranges=(common.UnitRange(common.Infinity.NEGATIVE, domains[0].ranges[0].start),),
)
)
for i in range(len(sorted_domains) - 1):
if sorted_domains[i].ranges[0].stop != sorted_domains[i + 1].ranges[0].start:
result.append(
common.Domain(
dims=(dim,),
ranges=(
common.UnitRange(
sorted_domains[i].ranges[0].stop, sorted_domains[i + 1].ranges[0].start
),
),
)
)
if domains[-1].ranges[0].stop is not common.Infinity.POSITIVE:
result.append(
common.Domain(
dims=(dim,),
ranges=(common.UnitRange(domains[-1].ranges[0].stop, common.Infinity.POSITIVE),),
)
)
return tuple(result)


def _intersect_multiple(
domain: common.Domain, domains: common.Domain | tuple[common.Domain]
) -> tuple[common.Domain, ...]:
if not isinstance(domains, tuple):
domains = (domains,)

return tuple(
intersection
for d in domains
if not (intersection := embedded_common.domain_intersection(domain, d)).is_empty()
)


def _concat_where(
mask_field: common.Field, true_field: common.Field, false_field: common.Field
masks: common.Domain | tuple[common.Domain, ...],
true_field: common.Field,
false_field: common.Field,
) -> common.Field:
cls_ = _get_nd_array_class(mask_field, true_field, false_field)
xp = cls_.array_ns
if mask_field.domain.ndim != 1:
if not isinstance(masks, tuple):
masks = (masks,)
if any(m.ndim for m in masks) != 1:
raise NotImplementedError(
"'concat_where': Can only concatenate fields with a 1-dimensional mask."
)
mask_dim = mask_field.domain.dims[0]
mask_dim = masks[0].dims[0]

# intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain
t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim)

# TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils
# compute the consecutive ranges (first relative, then domain) of true and false values
mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices(
mask_field.ndarray
)
mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = (
(mask, mask_field.domain.slice_at[domain_slice])
for mask, domain_slice in mask_values_to_slices_mapping
)
# mask domains intersected with the respective fields
mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = (
(
mask_value,
embedded_common.domain_intersection(
t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain
),
)
for mask_value, mask_domain in mask_values_to_domain_mapping
)

# remove the empty domains from the beginning and end
mask_values_to_intersected_domains_mapping = _trim_empty_domains(
mask_values_to_intersected_domains_mapping
)
if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping):
raise embedded_exceptions.NonContiguousDomain(
f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}."
)
true_domains = _intersect_multiple(t_broadcasted.domain, masks)
t_slices = tuple(t_broadcasted[d] for d in true_domains)

# slice the fields with the domain ranges
transformed = [
t_broadcasted[d] if v else f_broadcasted[d]
for v, d in mask_values_to_intersected_domains_mapping
]
inverted_masks = _invert_domain(masks)
false_domains = _intersect_multiple(f_broadcasted.domain, inverted_masks)
f_slices = tuple(f_broadcasted[d] for d in false_domains)

# stack the fields together
if transformed:
return _concat(*transformed, dim=mask_dim)
else:
result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0)))
result_array = xp.empty(result_domain.shape)
return cls_.from_array(result_array, domain=result_domain)
return _concat(*f_slices, *t_slices, dim=mask_dim)


NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR
Expand Down
9 changes: 7 additions & 2 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,12 @@ def create_if(

return im.let(cond_symref_name, cond_)(result)

_visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
domain, true_branch, false_branch = self.visit(node.args, **kwargs)
# TODO: use this case again. breaks domain inference in fused_velocity_advection_stencil_1_to_7
# because some tuple elements are never accessed and the collapse tuple
# does not propagate across concat where
return im.concat_where(domain, true_branch, false_branch)

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return im.call("broadcast")(*self.visit(node.args, **kwargs))
Expand Down Expand Up @@ -488,7 +493,7 @@ def _map(
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
"""
if all(
isinstance(t, ts.ScalarType)
isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType))
for arg_type in original_arg_types
for t in type_info.primitive_constituents(arg_type)
):
Expand Down
12 changes: 12 additions & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
raise BackendNotSelectedError()


@builtin_dispatch
def concat_where(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def in_(*args):
raise BackendNotSelectedError()


UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"}
UNARY_LOGICAL_BUILTINS = {"not_"}
UNARY_MATH_FP_BUILTINS = {
Expand Down Expand Up @@ -494,6 +504,8 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"scan",
"tuple_get",
"unstructured_domain",
"concat_where",
"in_",
*ARITHMETIC_BUILTINS,
*TYPE_BUILTINS,
}
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,16 @@ def index(axis: common.Dimension) -> common.Field:
return IndexField(axis)


@builtins.concat_where.register(EMBEDDED)
def concat_where(*args):
raise NotImplementedError("To be implemented in frontend embedded.")


@builtins.in_.register(EMBEDDED)
def in_(*args):
raise NotImplementedError("To be implemented in frontend embedded.")


def closure(
domain_: runtime.CartesianDomain | runtime.UnstructuredDomain,
sten: Callable[..., Any],
Expand Down
21 changes: 20 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import ClassVar, List, Optional, Union
import typing
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef
Expand Down Expand Up @@ -63,6 +65,22 @@ class NoneLiteral(Expr):
_none_literal: int = 0


class InfinityLiteral(Expr):
# TODO(tehrengruber): self referential `ClassVar` not supported in eve.
if TYPE_CHECKING:
POSITIVE: ClassVar[InfinityLiteral]
NEGATIVE: ClassVar[InfinityLiteral]

name: typing.Literal["POSITIVE", "NEGATIVE"]

def __str__(self):
return f"{type(self).__name__}.{self.name}"


InfinityLiteral.NEGATIVE = InfinityLiteral(name="NEGATIVE")
InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE")


class OffsetLiteral(Expr):
value: Union[int, str]

Expand Down Expand Up @@ -142,4 +160,5 @@ class Program(Node, ValidatedSymbolTableTrait):
Program.__hash__ = Node.__hash__ # type: ignore[method-assign]
SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign]
IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign]
InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
Temporary.__hash__ = Node.__hash__ # type: ignore[method-assign]
Loading
Loading