|
rl_global_batch = args.rl_global_batch |
|
if args.filter_trajectory: |
|
_world_size = actor_dp_mesh.size() |
|
_data_size = len(trajectory_dataset) |
|
# train_global_batch is divisible by world_size |
|
rl_global_batch = _data_size // _world_size * _world_size |
|
|
|
rl_loader = DataLoader( |
|
trajectory_dataset, |
|
batch_size=args.rl_mirco_batch, |
|
num_workers=0, |
|
collate_fn=TrajectoryCollator(pack_batch=True), |
|
shuffle=False, |
|
sampler=RLParallelSampler(trajectory_dataset, actor_dp_mesh, rl_global_batch, shuffle=False), |
|
persistent_workers=False, |
|
) |
When training large models (especially 32B parameter models) with distributed processing, there's a potential issue where rl_global_batch can become zero if _world_size is large. This causes a ZeroDivisionError in the code. Is there any reasonable method to fix this problem?
OREAL/train_oreal.py
Lines 598 to 613 in 133434b
When training large models (especially 32B parameter models) with distributed processing, there's a potential issue where
rl_global_batchcan become zero if_world_sizeis large. This causes a ZeroDivisionError in the code. Is there any reasonable method to fix this problem?