diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 0108bdc..2b16d50 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -486,6 +486,61 @@ def test_numpy_array(self): result = AsyncSimpleStorageManager._select_by_positions(arr, [0, 2]) np.testing.assert_array_equal(result, np.array([10, 30])) + def test_regular_tensor_single_element(self): + """Case 1: Single element selection returns a single-row view.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + result = AsyncSimpleStorageManager._select_by_positions(t, [1]) + assert result.shape == (1, 2) + assert torch.equal(result, torch.tensor([[3.0, 4.0]])) + + def test_regular_tensor_strided_slice(self): + """Case 2: Constant stride (step > 1) uses Python slicing for zero-copy view.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]) + result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2, 4]) + # positions form constant stride of 2 + expected = torch.tensor([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]]) + assert torch.equal(result, expected) + + def test_regular_tensor_irregular_indices_fallback(self): + """Case 3: Irregular indices fall back to index_select to avoid ZMQ frame fragmentation.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + # positions [0, 2, 3] have irregular gaps (2, then 1) - not constant stride + result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2, 3]) + expected = torch.tensor([[1.0, 2.0], [5.0, 6.0], [7.0, 8.0]]) + assert torch.equal(result, expected) + + def test_regular_tensor_irregular_reverse_order(self): + """Irregular indices in reverse order also falls back to index_select.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + result = AsyncSimpleStorageManager._select_by_positions(t, [3, 1, 0]) + expected = torch.tensor([[7.0, 8.0], [3.0, 4.0], [1.0, 2.0]]) + assert torch.equal(result, expected) + + def test_nested_tensor_single_element(self): + """Single element from nested tensor uses the lambda path.""" + t = torch.nested.as_nested_tensor( + [torch.tensor([1.0]), torch.tensor([2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])], + layout=torch.jagged, + ) + result = AsyncSimpleStorageManager._select_by_positions(t, [1]) + assert isinstance(result, list) + assert len(result) == 1 + assert torch.equal(result[0], torch.tensor([2.0, 3.0])) + + def test_empty_positions_raises_error(self): + """Empty positions list should raise ValueError.""" + t = torch.tensor([1.0, 2.0, 3.0]) + with pytest.raises(ValueError, match="No positions specified"): + AsyncSimpleStorageManager._select_by_positions(t, []) + + def test_regular_tensor_negative_stride_rejected(self): + """Negative stride (reversed order) should fall back to index_select.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + # positions [2, 1, 0] have step = -1 (negative) + result = AsyncSimpleStorageManager._select_by_positions(t, [2, 1, 0]) + expected = torch.tensor([[5.0, 6.0], [3.0, 4.0], [1.0, 2.0]]) + assert torch.equal(result, expected) + class TestPackFieldValues: """Test _pack_field_values static method packing logic.""" diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 1fc82bf..27e173c 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -209,17 +209,70 @@ def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]: def _select_by_positions(field_data, positions: list[int]): """Slice a single field's data by non-contiguous batch positions. - Handles four data types: - - Nested tensors: unbind → select → return as list - - NonTensorStack: tolist → select → re-wrap - - list: direct index selection via itemgetter - - Regular tensors / numpy arrays: fancy indexing + This method optimizes selection to minimize memory overhead and network fragmentation: + - Nested tensors: Unbinds into a list of views (end-to-end zero-copy). + - Regular tensors (step == 1): Returns a contiguous slice (end-to-end zero-copy). + - Regular tensors (step > 1): Returns a strided view (shares storage). Note that + downstream serialization will force a `.contiguous()` copy, but slicing is still + faster than `index_select` and the peak memory period is reduced. + - Regular tensors (irregular): Falls back to `index_select` to assemble a single + contiguous tensor, preventing excessive ZMQ multipart frames. + - NonTensorStack: tolist → select → re-wrap. + - List: Direct index selection via `itemgetter`. + - Numpy arrays / Others: Advanced indexing (memory copy). """ - if isinstance(field_data, torch.Tensor) and field_data.is_nested: - unbound = field_data.unbind() - getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],) - selected = getter(unbound) - return list(selected) + + n = len(positions) + if n == 0: + raise ValueError("No positions specified for selection.") + + # --- Handle PyTorch Tensors --- + if isinstance(field_data, torch.Tensor): + if field_data.is_nested: + # Nested tensors cannot be directly sliced into a single tensor view. + # Unbinding and selecting returns a list of individual views (zero-copy), + # which is acceptable for nested structures. + unbound = field_data.unbind() + getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],) + selected = getter(unbound) + return list(selected) + else: + # --- Smart Slicing for Regular Tensors --- + # Goal: Return a single underlying memory view (zero-copy) to avoid both + # memory allocation overhead and downstream ZMQ frame fragmentation. + + # Case 1: Single element selection (returns a single-row view) + if n == 1: + # Single element is natively contiguous + return field_data[positions[0] : positions[0] + 1] + + # Case 2: Check if positions form a constant-stride sequence + step = positions[1] - positions[0] + is_constant_stride = True + for i in range(2, n): + if positions[i] - positions[i - 1] != step: + is_constant_stride = False + break + + # If perfectly regular (e.g., [0, 2, 4]), use Python slicing to get a view + if is_constant_stride and step > 0: + # Note: + # A strided slice (step > 1) creates a non-contiguous view. + # While it shares storage here, the downstream MsgpackEncoder will force + # a .contiguous() copy before extracting the buffer. However, this pure + # Python slicing is still more efficient than falling back to index_select, + # and it reduces memory peak period. + return field_data[positions[0] : positions[-1] + 1 : step] + + # Case 3: Fallback for irregular indices (Typically this will not happen!) + # We intentionally accept a memory copy here to assemble a single contiguous + # tensor. Returning a list of individual views for irregular indices would + # generate excessive multipart ZMQ frames, severely degrading network performance. + else: + idx_tensor = torch.tensor(positions, device=field_data.device) + return torch.index_select(field_data, dim=0, index=idx_tensor) + + # --- Handle Non-Tensor Types --- elif isinstance(field_data, NonTensorStack): items = field_data.tolist() getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],) @@ -344,7 +397,14 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack: if all(isinstance(v, torch.Tensor) for v in values): if all(v.shape == values[0].shape for v in values): return torch.stack(values) - return torch.nested.as_nested_tensor(values, layout=torch.jagged) + try: + return torch.nested.as_nested_tensor(values, layout=torch.jagged) + except (RuntimeError, TypeError) as e: + logger.warning( + f"Failed to pack nested tensor with jagged layout. " + f"Falling back to strided layout. Detailed error: {e}" + ) + return torch.nested.as_nested_tensor(values, layout=torch.strided) return NonTensorStack(*values) async def get_data(self, metadata: BatchMeta) -> TensorDict: