@@ -203,12 +203,8 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]:
203203 return field_schema
204204
205205
206- @dataclass
207206class BatchMeta :
208- """Records the metadata of a batch of data samples with optimized field-level schema.
209-
210- This is the O(BxF) optimized version that stores field metadata at the field level
211- instead of per-sample, reducing storage from O(B*F) to O(F).
207+ """Metadata of a batch of data samples.
212208
213209 Attributes:
214210 global_indexes: List of global sample indices in this batch.
@@ -220,24 +216,42 @@ class BatchMeta:
220216 _custom_backend_meta: Per-sample per-field storage backend metadata, list aligned with global_indexes.
221217 """
222218
223- global_indexes : list [int ]
224- partition_ids : list [str ]
225- # field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}}
226- field_schema : dict [str , dict [str , Any ]] = dataclasses .field (default_factory = dict )
227- # vectorized production status matrix
228- production_status : np .ndarray = dataclasses .field (default = None , repr = False ) # type: ignore[assignment]
229- extra_info : dict [str , Any ] = dataclasses .field (default_factory = dict )
230- # user-defined meta for each sample (sample-level), list aligned with global_indexes
231- custom_meta : list [dict [str , Any ]] = dataclasses .field (default_factory = list )
232- # internal meta for different storage backends (per-sample per-field level), list aligned with global_indexes
233- _custom_backend_meta : list [dict [str , Any ]] = dataclasses .field (default_factory = list )
219+ __slots__ = (
220+ "global_indexes" ,
221+ "partition_ids" ,
222+ "field_schema" ,
223+ "production_status" ,
224+ "extra_info" ,
225+ "custom_meta" ,
226+ "_custom_backend_meta" ,
227+ "_size" ,
228+ "_field_names" ,
229+ "_is_ready" ,
230+ )
231+
232+ def __init__ (
233+ self ,
234+ global_indexes : list [int ],
235+ partition_ids : list [str ],
236+ field_schema : Optional [dict [str , dict [str , Any ]]] = None ,
237+ production_status : Optional [np .ndarray ] = None ,
238+ extra_info : Optional [dict [str , Any ]] = None ,
239+ custom_meta : Optional [list [dict [str , Any ]]] = None ,
240+ _custom_backend_meta : Optional [list [dict [str , Any ]]] = None ,
241+ ) -> None :
242+ if field_schema is None :
243+ field_schema = {}
244+ if extra_info is None :
245+ extra_info = {}
246+ if custom_meta is None :
247+ custom_meta = []
248+ if _custom_backend_meta is None :
249+ _custom_backend_meta = []
234250
235- def __post_init__ (self ):
236- """Initialize all computed properties during initialization"""
237- self .global_indexes = list (self .global_indexes )
238- self .partition_ids = list (self .partition_ids )
239- self .field_schema = {k : dict (v ) for k , v in self .field_schema .items ()}
240- self .extra_info = dict (self .extra_info )
251+ self .global_indexes = list (global_indexes )
252+ self .partition_ids = list (partition_ids )
253+ self .field_schema = {k : dict (v ) for k , v in field_schema .items ()}
254+ self .extra_info = dict (extra_info )
241255
242256 # Validation
243257 if len (self .global_indexes ) != len (self .partition_ids ):
@@ -248,8 +262,8 @@ def __post_init__(self):
248262
249263 batch_size = len (self .global_indexes )
250264
251- if self . production_status is not None :
252- self .production_status = np .array (self . production_status , dtype = np .int8 , copy = True )
265+ if production_status is not None :
266+ self .production_status = np .array (production_status , dtype = np .int8 , copy = True )
253267
254268 if len (self .production_status ) != batch_size :
255269 raise ValueError (f"production_status length { len (self .production_status )} != batch_size { batch_size } " )
@@ -272,32 +286,40 @@ def __post_init__(self):
272286 self ._is_ready = is_ready
273287
274288 # Validate or initialize columnar custom_meta / _custom_backend_meta
275- if not self . custom_meta :
276- self .custom_meta = [{} for _ in range (batch_size )]
289+ if not custom_meta :
290+ self .custom_meta : list [ dict [ str , Any ]] = [{} for _ in range (batch_size )]
277291 else :
278- self .custom_meta = [dict (d ) for d in self . custom_meta ]
292+ self .custom_meta = [dict (d ) for d in custom_meta ]
279293 if len (self .custom_meta ) != batch_size :
280294 raise ValueError (f"custom_meta length { len (self .custom_meta )} != batch_size { batch_size } " )
281- if not self . _custom_backend_meta :
282- self ._custom_backend_meta = [{} for _ in range (batch_size )]
295+ if not _custom_backend_meta :
296+ self ._custom_backend_meta : list [ dict [ str , Any ]] = [{} for _ in range (batch_size )]
283297 else :
284- self ._custom_backend_meta = [dict (d ) for d in self . _custom_backend_meta ]
298+ self ._custom_backend_meta = [dict (d ) for d in _custom_backend_meta ]
285299 if len (self ._custom_backend_meta ) != batch_size :
286300 raise ValueError (
287301 f"_custom_backend_meta length { len (self ._custom_backend_meta )} != batch_size { batch_size } "
288302 )
289303
304+ def __getstate__ (self ):
305+ """Serialize for pickle/Ray.
306+
307+ Returns tuple of slot values to ensure proper reconstruction.
308+ """
309+ return tuple (getattr (self , slot ) for slot in self .__slots__ )
310+
290311 def __setstate__ (self , state ):
291- """Restore instance from pickle/Ray deserialization .
312+ """Deserialize from pickle/Ray.
292313
293- Python dataclass pickle skips __init__/__post_init__, so the
294- .copy() guard for production_status is bypassed. Ray Arrow
295- zero-copy deserialization produces read-only numpy arrays.
296- This method ensures writability after deserialization.
314+ Ray Arrow zero-copy deserialization produces read-only numpy
315+ arrays. This method ensures production_status is writable after
316+ deserialization.
297317 """
298- self .__dict__ .update (state )
299- if isinstance (self .production_status , np .ndarray ) and not self .production_status .flags .writeable :
300- self .production_status = self .production_status .copy ()
318+ for slot , value in zip (self .__slots__ , state , strict = False ):
319+ # Ray Arrow zero-copy produces read-only numpy arrays
320+ if slot == "production_status" and isinstance (value , np .ndarray ) and not value .flags .writeable :
321+ value = value .copy ()
322+ setattr (self , slot , value )
301323
302324 @property
303325 def size (self ) -> int :
0 commit comments