From 82db416a70a22b537f299368f2b1308885e4fd18 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 20 Mar 2026 11:53:54 +0800 Subject: [PATCH 1/6] optimize memory usage for regular tensor Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 3 ++- .../storage/managers/simple_backend_manager.py | 15 ++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 0108bdc..fa6d1f9 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -457,7 +457,8 @@ class TestSelectByPositions: def test_regular_tensor(self): t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2]) - assert torch.equal(result, torch.tensor([[1.0, 2.0], [5.0, 6.0]])) + assert torch.equal(result[0], t[0]) + assert torch.equal(result[1], t[2]) def test_nested_tensor(self): t = torch.nested.as_nested_tensor( diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 1fc82bf..ed5847a 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -215,11 +215,16 @@ def _select_by_positions(field_data, positions: list[int]): - list: direct index selection via itemgetter - Regular tensors / numpy arrays: fancy indexing """ - if isinstance(field_data, torch.Tensor) and field_data.is_nested: - unbound = field_data.unbind() - getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],) - selected = getter(unbound) - return list(selected) + if isinstance(field_data, torch.Tensor): + if field_data.is_nested: + # for nested tensor, unbind and select will not lead to memory copy + unbound = field_data.unbind() + getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],) + selected = getter(unbound) + return list(selected) + else: + # for ordinary tensor, use simple view will also prevent memory copy + return [field_data[i] for i in positions] elif isinstance(field_data, NonTensorStack): items = field_data.tolist() getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],) From 2fd9fba333f938a60ed3f9ed5d029c9ddd0b7748 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 20 Mar 2026 12:44:51 +0800 Subject: [PATCH 2/6] fix comments Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 2 ++ transfer_queue/storage/managers/simple_backend_manager.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index fa6d1f9..7650236 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -457,6 +457,8 @@ class TestSelectByPositions: def test_regular_tensor(self): t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2]) + assert isinstance(result, list) + assert len(result) == 2 assert torch.equal(result[0], t[0]) assert torch.equal(result[1], t[2]) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index ed5847a..31ba4bb 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -211,9 +211,10 @@ def _select_by_positions(field_data, positions: list[int]): Handles four data types: - Nested tensors: unbind → select → return as list + - Regular tensors: select by single position → warp by list - NonTensorStack: tolist → select → re-wrap - list: direct index selection via itemgetter - - Regular tensors / numpy arrays: fancy indexing + - numpy arrays: fancy indexing """ if isinstance(field_data, torch.Tensor): if field_data.is_nested: From 85baef8a98da85fc6dc4defe16e60e8180227a82 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 20 Mar 2026 13:04:58 +0800 Subject: [PATCH 3/6] more advanced strategy Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 5 +- .../managers/simple_backend_manager.py | 64 ++++++++++++++++--- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 7650236..0108bdc 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -457,10 +457,7 @@ class TestSelectByPositions: def test_regular_tensor(self): t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2]) - assert isinstance(result, list) - assert len(result) == 2 - assert torch.equal(result[0], t[0]) - assert torch.equal(result[1], t[2]) + assert torch.equal(result, torch.tensor([[1.0, 2.0], [5.0, 6.0]])) def test_nested_tensor(self): t = torch.nested.as_nested_tensor( diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 31ba4bb..2e024c9 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -209,23 +209,69 @@ def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]: def _select_by_positions(field_data, positions: list[int]): """Slice a single field's data by non-contiguous batch positions. - Handles four data types: - - Nested tensors: unbind → select → return as list - - Regular tensors: select by single position → warp by list - - NonTensorStack: tolist → select → re-wrap - - list: direct index selection via itemgetter - - numpy arrays: fancy indexing + This method attempts to preserve zero-copy views whenever possible, while + falling back to memory-copied single tensors when indices are irregular. + This prevents severe network fragmentation (emitting too many ZMQ frames) + during serialization. + + Supported data types: + - Nested tensors: unbind → select → return as a list of views (zero-copy). + - Regular tensors: Checks for constant-stride to return a single sliced view. + Falls back to `index_select` (memory copy) to ensure a single buffer. + - NonTensorStack: tolist → select → re-wrap. + - List: Direct index selection via `itemgetter`. + - Numpy arrays / Others: Advanced indexing (memory copy). """ + + n = len(positions) + if n == 0: + raise ValueError("No positions specified for selection.") + + # --- Handle PyTorch Tensors --- if isinstance(field_data, torch.Tensor): if field_data.is_nested: - # for nested tensor, unbind and select will not lead to memory copy + # Nested tensors cannot be directly sliced into a single tensor view. + # Unbinding and selecting returns a list of individual views (zero-copy), + # which is acceptable for nested structures. unbound = field_data.unbind() getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],) selected = getter(unbound) return list(selected) else: - # for ordinary tensor, use simple view will also prevent memory copy - return [field_data[i] for i in positions] + # --- Smart Slicing for Regular Tensors --- + # Goal: Return a single underlying memory view (zero-copy) to avoid both + # memory allocation overhead and downstream ZMQ frame fragmentation. + + # Case 1: Single element selection (returns a single-row view) + if n == 1: + return field_data[positions[0] : positions[0] + 1] + + # Case 2: Check if positions form a constant-stride sequence + step = positions[1] - positions[0] + is_constant_stride = True + for i in range(2, n): + if positions[i] - positions[i - 1] != step: + is_constant_stride = False + break + + # If perfectly regular (e.g., [0, 2, 4]), use Python slicing to get a view + if is_constant_stride and step > 0: + return field_data[positions[0] : positions[-1] + 1 : step] + + # Case 3: Fallback for irregular indices (Typically this will not happen!) + # We intentionally accept a memory copy here to assemble a single contiguous + # tensor. Returning a list of individual views for irregular indices would + # generate excessive multipart ZMQ frames, severely degrading network performance. + else: + logger.debug( + f"Irregular indices detected for tensor of shape {field_data.shape}. " + "Falling back to index_select (memory copy will occur)." + ) + # Note: Ensure idx_tensor is on the same device as field_data + idx_tensor = torch.tensor(positions, device=field_data.device) + return torch.index_select(field_data, dim=0, index=idx_tensor) + + # --- Handle Non-Tensor Types --- elif isinstance(field_data, NonTensorStack): items = field_data.tolist() getter = itemgetter(*positions) if len(positions) > 1 else lambda seq: (seq[positions[0]],) From 6bfd22ac89728405f0194264f1d0b9ab66d1b059 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 20 Mar 2026 16:01:34 +0800 Subject: [PATCH 4/6] update comments Signed-off-by: 0oshowero0 --- .../managers/simple_backend_manager.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 2e024c9..1f07cba 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -209,15 +209,14 @@ def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]: def _select_by_positions(field_data, positions: list[int]): """Slice a single field's data by non-contiguous batch positions. - This method attempts to preserve zero-copy views whenever possible, while - falling back to memory-copied single tensors when indices are irregular. - This prevents severe network fragmentation (emitting too many ZMQ frames) - during serialization. - - Supported data types: - - Nested tensors: unbind → select → return as a list of views (zero-copy). - - Regular tensors: Checks for constant-stride to return a single sliced view. - Falls back to `index_select` (memory copy) to ensure a single buffer. + This method optimizes selection to minimize memory overhead and network fragmentation: + - Nested tensors: Unbinds into a list of views (end-to-end zero-copy). + - Regular tensors (step == 1): Returns a contiguous slice (end-to-end zero-copy). + - Regular tensors (step > 1): Returns a strided view (shares storage). Note that + downstream serialization will force a `.contiguous()` copy, but slicing is still + faster than `index_select` and the peak memory period is reduced. + - Regular tensors (irregular): Falls back to `index_select` to assemble a single + contiguous tensor, preventing excessive ZMQ multipart frames. - NonTensorStack: tolist → select → re-wrap. - List: Direct index selection via `itemgetter`. - Numpy arrays / Others: Advanced indexing (memory copy). @@ -244,6 +243,7 @@ def _select_by_positions(field_data, positions: list[int]): # Case 1: Single element selection (returns a single-row view) if n == 1: + # Single element is natively contiguous return field_data[positions[0] : positions[0] + 1] # Case 2: Check if positions form a constant-stride sequence @@ -256,6 +256,12 @@ def _select_by_positions(field_data, positions: list[int]): # If perfectly regular (e.g., [0, 2, 4]), use Python slicing to get a view if is_constant_stride and step > 0: + # Note: + # A strided slice (step > 1) creates a non-contiguous view. + # While it shares storage here, the downstream MsgpackEncoder will force + # a .contiguous() copy before extracting the buffer. However, this pure + # Python slicing is still more efficient than falling back to index_select, + # and it reduces memory peak period. return field_data[positions[0] : positions[-1] + 1 : step] # Case 3: Fallback for irregular indices (Typically this will not happen!) @@ -263,11 +269,6 @@ def _select_by_positions(field_data, positions: list[int]): # tensor. Returning a list of individual views for irregular indices would # generate excessive multipart ZMQ frames, severely degrading network performance. else: - logger.debug( - f"Irregular indices detected for tensor of shape {field_data.shape}. " - "Falling back to index_select (memory copy will occur)." - ) - # Note: Ensure idx_tensor is on the same device as field_data idx_tensor = torch.tensor(positions, device=field_data.device) return torch.index_select(field_data, dim=0, index=idx_tensor) From a8f94154af8673f506449c100bf7cb330c5b3e6e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 20 Mar 2026 16:53:29 +0800 Subject: [PATCH 5/6] add fallback for tensor packing Signed-off-by: 0oshowero0 --- .../storage/managers/simple_backend_manager.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 1f07cba..7a94fe1 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -397,7 +397,14 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack: if all(isinstance(v, torch.Tensor) for v in values): if all(v.shape == values[0].shape for v in values): return torch.stack(values) - return torch.nested.as_nested_tensor(values, layout=torch.jagged) + try: + return torch.nested.as_nested_tensor(values, layout=torch.jagged) + except RuntimeError as e: + logger.warning( + f"Failed to pack nested tensor with jagged layout. " + f"Try to fallback as strided layout. Detailed error: {e}" + ) + return torch.nested.as_nested_tensor(values, layout=torch.strided) return NonTensorStack(*values) async def get_data(self, metadata: BatchMeta) -> TensorDict: From ce14308ee2d8288d48cb65ecb2d584db98064834 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 20 Mar 2026 17:06:52 +0800 Subject: [PATCH 6/6] add UT Signed-off-by: 0oshowero0 --- tests/test_async_simple_storage_manager.py | 55 +++++++++++++++++++ .../managers/simple_backend_manager.py | 4 +- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 0108bdc..2b16d50 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -486,6 +486,61 @@ def test_numpy_array(self): result = AsyncSimpleStorageManager._select_by_positions(arr, [0, 2]) np.testing.assert_array_equal(result, np.array([10, 30])) + def test_regular_tensor_single_element(self): + """Case 1: Single element selection returns a single-row view.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + result = AsyncSimpleStorageManager._select_by_positions(t, [1]) + assert result.shape == (1, 2) + assert torch.equal(result, torch.tensor([[3.0, 4.0]])) + + def test_regular_tensor_strided_slice(self): + """Case 2: Constant stride (step > 1) uses Python slicing for zero-copy view.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]) + result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2, 4]) + # positions form constant stride of 2 + expected = torch.tensor([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]]) + assert torch.equal(result, expected) + + def test_regular_tensor_irregular_indices_fallback(self): + """Case 3: Irregular indices fall back to index_select to avoid ZMQ frame fragmentation.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + # positions [0, 2, 3] have irregular gaps (2, then 1) - not constant stride + result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2, 3]) + expected = torch.tensor([[1.0, 2.0], [5.0, 6.0], [7.0, 8.0]]) + assert torch.equal(result, expected) + + def test_regular_tensor_irregular_reverse_order(self): + """Irregular indices in reverse order also falls back to index_select.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + result = AsyncSimpleStorageManager._select_by_positions(t, [3, 1, 0]) + expected = torch.tensor([[7.0, 8.0], [3.0, 4.0], [1.0, 2.0]]) + assert torch.equal(result, expected) + + def test_nested_tensor_single_element(self): + """Single element from nested tensor uses the lambda path.""" + t = torch.nested.as_nested_tensor( + [torch.tensor([1.0]), torch.tensor([2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])], + layout=torch.jagged, + ) + result = AsyncSimpleStorageManager._select_by_positions(t, [1]) + assert isinstance(result, list) + assert len(result) == 1 + assert torch.equal(result[0], torch.tensor([2.0, 3.0])) + + def test_empty_positions_raises_error(self): + """Empty positions list should raise ValueError.""" + t = torch.tensor([1.0, 2.0, 3.0]) + with pytest.raises(ValueError, match="No positions specified"): + AsyncSimpleStorageManager._select_by_positions(t, []) + + def test_regular_tensor_negative_stride_rejected(self): + """Negative stride (reversed order) should fall back to index_select.""" + t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + # positions [2, 1, 0] have step = -1 (negative) + result = AsyncSimpleStorageManager._select_by_positions(t, [2, 1, 0]) + expected = torch.tensor([[5.0, 6.0], [3.0, 4.0], [1.0, 2.0]]) + assert torch.equal(result, expected) + class TestPackFieldValues: """Test _pack_field_values static method packing logic.""" diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 7a94fe1..27e173c 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -399,10 +399,10 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack: return torch.stack(values) try: return torch.nested.as_nested_tensor(values, layout=torch.jagged) - except RuntimeError as e: + except (RuntimeError, TypeError) as e: logger.warning( f"Failed to pack nested tensor with jagged layout. " - f"Try to fallback as strided layout. Detailed error: {e}" + f"Falling back to strided layout. Detailed error: {e}" ) return torch.nested.as_nested_tensor(values, layout=torch.strided) return NonTensorStack(*values)