Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,14 @@ def test_setstate_readonly_production_status(self):
"""
batch = self._make_batch()
# Simulate pickle round-trip with Arrow zero-copy (read-only array)
state = batch.__dict__.copy()
state["production_status"] = state["production_status"].copy()
state["production_status"].flags.writeable = False

state = batch.__getstate__()
# Convert tuple to list for modification
state = list(state)
slot_idx = list(BatchMeta.__slots__).index("production_status")
state[slot_idx] = state[slot_idx].copy()
state[slot_idx].flags.writeable = False
state = tuple(state)

restored = BatchMeta.__new__(BatchMeta)
restored.__setstate__(state)
Expand Down
98 changes: 60 additions & 38 deletions transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,8 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]:
return field_schema


@dataclass
class BatchMeta:
"""Records the metadata of a batch of data samples with optimized field-level schema.

This is the O(BxF) optimized version that stores field metadata at the field level
instead of per-sample, reducing storage from O(B*F) to O(F).
"""Metadata of a batch of data samples.

Attributes:
global_indexes: List of global sample indices in this batch.
Expand All @@ -220,24 +216,42 @@ class BatchMeta:
_custom_backend_meta: Per-sample per-field storage backend metadata, list aligned with global_indexes.
"""

global_indexes: list[int]
partition_ids: list[str]
# field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}}
field_schema: dict[str, dict[str, Any]] = dataclasses.field(default_factory=dict)
# vectorized production status matrix
production_status: np.ndarray = dataclasses.field(default=None, repr=False) # type: ignore[assignment]
extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
# user-defined meta for each sample (sample-level), list aligned with global_indexes
custom_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list)
# internal meta for different storage backends (per-sample per-field level), list aligned with global_indexes
_custom_backend_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list)
__slots__ = (
"global_indexes",
"partition_ids",
"field_schema",
"production_status",
"extra_info",
"custom_meta",
"_custom_backend_meta",
"_size",
"_field_names",
"_is_ready",
)

def __init__(
self,
global_indexes: list[int],
partition_ids: list[str],
field_schema: Optional[dict[str, dict[str, Any]]] = None,
production_status: Optional[np.ndarray] = None,
extra_info: Optional[dict[str, Any]] = None,
custom_meta: Optional[list[dict[str, Any]]] = None,
_custom_backend_meta: Optional[list[dict[str, Any]]] = None,
) -> None:
if field_schema is None:
field_schema = {}
if extra_info is None:
extra_info = {}
if custom_meta is None:
custom_meta = []
if _custom_backend_meta is None:
_custom_backend_meta = []

def __post_init__(self):
"""Initialize all computed properties during initialization"""
self.global_indexes = list(self.global_indexes)
self.partition_ids = list(self.partition_ids)
self.field_schema = {k: dict(v) for k, v in self.field_schema.items()}
self.extra_info = dict(self.extra_info)
self.global_indexes = list(global_indexes)
self.partition_ids = list(partition_ids)
self.field_schema = {k: dict(v) for k, v in field_schema.items()}
self.extra_info = dict(extra_info)

# Validation
if len(self.global_indexes) != len(self.partition_ids):
Expand All @@ -248,8 +262,8 @@ def __post_init__(self):

batch_size = len(self.global_indexes)

if self.production_status is not None:
self.production_status = np.array(self.production_status, dtype=np.int8, copy=True)
if production_status is not None:
self.production_status = np.array(production_status, dtype=np.int8, copy=True)

if len(self.production_status) != batch_size:
raise ValueError(f"production_status length {len(self.production_status)} != batch_size {batch_size}")
Expand All @@ -272,32 +286,40 @@ def __post_init__(self):
self._is_ready = is_ready

# Validate or initialize columnar custom_meta / _custom_backend_meta
if not self.custom_meta:
self.custom_meta = [{} for _ in range(batch_size)]
if not custom_meta:
self.custom_meta: list[dict[str, Any]] = [{} for _ in range(batch_size)]
else:
self.custom_meta = [dict(d) for d in self.custom_meta]
self.custom_meta = [dict(d) for d in custom_meta]
if len(self.custom_meta) != batch_size:
raise ValueError(f"custom_meta length {len(self.custom_meta)} != batch_size {batch_size}")
if not self._custom_backend_meta:
self._custom_backend_meta = [{} for _ in range(batch_size)]
if not _custom_backend_meta:
self._custom_backend_meta: list[dict[str, Any]] = [{} for _ in range(batch_size)]
else:
self._custom_backend_meta = [dict(d) for d in self._custom_backend_meta]
self._custom_backend_meta = [dict(d) for d in _custom_backend_meta]
if len(self._custom_backend_meta) != batch_size:
raise ValueError(
f"_custom_backend_meta length {len(self._custom_backend_meta)} != batch_size {batch_size}"
)

def __getstate__(self):
"""Serialize for pickle/Ray.

Returns tuple of slot values to ensure proper reconstruction.
"""
return tuple(getattr(self, slot) for slot in self.__slots__)

def __setstate__(self, state):
"""Restore instance from pickle/Ray deserialization.
"""Deserialize from pickle/Ray.

Python dataclass pickle skips __init__/__post_init__, so the
.copy() guard for production_status is bypassed. Ray Arrow
zero-copy deserialization produces read-only numpy arrays.
This method ensures writability after deserialization.
Ray Arrow zero-copy deserialization produces read-only numpy
arrays. This method ensures production_status is writable after
deserialization.
"""
self.__dict__.update(state)
if isinstance(self.production_status, np.ndarray) and not self.production_status.flags.writeable:
self.production_status = self.production_status.copy()
for slot, value in zip(self.__slots__, state, strict=False):
# Ray Arrow zero-copy produces read-only numpy arrays
if slot == "production_status" and isinstance(value, np.ndarray) and not value.flags.writeable:
value = value.copy()
setattr(self, slot, value)

@property
def size(self) -> int:
Expand Down
37 changes: 1 addition & 36 deletions transfer_queue/utils/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference
CUSTOM_TYPE_NESTED_TENSOR = 4 # For nested tensor (strided or jagged)
CUSTOM_TYPE_BATCHMETA = 5 # For BatchMeta serialization
CUSTOM_TYPE_NUMPY = 6 # For numpy ndarray with buffer reference
CUSTOM_TYPE_NUMPY = 5 # For numpy ndarray with buffer reference

# 0xC1 is permanently reserved (invalid) in msgpack spec — safe to use as pickle fallback sentinel.
_PICKLE_FALLBACK_SENTINEL = b"\xc1\xfe\xed"
Expand Down Expand Up @@ -79,8 +78,6 @@ def aux_buffers(self) -> list[bytestr]:

def encode(self, obj: Any) -> Sequence[bytestr]:
"""Encode a given object to a byte array."""
# Pre-process to convert BatchMeta to Ext; msgspec auto-serializes dataclasses and won't call enc_hook for them.
obj = self._preprocess_for_batchmeta(obj)

bufs: list[bytestr] = [b""]
token = _encoder_aux_buffers.set(bufs)
Expand All @@ -94,24 +91,6 @@ def encode(self, obj: Any) -> Sequence[bytestr]:
finally:
_encoder_aux_buffers.reset(token)

def _preprocess_for_batchmeta(self, obj: Any) -> Any:
"""Recursively preprocess object to convert BatchMeta to Ext.

This is necessary because msgspec auto-serializes dataclasses and
won't call enc_hook for them.
"""
from transfer_queue.metadata import BatchMeta

if isinstance(obj, BatchMeta):
return self._encode_batchmeta(obj)
elif isinstance(obj, dict):
return {k: self._preprocess_for_batchmeta(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._preprocess_for_batchmeta(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(self._preprocess_for_batchmeta(item) for item in obj)
return obj

def enc_hook(self, obj: Any) -> Any:
"""Custom encoding hook for types msgspec doesn't natively support.

Expand All @@ -120,8 +99,6 @@ def enc_hook(self, obj: Any) -> Any:
- TensorDict: Convert to dict structure for recursive processing
- numpy.ndarray: Convert to tensor for unified handling

Note: BatchMeta is handled by _preprocess_for_batchmeta() before encode() is called,
so it will never reach this hook.
"""
if isinstance(obj, torch.Tensor):
return self._encode_tensor(obj)
Expand All @@ -148,15 +125,6 @@ def enc_hook(self, obj: Any) -> Any:
# Fallback to pickle for unknown types
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

def _encode_batchmeta(self, obj: Any) -> msgpack.Ext:
"""Encode BatchMeta as a pickle-based Ext payload.

BatchMeta must be preprocessed before encode() because msgspec auto-serializes
dataclasses (bypassing enc_hook), and BatchMeta fields contain torch.dtype which
msgpack cannot handle natively.
"""
return msgpack.Ext(CUSTOM_TYPE_BATCHMETA, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

def _encode_tensordict(self, obj: Any) -> dict:
"""Convert TensorDict to a dict structure for recursive msgpack processing.

Expand Down Expand Up @@ -373,7 +341,6 @@ def ext_hook(self, code: int, data: memoryview) -> Any:
- torch.Tensor: Extract buffer, store metadata
- TensorDict: Convert to dict structure for recursive processing
- numpy.ndarray: Convert to tensor for unified handling
- BatchMeta: Reconstruct from pickle
"""
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
Expand All @@ -385,8 +352,6 @@ def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_NESTED_TENSOR:
nested_meta = pickle.loads(data)
return self._decode_nested_tensor(nested_meta)
if code == CUSTOM_TYPE_BATCHMETA:
return pickle.loads(data)
if code == CUSTOM_TYPE_NUMPY:
meta = pickle.loads(data)
return self._decode_numpy(meta)
Expand Down
Loading