diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml index c24cab8..67bbcf8 100644 --- a/.github/workflows/tutorial-check.yml +++ b/.github/workflows/tutorial-check.yml @@ -33,8 +33,14 @@ jobs: python -m pip install --upgrade pip pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -e ".[yuanrong]" - - name: Run tutorials + pip install nbconvert ipykernel + - name: Run Python tutorials run: | export TQ_NUM_THREADS=2 export RAY_DEDUP_LOGS=0 - for file in tutorial/*.py; do python3 "$file"; done \ No newline at end of file + for file in tutorial/*.py; do python3 "$file"; done + - name: Run notebook tutorials + run: | + export TQ_NUM_THREADS=2 + export RAY_DEDUP_LOGS=0 + jupyter nbconvert --to notebook --execute --ExecutePreprocessor.timeout=120 tutorial/basic.ipynb \ No newline at end of file diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 4fac9d2..6ace099 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -178,6 +178,15 @@ def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""): assert torch.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol), f"{msg} Tensors are not close" +def assert_nested_tensor_equal(nested_a, nested_b, msg=""): + """Assert two nested (jagged) tensors are equal component-wise.""" + components_a = list(nested_a) + components_b = list(nested_b) + assert len(components_a) == len(components_b), f"{msg} Length mismatch: {len(components_a)} vs {len(components_b)}" + for i, (a, b) in enumerate(zip(components_a, components_b, strict=True)): + assert torch.equal(a, b), f"{msg} Component {i} not equal: {a} vs {b}" + + class TestKVPutE2E: """End-to-end tests for kv_put functionality.""" @@ -518,6 +527,25 @@ def test_kv_batch_put_returns_cumulative_fields(self, controller, tq_api): class TestKVGetE2E: """End-to-end tests for kv_batch_get functionality.""" + def test_kv_batch_get_nested_tensor(self, controller, tq_api): + # test put a regular tensor with batch size 1 and get it back as a nested tensor + partition_id = "test_partition" + keys = [] + data_list = [] + + for i in range(1, 4): + key = f"nested_tensor_{i}" + keys.append(key) + data = torch.randn(size=(i,)) + data_list.append(data) + fields = TensorDict({"data": data.unsqueeze(0)}, batch_size=1) + tq_api.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) + + retrieved = tq_api.kv_batch_get(keys=keys, partition_id=partition_id) + + assert_nested_tensor_equal(retrieved["data"], torch.nested.as_nested_tensor(data_list, layout=torch.jagged)) + tq_api.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_get_single_key(self, controller, tq_api): """Test getting data for a single key.""" partition_id = "test_partition" @@ -569,11 +597,8 @@ def test_kv_batch_get_partial_keys(self, controller, tq_api): retrieved = tq_api.kv_batch_get(keys=partial_keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) - for actual, expected in zip(retrieved["nested_data"], expected_nested_data, strict=True): - assert_tensor_equal(actual, expected) - - for actual, expected in zip(retrieved["three_d_nested_data"], expected_three_d_nested_data, strict=True): - assert_tensor_equal(actual, expected) + assert_nested_tensor_equal(retrieved["nested_data"], expected_nested_data) + assert_nested_tensor_equal(retrieved["three_d_nested_data"], expected_three_d_nested_data) tq_api.kv_clear(keys=keys, partition_id=partition_id) diff --git a/tutorial/basic.ipynb b/tutorial/basic.ipynb index aaa5d44..5e5944d 100644 --- a/tutorial/basic.ipynb +++ b/tutorial/basic.ipynb @@ -49,7 +49,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2026-03-25 15:57:42,882\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n", + "2026-03-27 23:06:04,557\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n", "/opt/miniconda3/envs/verl/lib/python3.11/site-packages/ray/_private/worker.py:2062: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\n", " warnings.warn(\n" ] @@ -58,7 +58,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33m(raylet)\u001b[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace=\"d7384875-7cae-4308-b16c-c92541dc9f07\", ...)\n", + "\u001b[33m(raylet)\u001b[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace=\"798d49c0-a4e8-4877-813b-bce2d966c4eb\", ...)\n", "TransferQueue is ready!\n" ] } @@ -386,7 +386,7 @@ ], "source": [ "# Retrieve only input_ids (single field)\n", - "result = tq.kv_batch_get(keys=\"sample_1\", partition_id=\"train\", fields=\"input_ids\")\n", + "result = tq.kv_batch_get(keys=\"sample_1\", partition_id=\"train\", select_fields=\"input_ids\")\n", "print(\"Fields returned:\", list(result.keys()))\n", "assert \"input_ids\" in result.keys()\n", "assert \"attention_mask\" not in result.keys()\n", @@ -395,7 +395,7 @@ "result = tq.kv_batch_get(\n", " keys=\"sample_1\",\n", " partition_id=\"train\",\n", - " fields=[\"input_ids\", \"attention_mask\"],\n", + " select_fields=[\"input_ids\", \"attention_mask\"],\n", ")\n", "print(\"Fields returned:\", list(result.keys()))" ] @@ -969,7 +969,7 @@ "| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n", "| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n", "| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n", - "| Get | `tq.kv_batch_get(keys, partition_id, fields=None)` | Returns a `TensorDict` |\n", + "| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n", "| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n", "| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n", "| Close | `tq.close()` | Tears down controller & storage |\n", @@ -980,6 +980,13 @@ "For low-level, metadata-based access, see `tq.get_client()` and the\n", "[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)." ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {