Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/test_async_simple_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,61 @@ def test_numpy_array(self):
result = AsyncSimpleStorageManager._select_by_positions(arr, [0, 2])
np.testing.assert_array_equal(result, np.array([10, 30]))

def test_regular_tensor_single_element(self):
"""Case 1: Single element selection returns a single-row view."""
t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
result = AsyncSimpleStorageManager._select_by_positions(t, [1])
assert result.shape == (1, 2)
assert torch.equal(result, torch.tensor([[3.0, 4.0]]))

def test_regular_tensor_strided_slice(self):
"""Case 2: Constant stride (step > 1) uses Python slicing for zero-copy view."""
t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]])
result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2, 4])
# positions form constant stride of 2
expected = torch.tensor([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]])
assert torch.equal(result, expected)

def test_regular_tensor_irregular_indices_fallback(self):
"""Case 3: Irregular indices fall back to index_select to avoid ZMQ frame fragmentation."""
t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
# positions [0, 2, 3] have irregular gaps (2, then 1) - not constant stride
result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2, 3])
expected = torch.tensor([[1.0, 2.0], [5.0, 6.0], [7.0, 8.0]])
assert torch.equal(result, expected)

def test_regular_tensor_irregular_reverse_order(self):
"""Irregular indices in reverse order also falls back to index_select."""
t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
result = AsyncSimpleStorageManager._select_by_positions(t, [3, 1, 0])
expected = torch.tensor([[7.0, 8.0], [3.0, 4.0], [1.0, 2.0]])
assert torch.equal(result, expected)

def test_nested_tensor_single_element(self):
"""Single element from nested tensor uses the lambda path."""
t = torch.nested.as_nested_tensor(
[torch.tensor([1.0]), torch.tensor([2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])],
layout=torch.jagged,
)
result = AsyncSimpleStorageManager._select_by_positions(t, [1])
assert isinstance(result, list)
assert len(result) == 1
assert torch.equal(result[0], torch.tensor([2.0, 3.0]))

def test_empty_positions_raises_error(self):
"""Empty positions list should raise ValueError."""
t = torch.tensor([1.0, 2.0, 3.0])
with pytest.raises(ValueError, match="No positions specified"):
AsyncSimpleStorageManager._select_by_positions(t, [])

def test_regular_tensor_negative_stride_rejected(self):
"""Negative stride (reversed order) should fall back to index_select."""
t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
# positions [2, 1, 0] have step = -1 (negative)
result = AsyncSimpleStorageManager._select_by_positions(t, [2, 1, 0])
expected = torch.tensor([[5.0, 6.0], [3.0, 4.0], [1.0, 2.0]])
assert torch.equal(result, expected)


class TestPackFieldValues:
"""Test _pack_field_values static method packing logic."""
Expand Down
82 changes: 71 additions & 11 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,70 @@ def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]:
def _select_by_positions(field_data, positions: list[int]):
"""Slice a single field's data by non-contiguous batch positions.

Handles four data types:
- Nested tensors: unbind → select → return as list
- NonTensorStack: tolist → select → re-wrap
- list: direct index selection via itemgetter
- Regular tensors / numpy arrays: fancy indexing
This method optimizes selection to minimize memory overhead and network fragmentation:
- Nested tensors: Unbinds into a list of views (end-to-end zero-copy).
- Regular tensors (step == 1): Returns a contiguous slice (end-to-end zero-copy).
- Regular tensors (step > 1): Returns a strided view (shares storage). Note that
downstream serialization will force a `.contiguous()` copy, but slicing is still
faster than `index_select` and the peak memory period is reduced.
- Regular tensors (irregular): Falls back to `index_select` to assemble a single
contiguous tensor, preventing excessive ZMQ multipart frames.
- NonTensorStack: tolist → select → re-wrap.
- List: Direct index selection via `itemgetter`.
- Numpy arrays / Others: Advanced indexing (memory copy).
"""
if isinstance(field_data, torch.Tensor) and field_data.is_nested:
unbound = field_data.unbind()
getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],)
selected = getter(unbound)
return list(selected)

n = len(positions)
if n == 0:
raise ValueError("No positions specified for selection.")

# --- Handle PyTorch Tensors ---
if isinstance(field_data, torch.Tensor):
if field_data.is_nested:
# Nested tensors cannot be directly sliced into a single tensor view.
# Unbinding and selecting returns a list of individual views (zero-copy),
# which is acceptable for nested structures.
unbound = field_data.unbind()
getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],)
selected = getter(unbound)
return list(selected)
else:
# --- Smart Slicing for Regular Tensors ---
# Goal: Return a single underlying memory view (zero-copy) to avoid both
# memory allocation overhead and downstream ZMQ frame fragmentation.

# Case 1: Single element selection (returns a single-row view)
if n == 1:
# Single element is natively contiguous
return field_data[positions[0] : positions[0] + 1]

# Case 2: Check if positions form a constant-stride sequence
step = positions[1] - positions[0]
is_constant_stride = True
for i in range(2, n):
if positions[i] - positions[i - 1] != step:
is_constant_stride = False
break

# If perfectly regular (e.g., [0, 2, 4]), use Python slicing to get a view
if is_constant_stride and step > 0:
# Note:
# A strided slice (step > 1) creates a non-contiguous view.
# While it shares storage here, the downstream MsgpackEncoder will force
# a .contiguous() copy before extracting the buffer. However, this pure
# Python slicing is still more efficient than falling back to index_select,
# and it reduces memory peak period.
return field_data[positions[0] : positions[-1] + 1 : step]

# Case 3: Fallback for irregular indices (Typically this will not happen!)
# We intentionally accept a memory copy here to assemble a single contiguous
# tensor. Returning a list of individual views for irregular indices would
# generate excessive multipart ZMQ frames, severely degrading network performance.
else:
idx_tensor = torch.tensor(positions, device=field_data.device)
return torch.index_select(field_data, dim=0, index=idx_tensor)

# --- Handle Non-Tensor Types ---
elif isinstance(field_data, NonTensorStack):
items = field_data.tolist()
getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],)
Expand Down Expand Up @@ -344,7 +397,14 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack:
if all(isinstance(v, torch.Tensor) for v in values):
if all(v.shape == values[0].shape for v in values):
return torch.stack(values)
return torch.nested.as_nested_tensor(values, layout=torch.jagged)
try:
return torch.nested.as_nested_tensor(values, layout=torch.jagged)
except (RuntimeError, TypeError) as e:
logger.warning(
f"Failed to pack nested tensor with jagged layout. "
f"Falling back to strided layout. Detailed error: {e}"
)
return torch.nested.as_nested_tensor(values, layout=torch.strided)
return NonTensorStack(*values)

async def get_data(self, metadata: BatchMeta) -> TensorDict:
Expand Down
Loading