diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 195eaf3..8691bad 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -242,9 +242,14 @@ def test_setstate_readonly_production_status(self): """ batch = self._make_batch() # Simulate pickle round-trip with Arrow zero-copy (read-only array) - state = batch.__dict__.copy() - state["production_status"] = state["production_status"].copy() - state["production_status"].flags.writeable = False + + state = batch.__getstate__() + # Convert tuple to list for modification + state = list(state) + slot_idx = list(BatchMeta.__slots__).index("production_status") + state[slot_idx] = state[slot_idx].copy() + state[slot_idx].flags.writeable = False + state = tuple(state) restored = BatchMeta.__new__(BatchMeta) restored.__setstate__(state) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 9c91a6b..2987695 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -203,12 +203,8 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: return field_schema -@dataclass class BatchMeta: - """Records the metadata of a batch of data samples with optimized field-level schema. - - This is the O(BxF) optimized version that stores field metadata at the field level - instead of per-sample, reducing storage from O(B*F) to O(F). + """Metadata of a batch of data samples. Attributes: global_indexes: List of global sample indices in this batch. @@ -220,24 +216,42 @@ class BatchMeta: _custom_backend_meta: Per-sample per-field storage backend metadata, list aligned with global_indexes. """ - global_indexes: list[int] - partition_ids: list[str] - # field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}} - field_schema: dict[str, dict[str, Any]] = dataclasses.field(default_factory=dict) - # vectorized production status matrix - production_status: np.ndarray = dataclasses.field(default=None, repr=False) # type: ignore[assignment] - extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - # user-defined meta for each sample (sample-level), list aligned with global_indexes - custom_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list) - # internal meta for different storage backends (per-sample per-field level), list aligned with global_indexes - _custom_backend_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list) + __slots__ = ( + "global_indexes", + "partition_ids", + "field_schema", + "production_status", + "extra_info", + "custom_meta", + "_custom_backend_meta", + "_size", + "_field_names", + "_is_ready", + ) + + def __init__( + self, + global_indexes: list[int], + partition_ids: list[str], + field_schema: Optional[dict[str, dict[str, Any]]] = None, + production_status: Optional[np.ndarray] = None, + extra_info: Optional[dict[str, Any]] = None, + custom_meta: Optional[list[dict[str, Any]]] = None, + _custom_backend_meta: Optional[list[dict[str, Any]]] = None, + ) -> None: + if field_schema is None: + field_schema = {} + if extra_info is None: + extra_info = {} + if custom_meta is None: + custom_meta = [] + if _custom_backend_meta is None: + _custom_backend_meta = [] - def __post_init__(self): - """Initialize all computed properties during initialization""" - self.global_indexes = list(self.global_indexes) - self.partition_ids = list(self.partition_ids) - self.field_schema = {k: dict(v) for k, v in self.field_schema.items()} - self.extra_info = dict(self.extra_info) + self.global_indexes = list(global_indexes) + self.partition_ids = list(partition_ids) + self.field_schema = {k: dict(v) for k, v in field_schema.items()} + self.extra_info = dict(extra_info) # Validation if len(self.global_indexes) != len(self.partition_ids): @@ -248,8 +262,8 @@ def __post_init__(self): batch_size = len(self.global_indexes) - if self.production_status is not None: - self.production_status = np.array(self.production_status, dtype=np.int8, copy=True) + if production_status is not None: + self.production_status = np.array(production_status, dtype=np.int8, copy=True) if len(self.production_status) != batch_size: raise ValueError(f"production_status length {len(self.production_status)} != batch_size {batch_size}") @@ -272,32 +286,40 @@ def __post_init__(self): self._is_ready = is_ready # Validate or initialize columnar custom_meta / _custom_backend_meta - if not self.custom_meta: - self.custom_meta = [{} for _ in range(batch_size)] + if not custom_meta: + self.custom_meta: list[dict[str, Any]] = [{} for _ in range(batch_size)] else: - self.custom_meta = [dict(d) for d in self.custom_meta] + self.custom_meta = [dict(d) for d in custom_meta] if len(self.custom_meta) != batch_size: raise ValueError(f"custom_meta length {len(self.custom_meta)} != batch_size {batch_size}") - if not self._custom_backend_meta: - self._custom_backend_meta = [{} for _ in range(batch_size)] + if not _custom_backend_meta: + self._custom_backend_meta: list[dict[str, Any]] = [{} for _ in range(batch_size)] else: - self._custom_backend_meta = [dict(d) for d in self._custom_backend_meta] + self._custom_backend_meta = [dict(d) for d in _custom_backend_meta] if len(self._custom_backend_meta) != batch_size: raise ValueError( f"_custom_backend_meta length {len(self._custom_backend_meta)} != batch_size {batch_size}" ) + def __getstate__(self): + """Serialize for pickle/Ray. + + Returns tuple of slot values to ensure proper reconstruction. + """ + return tuple(getattr(self, slot) for slot in self.__slots__) + def __setstate__(self, state): - """Restore instance from pickle/Ray deserialization. + """Deserialize from pickle/Ray. - Python dataclass pickle skips __init__/__post_init__, so the - .copy() guard for production_status is bypassed. Ray Arrow - zero-copy deserialization produces read-only numpy arrays. - This method ensures writability after deserialization. + Ray Arrow zero-copy deserialization produces read-only numpy + arrays. This method ensures production_status is writable after + deserialization. """ - self.__dict__.update(state) - if isinstance(self.production_status, np.ndarray) and not self.production_status.flags.writeable: - self.production_status = self.production_status.copy() + for slot, value in zip(self.__slots__, state, strict=False): + # Ray Arrow zero-copy produces read-only numpy arrays + if slot == "production_status" and isinstance(value, np.ndarray) and not value.flags.writeable: + value = value.copy() + setattr(self, slot, value) @property def size(self) -> int: diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index 38917a1..fd57b80 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -37,8 +37,7 @@ CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference CUSTOM_TYPE_NESTED_TENSOR = 4 # For nested tensor (strided or jagged) -CUSTOM_TYPE_BATCHMETA = 5 # For BatchMeta serialization -CUSTOM_TYPE_NUMPY = 6 # For numpy ndarray with buffer reference +CUSTOM_TYPE_NUMPY = 5 # For numpy ndarray with buffer reference # 0xC1 is permanently reserved (invalid) in msgpack spec — safe to use as pickle fallback sentinel. _PICKLE_FALLBACK_SENTINEL = b"\xc1\xfe\xed" @@ -79,8 +78,6 @@ def aux_buffers(self) -> list[bytestr]: def encode(self, obj: Any) -> Sequence[bytestr]: """Encode a given object to a byte array.""" - # Pre-process to convert BatchMeta to Ext; msgspec auto-serializes dataclasses and won't call enc_hook for them. - obj = self._preprocess_for_batchmeta(obj) bufs: list[bytestr] = [b""] token = _encoder_aux_buffers.set(bufs) @@ -94,24 +91,6 @@ def encode(self, obj: Any) -> Sequence[bytestr]: finally: _encoder_aux_buffers.reset(token) - def _preprocess_for_batchmeta(self, obj: Any) -> Any: - """Recursively preprocess object to convert BatchMeta to Ext. - - This is necessary because msgspec auto-serializes dataclasses and - won't call enc_hook for them. - """ - from transfer_queue.metadata import BatchMeta - - if isinstance(obj, BatchMeta): - return self._encode_batchmeta(obj) - elif isinstance(obj, dict): - return {k: self._preprocess_for_batchmeta(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [self._preprocess_for_batchmeta(item) for item in obj] - elif isinstance(obj, tuple): - return tuple(self._preprocess_for_batchmeta(item) for item in obj) - return obj - def enc_hook(self, obj: Any) -> Any: """Custom encoding hook for types msgspec doesn't natively support. @@ -120,8 +99,6 @@ def enc_hook(self, obj: Any) -> Any: - TensorDict: Convert to dict structure for recursive processing - numpy.ndarray: Convert to tensor for unified handling - Note: BatchMeta is handled by _preprocess_for_batchmeta() before encode() is called, - so it will never reach this hook. """ if isinstance(obj, torch.Tensor): return self._encode_tensor(obj) @@ -148,15 +125,6 @@ def enc_hook(self, obj: Any) -> Any: # Fallback to pickle for unknown types return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) - def _encode_batchmeta(self, obj: Any) -> msgpack.Ext: - """Encode BatchMeta as a pickle-based Ext payload. - - BatchMeta must be preprocessed before encode() because msgspec auto-serializes - dataclasses (bypassing enc_hook), and BatchMeta fields contain torch.dtype which - msgpack cannot handle natively. - """ - return msgpack.Ext(CUSTOM_TYPE_BATCHMETA, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) - def _encode_tensordict(self, obj: Any) -> dict: """Convert TensorDict to a dict structure for recursive msgpack processing. @@ -373,7 +341,6 @@ def ext_hook(self, code: int, data: memoryview) -> Any: - torch.Tensor: Extract buffer, store metadata - TensorDict: Convert to dict structure for recursive processing - numpy.ndarray: Convert to tensor for unified handling - - BatchMeta: Reconstruct from pickle """ if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) @@ -385,8 +352,6 @@ def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_NESTED_TENSOR: nested_meta = pickle.loads(data) return self._decode_nested_tensor(nested_meta) - if code == CUSTOM_TYPE_BATCHMETA: - return pickle.loads(data) if code == CUSTOM_TYPE_NUMPY: meta = pickle.loads(data) return self._decode_numpy(meta)