Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions skyrl/train/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import asyncio
import inspect
import os
import sys
import traceback
from dataclasses import dataclass
from typing import Iterable, List, Set, Tuple
Expand Down Expand Up @@ -593,19 +592,16 @@ async def _run_generate_for_a_group_loop(self, generation_output_group_buffer: a
except asyncio.QueueFull:
raise AssertionError("Generation buffer should never be full given staleness control.")
await self._staleness_manager.on_rollout_accepted()
slot_acquired = False
except asyncio.CancelledError:
# If a slot was acquired but we exit early, release running count
try:
if "slot_acquired" in locals() and slot_acquired:
raise RuntimeError("Generation workers should only be cancelled when they finish running.")
finally:
return
if "slot_acquired" in locals() and slot_acquired:
logger.warning("Generation worker cancelled while slot was acquired. Releasing slot.")
await self._staleness_manager.on_rollout_rejected()
return
Comment thread
dinhxuanvu marked this conversation as resolved.
except Exception as e:
logger.error(f"Generator worker errored out with exception: {e}")
logger.error(f"Traceback: \n{traceback.format_exc()}")
if "slot_acquired" in locals() and slot_acquired:
raise RuntimeError("Generation workers should only run into error when they finish running.")
sys.exit(1)
os._exit(1)

async def async_sync_policy_weights_to_inference_engines(self):
return await self.policy_model.async_run_method(
Expand Down
Loading