Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/tutorial-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,14 @@ jobs:
python -m pip install --upgrade pip
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install -e ".[yuanrong]"
- name: Run tutorials
pip install nbconvert ipykernel
- name: Run Python tutorials
run: |
export TQ_NUM_THREADS=2
export RAY_DEDUP_LOGS=0
for file in tutorial/*.py; do python3 "$file"; done
for file in tutorial/*.py; do python3 "$file"; done
- name: Run notebook tutorials
run: |
export TQ_NUM_THREADS=2
export RAY_DEDUP_LOGS=0
jupyter nbconvert --to notebook --execute --ExecutePreprocessor.timeout=120 tutorial/basic.ipynb
35 changes: 30 additions & 5 deletions tests/e2e/test_kv_interface_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""):
assert torch.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol), f"{msg} Tensors are not close"


def assert_nested_tensor_equal(nested_a, nested_b, msg=""):
"""Assert two nested (jagged) tensors are equal component-wise."""
components_a = list(nested_a)
components_b = list(nested_b)
assert len(components_a) == len(components_b), f"{msg} Length mismatch: {len(components_a)} vs {len(components_b)}"
for i, (a, b) in enumerate(zip(components_a, components_b, strict=True)):
assert torch.equal(a, b), f"{msg} Component {i} not equal: {a} vs {b}"


class TestKVPutE2E:
"""End-to-end tests for kv_put functionality."""

Expand Down Expand Up @@ -518,6 +527,25 @@ def test_kv_batch_put_returns_cumulative_fields(self, controller, tq_api):
class TestKVGetE2E:
"""End-to-end tests for kv_batch_get functionality."""

def test_kv_batch_get_nested_tensor(self, controller, tq_api):
# test put a regular tensor with batch size 1 and get it back as a nested tensor
partition_id = "test_partition"
keys = []
data_list = []

for i in range(1, 4):
key = f"nested_tensor_{i}"
keys.append(key)
data = torch.randn(size=(i,))
data_list.append(data)
fields = TensorDict({"data": data.unsqueeze(0)}, batch_size=1)
tq_api.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None)

retrieved = tq_api.kv_batch_get(keys=keys, partition_id=partition_id)

assert_nested_tensor_equal(retrieved["data"], torch.nested.as_nested_tensor(data_list, layout=torch.jagged))
tq_api.kv_clear(keys=keys, partition_id=partition_id)

def test_kv_batch_get_single_key(self, controller, tq_api):
"""Test getting data for a single key."""
partition_id = "test_partition"
Expand Down Expand Up @@ -569,11 +597,8 @@ def test_kv_batch_get_partial_keys(self, controller, tq_api):
retrieved = tq_api.kv_batch_get(keys=partial_keys, partition_id=partition_id)
assert_tensor_equal(retrieved["data"], expected_data)

for actual, expected in zip(retrieved["nested_data"], expected_nested_data, strict=True):
assert_tensor_equal(actual, expected)

for actual, expected in zip(retrieved["three_d_nested_data"], expected_three_d_nested_data, strict=True):
assert_tensor_equal(actual, expected)
assert_nested_tensor_equal(retrieved["nested_data"], expected_nested_data)
assert_nested_tensor_equal(retrieved["three_d_nested_data"], expected_three_d_nested_data)

tq_api.kv_clear(keys=keys, partition_id=partition_id)

Expand Down
17 changes: 12 additions & 5 deletions tutorial/basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2026-03-25 15:57:42,882\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n",
"2026-03-27 23:06:04,557\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n",
"/opt/miniconda3/envs/verl/lib/python3.11/site-packages/ray/_private/worker.py:2062: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\n",
" warnings.warn(\n"
]
Expand All @@ -58,7 +58,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33m(raylet)\u001b[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace=\"d7384875-7cae-4308-b16c-c92541dc9f07\", ...)\n",
"\u001b[33m(raylet)\u001b[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace=\"798d49c0-a4e8-4877-813b-bce2d966c4eb\", ...)\n",
"TransferQueue is ready!\n"
]
}
Expand Down Expand Up @@ -386,7 +386,7 @@
],
"source": [
"# Retrieve only input_ids (single field)\n",
"result = tq.kv_batch_get(keys=\"sample_1\", partition_id=\"train\", fields=\"input_ids\")\n",
"result = tq.kv_batch_get(keys=\"sample_1\", partition_id=\"train\", select_fields=\"input_ids\")\n",
"print(\"Fields returned:\", list(result.keys()))\n",
"assert \"input_ids\" in result.keys()\n",
"assert \"attention_mask\" not in result.keys()\n",
Expand All @@ -395,7 +395,7 @@
"result = tq.kv_batch_get(\n",
" keys=\"sample_1\",\n",
" partition_id=\"train\",\n",
" fields=[\"input_ids\", \"attention_mask\"],\n",
" select_fields=[\"input_ids\", \"attention_mask\"],\n",
")\n",
"print(\"Fields returned:\", list(result.keys()))"
]
Expand Down Expand Up @@ -969,7 +969,7 @@
"| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n",
"| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n",
"| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n",
"| Get | `tq.kv_batch_get(keys, partition_id, fields=None)` | Returns a `TensorDict` |\n",
"| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n",
"| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n",
"| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n",
"| Close | `tq.close()` | Tears down controller & storage |\n",
Expand All @@ -980,6 +980,13 @@
"For low-level, metadata-based access, see `tq.get_client()` and the\n",
"[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading