Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 21 additions & 72 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
import logging
import os
import threading
from functools import wraps
from typing import Any, Callable, Optional
from uuid import uuid4
from typing import Any, Optional

import torch
import zmq
Expand All @@ -38,8 +36,7 @@
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
dynamic_zmq_socket,
)

logger = logging.getLogger(__name__)
Expand All @@ -53,6 +50,13 @@

TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))

# Pre-bound decorator for controller socket operations.
_controller_socket = dynamic_zmq_socket(
"request_handle_socket",
owner_id_attr="client_id",
server_attr="_controller",
)


class AsyncTransferQueueClient:
"""Asynchronous client for interacting with TransferQueue controller and storage systems.
Expand Down Expand Up @@ -99,63 +103,8 @@ def initialize_storage_manager(
manager_type, controller_info=self._controller, config=config
)

# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
@staticmethod
def dynamic_socket(socket_name: str):
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers.

Handles socket lifecycle: create -> connect -> inject -> close.

Args:
socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port")

Decorated Function Requirements:
1. Must be an async class method (needs `self`)
2. `self` must have:
- `_controller`: Server registry
- `client_id`: Unique client ID for socket identity
3. Receives ZMQ socket via `socket` keyword argument (injected by decorator)
"""

def decorator(func: Callable):
@wraps(func)
async def wrapper(self, *args, **kwargs):
server_info = self._controller
if not server_info:
raise RuntimeError("No controller registered")

context = zmq.asyncio.Context()
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip)

try:
sock.connect(address)
logger.debug(
f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} "
f"with identity {identity.decode()}"
)

kwargs["socket"] = sock
return await func(self, *args, **kwargs)
except Exception as e:
logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}")
raise
finally:
try:
if not sock.closed:
sock.close(linger=-1)
except Exception as e:
logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}")

context.term()

return wrapper

return decorator

# ==================== Basic API ====================
@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_get_meta(
self,
data_fields: list[str],
Expand Down Expand Up @@ -245,7 +194,7 @@ async def async_get_meta(
f"{response_msg.body.get('message', 'Unknown error')}"
)

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_set_custom_meta(
self,
metadata: BatchMeta,
Expand Down Expand Up @@ -534,7 +483,7 @@ async def async_clear_samples(self, metadata: BatchMeta):
except Exception as e:
raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
"""Clear metadata in the controller.

Expand All @@ -560,7 +509,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
raise RuntimeError("Failed to clear samples metadata in controller.")

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta:
"""Get metadata required for the whole partition from controller.

Expand Down Expand Up @@ -590,7 +539,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta

return response_msg.body["metadata"]

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def _clear_partition_in_controller(self, partition_id, socket=None):
"""Clear the whole partition in the controller.

Expand All @@ -617,7 +566,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
raise RuntimeError(f"Failed to clear partition {partition_id} in controller.")

# ==================== Status Query API ====================
@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_get_consumption_status(
self,
task_name: str,
Expand Down Expand Up @@ -680,7 +629,7 @@ async def async_get_consumption_status(
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_get_production_status(
self,
data_fields: list[str],
Expand Down Expand Up @@ -812,7 +761,7 @@ async def async_check_production_status(
return False
return torch.all(production_status == 1).item()

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_reset_consumption(
self,
partition_id: str,
Expand Down Expand Up @@ -874,7 +823,7 @@ async def async_reset_consumption(
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_get_partition_list(
self,
socket: Optional[zmq.asyncio.Socket] = None,
Expand Down Expand Up @@ -920,7 +869,7 @@ async def async_get_partition_list(
raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e

# ==================== KV Interface API ====================
@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_kv_retrieve_meta(
self,
keys: list[str] | str,
Expand Down Expand Up @@ -986,7 +935,7 @@ async def async_kv_retrieve_meta(
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_kv_retrieve_keys(
self,
global_indexes: list[int] | int,
Expand Down Expand Up @@ -1049,7 +998,7 @@ async def async_kv_retrieve_keys(
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e

@dynamic_socket(socket_name="request_handle_socket")
@_controller_socket
async def async_kv_list(
self,
partition_id: Optional[str] = None,
Expand Down
94 changes: 14 additions & 80 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
import warnings
from collections import defaultdict
from collections.abc import Mapping
from functools import wraps
from operator import itemgetter
from typing import Any, Callable, NamedTuple
from uuid import uuid4
from typing import Any, NamedTuple

import torch
import zmq
Expand All @@ -36,8 +34,7 @@
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
dynamic_zmq_socket,
)

logger = logging.getLogger(__name__)
Expand All @@ -51,6 +48,15 @@

TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds

# Pre-bound decorator for storage-unit socket operations.
_storage_unit_socket = dynamic_zmq_socket(
"put_get_socket",
owner_id_attr="storage_manager_id",
server_attr="storage_unit_infos",
target_kwarg="target_storage_unit",
timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT,
)


class RoutingGroup(NamedTuple):
"""Routing result for a single storage unit."""
Expand Down Expand Up @@ -114,78 +120,6 @@ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerIn

return server_infos_transform

# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
@staticmethod
def dynamic_storage_manager_socket(socket_name: str, timeout: int):
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close).

Args:
socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port").
timeout (float): Timeout in seconds for ZMQ connection (in seconds).

Decorated Function Rules:
1. Must be an async class method (needs `self`).
2. `self` requires:
- `storage_unit_infos: storage unit infos (ZMQServerInfo | dict[Any, ZMQServerInfo]).
3. Specify target server via:
- `target_storage_unit` arg.
4. Receives ZMQ socket via `socket` keyword arg (injected by decorator).
"""

def decorator(func: Callable):
@wraps(func)
async def wrapper(self, *args, **kwargs):
server_key = kwargs.get("target_storage_unit")
if server_key is None:
for arg in args:
if isinstance(arg, str) and arg in self.storage_unit_infos.keys():
server_key = arg
break

server_info = self.storage_unit_infos.get(server_key)

if not server_info:
raise RuntimeError(f"Server {server_key} not found in registered servers")

context = zmq.asyncio.Context()
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity)

try:
sock.connect(address)
# Timeouts to avoid indefinite await on recv/send
sock.setsockopt(zmq.RCVTIMEO, timeout * 1000)
sock.setsockopt(zmq.SNDTIMEO, timeout * 1000)
logger.debug(
f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} "
f"with identity {identity.decode()}"
)

kwargs["socket"] = sock
return await func(self, *args, **kwargs)
except Exception as e:
logger.error(
f"[{self.storage_manager_id}]: Error in socket operation with "
f"StorageUnit {server_info.id} at {address}: "
f"{type(e).__name__}: {e}"
)
raise
finally:
try:
if not sock.closed:
sock.close(linger=-1)
except Exception as e:
logger.warning(
f"[{self.storage_manager_id}]: Error closing socket to StorageUnit {server_info.id}: {e}"
)

context.term()

return wrapper

return decorator

def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]:
"""Group samples by global_idx % num_su, return {storage_id: RoutingGroup}.

Expand Down Expand Up @@ -335,7 +269,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
field_schema,
)

@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
@_storage_unit_socket
async def _put_to_single_storage_unit(
self,
global_indexes: list[int],
Expand Down Expand Up @@ -456,7 +390,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict:

return TensorDict(tensor_data, batch_size=len(metadata))

@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
@_storage_unit_socket
async def _get_from_single_storage_unit(
self,
global_indexes: list[int],
Expand Down Expand Up @@ -528,7 +462,7 @@ async def clear_data(self, metadata: BatchMeta) -> None:
if isinstance(result, Exception):
logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}")

@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
@_storage_unit_socket
async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=None, socket=None):
try:
request_msg = ZMQMessage.create(
Expand Down
Loading
Loading