Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 83 additions & 80 deletions tunix/sft/peft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,95 +622,98 @@ def train(
train_iterator = iter(train_ds)
index = 0
last_step_completion_time = time.perf_counter()
with utils.time_measure("Train loop"):
while True:
self._prof.maybe_activate(self._iter_steps)
with jax.profiler.StepTraceAnnotation(
"train", step_num=self._iter_steps
while True:
self._prof.maybe_activate(self._iter_steps)
with jax.profiler.StepTraceAnnotation(
"train", step_num=self._iter_steps
):
train_example = None
if self.data_hooks:
train_example = self.data_hooks.load_next_train_batch(self)
else:
try:
train_example = next(train_iterator)
if not self.is_managed_externally:
# TODO(mridulsahu): Add support to restore the iterator state
# instead of skipping the already trained examples.
if index < self._iter_steps:
# Skip the examples that are already trained.
index += 1
continue
index += 1
except StopIteration:
pass

if train_example is None:
break

# Stop training if max_steps is reached.
if (
self.config.max_steps is not None
and self._train_steps >= self.config.max_steps
):
train_example = None
if self.data_hooks:
train_example = self.data_hooks.load_next_train_batch(self)
else:
try:
train_example = next(train_iterator)
if not self.is_managed_externally:
# TODO(mridulsahu): Add support to restore the iterator state
# instead of skipping the already trained examples.
if index < self._iter_steps:
# Skip the examples that are already trained.
index += 1
continue
index += 1
except StopIteration:
pass

if train_example is None:
break

# Stop training if max_steps is reached.
if (
self.config.max_steps is not None
and self._train_steps >= self.config.max_steps
):
break
break

train_example = self._prepare_inputs(train_example)
train_example = sharding_utils.shard_input(
train_example, self.config.data_sharding_axis
)
train_example = self._prepare_inputs(train_example)
train_example = sharding_utils.shard_input(
train_example, self.config.data_sharding_axis
)

self._throttler.wait_for_next()
if self.training_hooks:
self.training_hooks.on_train_step_start(self)

with self._perf_tracer.span(
"peft_train_step", pxla.thread_resources.env.physical_mesh.devices
) as span:
train_loss, aux = train_step(train_example)
span.device_end([train_loss])

current_time = time.perf_counter()
step_time_delta = current_time - last_step_completion_time
last_step_completion_time = current_time

self._throttler.add_computation(train_loss)
self._buffered_train_metrics = self._buffer_metrics(
self._buffered_train_metrics,
loss=train_loss,
step=self._train_steps,
step_time_delta=step_time_delta,
self._throttler.wait_for_next()
if self.training_hooks:
self.training_hooks.on_train_step_start(self)

with self._perf_tracer.span(
"peft_train_step", pxla.thread_resources.env.physical_mesh.devices
) as span:
train_loss, aux = train_step(train_example)
span.device_end([train_loss])

current_time = time.perf_counter()
step_time_delta = current_time - last_step_completion_time
last_step_completion_time = current_time

self._throttler.add_computation(train_loss)
self._buffered_train_metrics = self._buffer_metrics(
self._buffered_train_metrics,
loss=train_loss,
step=self._train_steps,
step_time_delta=step_time_delta,
)
# NB: put this after self._buffer_metrics is important
self._post_process_train_step(aux)
self._iter_steps += 1

if (
self._iter_steps
% self.config.get_with_default("gradient_accumulation_steps", 1)
== 0
):
self._train_steps += 1
self._write_train_metrics()

# Checkpoint frequency is configured by checkpointing_options.
self.checkpoint_manager.save(
self._train_steps,
self.model,
self.optimizer,
save_only_lora_params=self._lora_enabled,
custom_metadata=self.custom_checkpoint_metadata(),
)
# NB: put this after self._buffer_metrics is important
self._post_process_train_step(aux)
self._iter_steps += 1

if (
self._iter_steps
% self.config.get_with_default("gradient_accumulation_steps", 1)
== 0
eval_ds
and self._train_steps % self.config.eval_every_n_steps == 0
):
self._train_steps += 1
self._write_train_metrics()

# Checkpoint frequency is configured by checkpointing_options.
self.checkpoint_manager.save(
self._train_steps,
self.model,
self.optimizer,
save_only_lora_params=self._lora_enabled,
custom_metadata=self.custom_checkpoint_metadata(),
)

if (
eval_ds
and self._train_steps % self.config.eval_every_n_steps == 0
):
self._run_eval(eval_ds, eval_step)

self._prof.maybe_deactivate(self._iter_steps)
self._run_eval(eval_ds, eval_step)

self._prof.maybe_deactivate(self._iter_steps)

self._throttler.wait_for_all()
logging.info(
"Train loop finished in: %.4f seconds",
time.perf_counter() - last_step_completion_time,
)
if self.training_hooks:
self.training_hooks.on_train_end(self)
if not self.is_managed_externally:
Expand Down