Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions recipe/simple_use_case/single_controller_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
import logging
import os
import random
import sys
import time
import uuid
from dataclasses import dataclass, field
from importlib import resources
from pathlib import Path

import ray
import torch
Expand All @@ -35,9 +33,6 @@
import transfer_queue as tq
from transfer_queue import KVBatchMeta

parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

Expand Down
7 changes: 1 addition & 6 deletions scripts/performance_test/perftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,15 @@
import csv
import logging
import os
import sys
import time
from pathlib import Path
from typing import Any

import ray
import torch
from omegaconf import OmegaConf
from tensordict import NonTensorStack, TensorDict

parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

import transfer_queue as tq # noqa: E402
import transfer_queue as tq

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
Expand Down
5 changes: 0 additions & 5 deletions scripts/performance_test/ray_perftest_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,13 @@
import csv
import logging
import os
import sys
import time
from pathlib import Path
from typing import Any

import ray
import torch
from tensordict import NonTensorStack, TensorDict

parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

Expand Down
15 changes: 5 additions & 10 deletions scripts/put_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import logging
import math
import os
import sys
import time
from pathlib import Path

import numpy as np
import ray
Expand All @@ -30,14 +28,11 @@
from tensordict import TensorDict
from tensordict.utils import LinkedList

parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

from transfer_queue import TransferQueueClient # noqa: E402
from transfer_queue.controller import TransferQueueController # noqa: E402
from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402
from transfer_queue.utils.common import get_placement_group # noqa: E402
from transfer_queue.utils.zmq_utils import process_zmq_server_info # noqa: E402
from transfer_queue import TransferQueueClient
from transfer_queue.controller import TransferQueueController
from transfer_queue.storage.simple_backend import SimpleStorageUnit
from transfer_queue.utils.common import get_placement_group
from transfer_queue.utils.zmq_utils import process_zmq_server_info

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down
5 changes: 0 additions & 5 deletions tests/e2e/test_e2e_lifecycle_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os
import sys
import time
from pathlib import Path

import numpy as np
import pytest
Expand All @@ -26,10 +25,6 @@
from tensordict import TensorDict
from tensordict.tensorclass import NonTensorData

# Setup paths (transfer_queue is not pip-installed)
parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

# Module-level default fields to avoid repeated generation
DEFAULT_FIELDS = [
"tensor_f32",
Expand Down
8 changes: 1 addition & 7 deletions tests/e2e/test_kv_interface_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,14 @@

import asyncio
import os
import sys
from pathlib import Path

import pytest
import ray
import torch
from omegaconf import OmegaConf
from tensordict import TensorDict

# Add parent directory to path
parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

import transfer_queue as tq # noqa: E402
import transfer_queue as tq


class TQAPIWrapper:
Expand Down
14 changes: 4 additions & 10 deletions tests/test_async_simple_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch

import numpy as np
Expand All @@ -24,14 +22,10 @@
import zmq
from tensordict import NonTensorStack, TensorDict

# Setup path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

from transfer_queue.metadata import BatchMeta # noqa: E402
from transfer_queue.storage import AsyncSimpleStorageManager # noqa: E402
from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo # noqa: E402
from transfer_queue.metadata import BatchMeta
from transfer_queue.storage import AsyncSimpleStorageManager
from transfer_queue.utils.enum_utils import TransferQueueRole
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo


@pytest_asyncio.fixture
Expand Down
16 changes: 4 additions & 12 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import time
from pathlib import Path
from threading import Thread
from unittest.mock import patch

Expand All @@ -24,16 +22,10 @@
import zmq
from tensordict import NonTensorStack, TensorDict

# Import your classes here
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

from transfer_queue import TransferQueueClient # noqa: E402
from transfer_queue.metadata import ( # noqa: E402
BatchMeta,
)
from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402
from transfer_queue.utils.zmq_utils import ( # noqa: E402
from transfer_queue import TransferQueueClient
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.enum_utils import TransferQueueRole
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
Expand Down
7 changes: 1 addition & 6 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,17 @@
# limitations under the License.

import logging
import sys
from pathlib import Path

import pytest
import ray
import torch

parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
from transfer_queue.controller import TransferQueueController

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

from transfer_queue.controller import TransferQueueController # noqa: E402


@pytest.fixture(scope="function")
def ray_setup():
Expand Down
6 changes: 0 additions & 6 deletions tests/test_controller_data_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,7 @@

import logging
import os
import sys
import time
from pathlib import Path

parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))


# Set up logging
logging.basicConfig(level=logging.INFO)
Expand Down
10 changes: 2 additions & 8 deletions tests/test_kv_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
import torch
from tensordict import TensorDict

# Setup path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

from transfer_queue.metadata import BatchMeta # noqa: E402
from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402
from transfer_queue.metadata import BatchMeta
from transfer_queue.storage.managers.base import KVStorageManager


def get_meta(data, global_indexes=None):
Expand Down
9 changes: 1 addition & 8 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,11 @@

"""Unit tests for TransferQueue metadata module - Columnar BatchMeta + KVBatchMeta."""

import sys
from pathlib import Path

import numpy as np
import pytest
import torch

# Setup path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

from transfer_queue.metadata import BatchMeta, KVBatchMeta # noqa: E402
from transfer_queue.metadata import BatchMeta, KVBatchMeta

# ==============================================================================
# Columnar BatchMeta Tests
Expand Down
15 changes: 5 additions & 10 deletions tests/test_ray_p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import time
from pathlib import Path

import numpy as np
import ray
import torch
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from tensordict import TensorDict

parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

from transfer_queue.client import TransferQueueClient # noqa: E402
from transfer_queue.metadata import BatchMeta # noqa: E402
from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory # noqa: E402
from transfer_queue.utils.zmq_utils import ZMQServerInfo # noqa: E402
from transfer_queue.client import TransferQueueClient
from transfer_queue.metadata import BatchMeta
from transfer_queue.storage.managers.base import KVStorageManager
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
from transfer_queue.utils.zmq_utils import ZMQServerInfo

TEST_CONFIGS: list[tuple[tuple[int, int], torch.dtype]] = [
((5000, 5000), torch.float32),
Expand Down
14 changes: 4 additions & 10 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,14 @@

"""Unit tests for TransferQueue samplers."""

import sys
from pathlib import Path
from typing import Any

import pytest

# Setup path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

from transfer_queue.sampler import BaseSampler # noqa: E402
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler # noqa: E402
from transfer_queue.sampler.rank_aware_sampler import RankAwareSampler # noqa: E402
from transfer_queue.sampler.sequential_sampler import SequentialSampler # noqa: E402
from transfer_queue.sampler import BaseSampler
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler
from transfer_queue.sampler.rank_aware_sampler import RankAwareSampler
from transfer_queue.sampler.sequential_sampler import SequentialSampler


class TestBaseSampler:
Expand Down
8 changes: 1 addition & 7 deletions tests/test_serial_utils_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from pathlib import Path

import numpy as np
import pytest
import torch
from tensordict import TensorDict

# Import your classes here
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

from transfer_queue.utils.serial_utils import MsgpackDecoder, MsgpackEncoder # noqa: E402
from transfer_queue.utils.serial_utils import MsgpackDecoder, MsgpackEncoder


@pytest.mark.parametrize(
Expand Down
10 changes: 2 additions & 8 deletions tests/test_simple_storage_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import time
from pathlib import Path

import pytest
import ray
import tensordict
import torch
import zmq

# Setup path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType # noqa: E402
from transfer_queue.storage.simple_backend import SimpleStorageUnit
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType


class MockStorageClient:
Expand Down
Loading
Loading