Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 108 additions & 71 deletions tunix/rl/experimental/agentic_rl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,12 @@ def _model_call(
is_first_msg=True, # no op if system msg is populated in reset
)

result = self.rl_cluster.generate(
prompts=chat_lists,
apply_chat_template=False if self.chat_parser else True,
mode=rl_cluster_lib.Mode.TRAIN,
)
with self.rl_cluster.perf.span("rollout"):
result = self.rl_cluster.generate(
prompts=chat_lists,
apply_chat_template=False if self.chat_parser else True,
mode=rl_cluster_lib.Mode.TRAIN,
)

return result.text[0]

Expand Down Expand Up @@ -707,81 +708,105 @@ def train(
train_data_gen = self._data_consumer_batch_generator(
train_data_queue, train_micro_batch_size * self._num_generations()
)
micro_batches_since_last_sync = 0
micro_batches_per_full_batch = full_batch_size // train_micro_batch_size
for train_micro_batch in train_data_gen:
if self.rl_cluster.global_steps >= self._training_config.max_steps:
logging.info(
"Reached max_steps: %d >= %d",
self.rl_cluster.global_steps,
self._training_config.max_steps,
)
prompt_queue.put(None)
break
self._iter_steps += 1

# Filter out examples that are too old (off-policy).
filtered_train_micro_batch = self._filter_outdated_offpolicy_examples(
train_micro_batch
)
if not filtered_train_micro_batch:
continue
train_micro_batch = filtered_train_micro_batch

merged_train_micro_batch = jax.tree.map(
lambda *xs: jnp.concatenate(xs, axis=0), *train_micro_batch
)
data_insufficient = False

while self.rl_cluster.global_steps < self._training_config.max_steps:
# Start of Global Step
with self.rl_cluster.perf.span_group("global_step"):
# TODO(noghabi): This mini batch is not correct. just a hack to make export work.
with self.rl_cluster.perf.span_group("mini_batch_step"):
# TODO(noghabi): Change metrics to single micro-batch not all steps.
with self.rl_cluster.perf.span("micro_batch_steps"):
for _ in range(micro_batches_per_full_batch):
# Start of Micro Batch
try:
with self.rl_cluster.perf.span("data_loading"):
# with sft_utils.time_measure(suppress_logging=True) as timer:
train_micro_batch = next(train_data_gen)

# self.rl_cluster.buffer_metrics(
# {
# "actor_dequeue_time": (
# timer(),
# np.mean,
# ),
# },
# mode=rl_cluster_lib.Mode.TRAIN,
# )
except StopIteration:
data_insufficient = True
break

self._iter_steps += 1

# Filter out examples that are too old (off-policy).
filtered_train_micro_batch = (
self._filter_outdated_offpolicy_examples(train_micro_batch)
)
if not filtered_train_micro_batch:
continue
train_micro_batch = filtered_train_micro_batch

# --- Evaluation Logic ---
current_eval_dataset = None
if (
all_eval_prompts
and self.rl_cluster.actor_trainer.train_steps
% training_config.eval_every_n_steps
== 0
):
self._eval_iter_steps = 0
eval_orchestrator = self._build_orchestrator()

async def _eval_runner_async(current_eval_orchestrator):
eval_examples = []
async for batch in self._orchestrator_producer(
current_eval_orchestrator,
all_eval_prompts,
num_generations=self._num_generations(),
):
eval_example = self._batch_to_train_example(
batch,
rl_cluster_lib.Mode.EVAL,
)
eval_examples.extend(eval_example)
return eval_examples
merged_train_micro_batch = jax.tree.map(
lambda *xs: jnp.concatenate(xs, axis=0), *train_micro_batch
)

eval_future = asyncio.run_coroutine_threadsafe(
_eval_runner_async(eval_orchestrator), self.loop
)
eval_examples = eval_future.result()
self._eval_iter_steps += 1
current_eval_dataset = eval_examples
# --- Evaluation Logic for one micro-batch ---
current_eval_dataset = None
if (
all_eval_prompts
and self.rl_cluster.actor_trainer.train_steps
% training_config.eval_every_n_steps
== 0
):
self._eval_iter_steps = 0
eval_orchestrator = self._build_orchestrator()

async def _eval_runner_async(current_eval_orchestrator):
eval_examples = []
async for batch in self._orchestrator_producer(
current_eval_orchestrator,
all_eval_prompts,
num_generations=self._num_generations(),
):
eval_example = self._batch_to_train_example(
batch,
rl_cluster_lib.Mode.EVAL,
)
eval_examples.extend(eval_example)
return eval_examples

eval_future = asyncio.run_coroutine_threadsafe(
_eval_runner_async(eval_orchestrator), self.loop
)
eval_examples = eval_future.result()
self._eval_iter_steps += 1
current_eval_dataset = eval_examples

# --- Training Step for one micro-batch ---
self.rl_cluster.update_actor(
[merged_train_micro_batch], current_eval_dataset, skip_jit
)
if hasattr(self.rl_cluster, "critic_trainer"):
self.rl_cluster.update_critic(
train_micro_batch, current_eval_dataset, skip_jit
)

# --- Training Step ---
self.rl_cluster.update_actor(
[merged_train_micro_batch], current_eval_dataset, skip_jit
)
if hasattr(self.rl_cluster, "critic_trainer"):
self.rl_cluster.update_critic(
train_micro_batch, current_eval_dataset, skip_jit
)
# End the global step due to data insufficiency.
if data_insufficient:
break

# --- Weight Sync Logic ---
micro_batches_since_last_sync += 1
if micro_batches_since_last_sync == micro_batches_per_full_batch:
# --- Weight Sync Logic ---
if self.should_sync_weights:
logging.info("Requesting sync lock to sync weights...")
self._rollout_sync_lock.acquire_weight_sync()
try:
logging.info("Sync lock acquired. Syncing weights.")
self.rl_cluster.sync_weights()
with self.rl_cluster.perf.span(
"weight_sync", self.rl_cluster.perf.all_devices
):
self.rl_cluster.sync_weights()
self.policy_version += 1
logging.info(
"Weights synced. Policy version incremented to %d.",
Expand All @@ -804,7 +829,19 @@ async def _eval_runner_async(current_eval_orchestrator):
)
except StopIteration:
prompt_queue.put(None)
micro_batches_since_last_sync = 0

self.rl_cluster.buffer_metrics(
self.rl_cluster.perf.export(),
mode=rl_cluster_lib.Mode.TRAIN,
)
# End of Global Step

logging.info(
"Training loop finished. global_steps: %d, max_steps: %d",
self.rl_cluster.global_steps,
self._training_config.max_steps,
)
prompt_queue.put(None)

_ = producer_future.result()
self.rl_cluster.close()
Expand Down