Skip to content

Commit 5f303c3

Browse files
authored
[optimize] Refactor BatchMeta to ordinary class (#53)
1. Refactor `BatchMeta` from `@dataclass` to regular class 2. Remove unnecessary serialization fix for dataclass `BatchMeta` --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent c4e86ed commit 5f303c3

3 files changed

Lines changed: 69 additions & 77 deletions

File tree

tests/test_metadata.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,14 @@ def test_setstate_readonly_production_status(self):
242242
"""
243243
batch = self._make_batch()
244244
# Simulate pickle round-trip with Arrow zero-copy (read-only array)
245-
state = batch.__dict__.copy()
246-
state["production_status"] = state["production_status"].copy()
247-
state["production_status"].flags.writeable = False
245+
246+
state = batch.__getstate__()
247+
# Convert tuple to list for modification
248+
state = list(state)
249+
slot_idx = list(BatchMeta.__slots__).index("production_status")
250+
state[slot_idx] = state[slot_idx].copy()
251+
state[slot_idx].flags.writeable = False
252+
state = tuple(state)
248253

249254
restored = BatchMeta.__new__(BatchMeta)
250255
restored.__setstate__(state)

transfer_queue/metadata.py

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,8 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]:
203203
return field_schema
204204

205205

206-
@dataclass
207206
class BatchMeta:
208-
"""Records the metadata of a batch of data samples with optimized field-level schema.
209-
210-
This is the O(BxF) optimized version that stores field metadata at the field level
211-
instead of per-sample, reducing storage from O(B*F) to O(F).
207+
"""Metadata of a batch of data samples.
212208
213209
Attributes:
214210
global_indexes: List of global sample indices in this batch.
@@ -220,24 +216,42 @@ class BatchMeta:
220216
_custom_backend_meta: Per-sample per-field storage backend metadata, list aligned with global_indexes.
221217
"""
222218

223-
global_indexes: list[int]
224-
partition_ids: list[str]
225-
# field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}}
226-
field_schema: dict[str, dict[str, Any]] = dataclasses.field(default_factory=dict)
227-
# vectorized production status matrix
228-
production_status: np.ndarray = dataclasses.field(default=None, repr=False) # type: ignore[assignment]
229-
extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
230-
# user-defined meta for each sample (sample-level), list aligned with global_indexes
231-
custom_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list)
232-
# internal meta for different storage backends (per-sample per-field level), list aligned with global_indexes
233-
_custom_backend_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list)
219+
__slots__ = (
220+
"global_indexes",
221+
"partition_ids",
222+
"field_schema",
223+
"production_status",
224+
"extra_info",
225+
"custom_meta",
226+
"_custom_backend_meta",
227+
"_size",
228+
"_field_names",
229+
"_is_ready",
230+
)
231+
232+
def __init__(
233+
self,
234+
global_indexes: list[int],
235+
partition_ids: list[str],
236+
field_schema: Optional[dict[str, dict[str, Any]]] = None,
237+
production_status: Optional[np.ndarray] = None,
238+
extra_info: Optional[dict[str, Any]] = None,
239+
custom_meta: Optional[list[dict[str, Any]]] = None,
240+
_custom_backend_meta: Optional[list[dict[str, Any]]] = None,
241+
) -> None:
242+
if field_schema is None:
243+
field_schema = {}
244+
if extra_info is None:
245+
extra_info = {}
246+
if custom_meta is None:
247+
custom_meta = []
248+
if _custom_backend_meta is None:
249+
_custom_backend_meta = []
234250

235-
def __post_init__(self):
236-
"""Initialize all computed properties during initialization"""
237-
self.global_indexes = list(self.global_indexes)
238-
self.partition_ids = list(self.partition_ids)
239-
self.field_schema = {k: dict(v) for k, v in self.field_schema.items()}
240-
self.extra_info = dict(self.extra_info)
251+
self.global_indexes = list(global_indexes)
252+
self.partition_ids = list(partition_ids)
253+
self.field_schema = {k: dict(v) for k, v in field_schema.items()}
254+
self.extra_info = dict(extra_info)
241255

242256
# Validation
243257
if len(self.global_indexes) != len(self.partition_ids):
@@ -248,8 +262,8 @@ def __post_init__(self):
248262

249263
batch_size = len(self.global_indexes)
250264

251-
if self.production_status is not None:
252-
self.production_status = np.array(self.production_status, dtype=np.int8, copy=True)
265+
if production_status is not None:
266+
self.production_status = np.array(production_status, dtype=np.int8, copy=True)
253267

254268
if len(self.production_status) != batch_size:
255269
raise ValueError(f"production_status length {len(self.production_status)} != batch_size {batch_size}")
@@ -272,32 +286,40 @@ def __post_init__(self):
272286
self._is_ready = is_ready
273287

274288
# Validate or initialize columnar custom_meta / _custom_backend_meta
275-
if not self.custom_meta:
276-
self.custom_meta = [{} for _ in range(batch_size)]
289+
if not custom_meta:
290+
self.custom_meta: list[dict[str, Any]] = [{} for _ in range(batch_size)]
277291
else:
278-
self.custom_meta = [dict(d) for d in self.custom_meta]
292+
self.custom_meta = [dict(d) for d in custom_meta]
279293
if len(self.custom_meta) != batch_size:
280294
raise ValueError(f"custom_meta length {len(self.custom_meta)} != batch_size {batch_size}")
281-
if not self._custom_backend_meta:
282-
self._custom_backend_meta = [{} for _ in range(batch_size)]
295+
if not _custom_backend_meta:
296+
self._custom_backend_meta: list[dict[str, Any]] = [{} for _ in range(batch_size)]
283297
else:
284-
self._custom_backend_meta = [dict(d) for d in self._custom_backend_meta]
298+
self._custom_backend_meta = [dict(d) for d in _custom_backend_meta]
285299
if len(self._custom_backend_meta) != batch_size:
286300
raise ValueError(
287301
f"_custom_backend_meta length {len(self._custom_backend_meta)} != batch_size {batch_size}"
288302
)
289303

304+
def __getstate__(self):
305+
"""Serialize for pickle/Ray.
306+
307+
Returns tuple of slot values to ensure proper reconstruction.
308+
"""
309+
return tuple(getattr(self, slot) for slot in self.__slots__)
310+
290311
def __setstate__(self, state):
291-
"""Restore instance from pickle/Ray deserialization.
312+
"""Deserialize from pickle/Ray.
292313
293-
Python dataclass pickle skips __init__/__post_init__, so the
294-
.copy() guard for production_status is bypassed. Ray Arrow
295-
zero-copy deserialization produces read-only numpy arrays.
296-
This method ensures writability after deserialization.
314+
Ray Arrow zero-copy deserialization produces read-only numpy
315+
arrays. This method ensures production_status is writable after
316+
deserialization.
297317
"""
298-
self.__dict__.update(state)
299-
if isinstance(self.production_status, np.ndarray) and not self.production_status.flags.writeable:
300-
self.production_status = self.production_status.copy()
318+
for slot, value in zip(self.__slots__, state, strict=False):
319+
# Ray Arrow zero-copy produces read-only numpy arrays
320+
if slot == "production_status" and isinstance(value, np.ndarray) and not value.flags.writeable:
321+
value = value.copy()
322+
setattr(self, slot, value)
301323

302324
@property
303325
def size(self) -> int:

transfer_queue/utils/serial_utils.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
CUSTOM_TYPE_CLOUDPICKLE = 2
3838
CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference
3939
CUSTOM_TYPE_NESTED_TENSOR = 4 # For nested tensor (strided or jagged)
40-
CUSTOM_TYPE_BATCHMETA = 5 # For BatchMeta serialization
41-
CUSTOM_TYPE_NUMPY = 6 # For numpy ndarray with buffer reference
40+
CUSTOM_TYPE_NUMPY = 5 # For numpy ndarray with buffer reference
4241

4342
# 0xC1 is permanently reserved (invalid) in msgpack spec — safe to use as pickle fallback sentinel.
4443
_PICKLE_FALLBACK_SENTINEL = b"\xc1\xfe\xed"
@@ -79,8 +78,6 @@ def aux_buffers(self) -> list[bytestr]:
7978

8079
def encode(self, obj: Any) -> Sequence[bytestr]:
8180
"""Encode a given object to a byte array."""
82-
# Pre-process to convert BatchMeta to Ext; msgspec auto-serializes dataclasses and won't call enc_hook for them.
83-
obj = self._preprocess_for_batchmeta(obj)
8481

8582
bufs: list[bytestr] = [b""]
8683
token = _encoder_aux_buffers.set(bufs)
@@ -94,24 +91,6 @@ def encode(self, obj: Any) -> Sequence[bytestr]:
9491
finally:
9592
_encoder_aux_buffers.reset(token)
9693

97-
def _preprocess_for_batchmeta(self, obj: Any) -> Any:
98-
"""Recursively preprocess object to convert BatchMeta to Ext.
99-
100-
This is necessary because msgspec auto-serializes dataclasses and
101-
won't call enc_hook for them.
102-
"""
103-
from transfer_queue.metadata import BatchMeta
104-
105-
if isinstance(obj, BatchMeta):
106-
return self._encode_batchmeta(obj)
107-
elif isinstance(obj, dict):
108-
return {k: self._preprocess_for_batchmeta(v) for k, v in obj.items()}
109-
elif isinstance(obj, list):
110-
return [self._preprocess_for_batchmeta(item) for item in obj]
111-
elif isinstance(obj, tuple):
112-
return tuple(self._preprocess_for_batchmeta(item) for item in obj)
113-
return obj
114-
11594
def enc_hook(self, obj: Any) -> Any:
11695
"""Custom encoding hook for types msgspec doesn't natively support.
11796
@@ -120,8 +99,6 @@ def enc_hook(self, obj: Any) -> Any:
12099
- TensorDict: Convert to dict structure for recursive processing
121100
- numpy.ndarray: Convert to tensor for unified handling
122101
123-
Note: BatchMeta is handled by _preprocess_for_batchmeta() before encode() is called,
124-
so it will never reach this hook.
125102
"""
126103
if isinstance(obj, torch.Tensor):
127104
return self._encode_tensor(obj)
@@ -148,15 +125,6 @@ def enc_hook(self, obj: Any) -> Any:
148125
# Fallback to pickle for unknown types
149126
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
150127

151-
def _encode_batchmeta(self, obj: Any) -> msgpack.Ext:
152-
"""Encode BatchMeta as a pickle-based Ext payload.
153-
154-
BatchMeta must be preprocessed before encode() because msgspec auto-serializes
155-
dataclasses (bypassing enc_hook), and BatchMeta fields contain torch.dtype which
156-
msgpack cannot handle natively.
157-
"""
158-
return msgpack.Ext(CUSTOM_TYPE_BATCHMETA, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
159-
160128
def _encode_tensordict(self, obj: Any) -> dict:
161129
"""Convert TensorDict to a dict structure for recursive msgpack processing.
162130
@@ -373,7 +341,6 @@ def ext_hook(self, code: int, data: memoryview) -> Any:
373341
- torch.Tensor: Extract buffer, store metadata
374342
- TensorDict: Convert to dict structure for recursive processing
375343
- numpy.ndarray: Convert to tensor for unified handling
376-
- BatchMeta: Reconstruct from pickle
377344
"""
378345
if code == CUSTOM_TYPE_PICKLE:
379346
return pickle.loads(data)
@@ -385,8 +352,6 @@ def ext_hook(self, code: int, data: memoryview) -> Any:
385352
if code == CUSTOM_TYPE_NESTED_TENSOR:
386353
nested_meta = pickle.loads(data)
387354
return self._decode_nested_tensor(nested_meta)
388-
if code == CUSTOM_TYPE_BATCHMETA:
389-
return pickle.loads(data)
390355
if code == CUSTOM_TYPE_NUMPY:
391356
meta = pickle.loads(data)
392357
return self._decode_numpy(meta)

0 commit comments

Comments
 (0)