Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions nki/test/test_nki_nl_atomic_rmw.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def atomic_rmw_indirect_indices(in_tensor, indices_tensor, value_tensor):
M = 512

# NKI_EXAMPLE_18_BEGIN
# Initialize sentinel mask for int32 max value
sentinel_mask = indices_tile != np.iinfo(np.int32).max
# Use safe indices (0 for sentinel lanes) to avoid OOB access
safe_indices = nl.where(sentinel_mask, indices_tile, nl.zeros(indices_tile.shape, dtype=indices_tile.dtype))

value: tensor[N, M] = nl.load(value_tensor)

# dynamic indices have to be in SBUF, with shape [N, 1]
Expand All @@ -41,12 +46,13 @@ def atomic_rmw_indirect_indices(in_tensor, indices_tensor, value_tensor):

########################################################################
# Atomic read-modify-write example:
# - read: values of rmw_tensor is indexed by values from indices_tile
# - modify: incremented by value
# - write: saved back into rmw_tensor
# resulting in rmw_tensor = rmw_tensor + value
# - read: values of rmw_tensor is indexed by safe_indices
# - modify: incremented by value (only for non-sentinel lanes)
# - write: saved back into rmw_tensor using predicate to skip sentinels
# resulting in rmw_tensor = rmw_tensor + value (only for valid indices)
########################################################################
nl.atomic_rmw(rmw_tensor[indices_tile, ix], value=value, op=np.add)
nl.atomic_rmw(rmw_tensor[safe_indices, ix], value=value, op=np.add,
predicate=sentinel_mask.squeeze(axis=-1))
# NKI_EXAMPLE_18_END
return rmw_tensor

Expand Down