diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index a04bfd6cf4..47ec230088 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -14,7 +14,6 @@ import asyncio import inspect import os -import sys import traceback from dataclasses import dataclass from typing import Iterable, List, Set, Tuple @@ -593,19 +592,16 @@ async def _run_generate_for_a_group_loop(self, generation_output_group_buffer: a except asyncio.QueueFull: raise AssertionError("Generation buffer should never be full given staleness control.") await self._staleness_manager.on_rollout_accepted() + slot_acquired = False except asyncio.CancelledError: - # If a slot was acquired but we exit early, release running count - try: - if "slot_acquired" in locals() and slot_acquired: - raise RuntimeError("Generation workers should only be cancelled when they finish running.") - finally: - return + if "slot_acquired" in locals() and slot_acquired: + logger.warning("Generation worker cancelled while slot was acquired. Releasing slot.") + await self._staleness_manager.on_rollout_rejected() + return except Exception as e: logger.error(f"Generator worker errored out with exception: {e}") logger.error(f"Traceback: \n{traceback.format_exc()}") - if "slot_acquired" in locals() and slot_acquired: - raise RuntimeError("Generation workers should only run into error when they finish running.") - sys.exit(1) + os._exit(1) async def async_sync_policy_weights_to_inference_engines(self): return await self.policy_model.async_run_method(