diff --git a/nki/test/test_nki_nl_atomic_rmw.py b/nki/test/test_nki_nl_atomic_rmw.py index 5ad1b225..3c8725e8 100644 --- a/nki/test/test_nki_nl_atomic_rmw.py +++ b/nki/test/test_nki_nl_atomic_rmw.py @@ -32,6 +32,11 @@ def atomic_rmw_indirect_indices(in_tensor, indices_tensor, value_tensor): M = 512 # NKI_EXAMPLE_18_BEGIN + # Initialize sentinel mask for int32 max value + sentinel_mask = indices_tile != np.iinfo(np.int32).max + # Use safe indices (0 for sentinel lanes) to avoid OOB access + safe_indices = nl.where(sentinel_mask, indices_tile, nl.zeros(indices_tile.shape, dtype=indices_tile.dtype)) + value: tensor[N, M] = nl.load(value_tensor) # dynamic indices have to be in SBUF, with shape [N, 1] @@ -41,12 +46,13 @@ def atomic_rmw_indirect_indices(in_tensor, indices_tensor, value_tensor): ######################################################################## # Atomic read-modify-write example: - # - read: values of rmw_tensor is indexed by values from indices_tile - # - modify: incremented by value - # - write: saved back into rmw_tensor - # resulting in rmw_tensor = rmw_tensor + value + # - read: values of rmw_tensor is indexed by safe_indices + # - modify: incremented by value (only for non-sentinel lanes) + # - write: saved back into rmw_tensor using predicate to skip sentinels + # resulting in rmw_tensor = rmw_tensor + value (only for valid indices) ######################################################################## - nl.atomic_rmw(rmw_tensor[indices_tile, ix], value=value, op=np.add) + nl.atomic_rmw(rmw_tensor[safe_indices, ix], value=value, op=np.add, + predicate=sentinel_mask.squeeze(axis=-1)) # NKI_EXAMPLE_18_END return rmw_tensor