Skip to content

Conversation

@oleksost
Copy link
Contributor

@oleksost oleksost commented Dec 16, 2025

✨ Description

Refactors loss definition and logging in thehead.py:

  • makes logging more explicit
  • implements forward KL
  • only log unscaled losses
  • defines a separate loss configs for each loss in the lm+_head
  • allows logging of losses without training on them

TODO:

  • Update tests

Config example:

    head:
      lr_scale: 0.0
      losses:
        lm_loss:
          type: cross_entropy
          factor: 1.0
        reverse_kl:
          type: reverse_kl_distillation
          factor: 1.0
        forward_kl:
          type: forward_kl_distillation
          factor: 0.0 # track without logging

This will train using lm_loss which is a cross_entropy_lm_loss as well as reverse_kl, both weighted with 1. Will also log forward_kl.

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

List the key changes introduced in this PR:

  1. Change A
  2. Change B

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@oleksost oleksost marked this pull request as draft December 16, 2025 13:31
@oleksost oleksost marked this pull request as ready for review December 16, 2025 14:20
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this PR. If we only have layer distillation, doesn't that mean we don't train the model head at all?

@oleksost oleksost changed the title Train with only layer distillation losses Train with only layer distillation losses + explicit logging Dec 17, 2025
@oleksost oleksost changed the title Train with only layer distillation losses + explicit logging Refactor lm head losses Dec 22, 2025
@oleksost oleksost changed the title Refactor lm head losses Refactor lm_head losses Dec 22, 2025
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this, it's badly needed.

self.language_model_loss_factor = 0.0
for loss_config in self.losses.values():
if "dist" in loss_config.type:
assert self.distillation_model is not None, "Distillation loss requires a distillation model."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the distillation model go with the loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, this raises error when there is no distillation mode, this is correct, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distillation_model parameter is not needed in the language model head itself, only the losses use it, so it should be moved with the losses that use it along with these checks.

desc="Configuration for the LM output layer (weight). Ignored for tied embeddings",
hint=FieldHint.architecture,
)
cross_entropy_implementation: CrossEntropyImpl = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These removals are likely to cause backward compatibility issues when loading existing models. Please make sure it doesn't disrupt ongoing work, and if needed add backward compatibility in _validate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested training with checkpoints created on the main branch in both distributed and apriel2 format. Training starts with no issues.

Removed the targets, class, moved tragets processing to losses, made loss masks more explicit
else:
self.language_model_loss_factor = 0.0
if not self.losses:
if "losses" not in self._explicit_fields:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it's needed, it doesn't make sense to have a head without loss.

Can simplify to
self.losses = {"lm_loss": CrossEntropyLMLossConfig()}

self.language_model_loss_factor = 0.0
for loss_config in self.losses.values():
if "dist" in loss_config.type:
assert self.distillation_model is not None, "Distillation loss requires a distillation model."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distillation_model parameter is not needed in the language model head itself, only the losses use it, so it should be moved with the losses that use it along with these checks.

sequence_parallel=self._sequence_parallel and self._vocab_parallel,
)

# TODO: also move to lm_head_losses?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not making it into an independent loss like the others?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed.

Actually it looks like z_loss is useless here, since it is implemented using gradient injection back backward on logits is never called, and if it is, then it does not backprop into the model due to detach() in the _logits_loss_forward_backward call.

@@ -0,0 +1,23 @@
from fast_llm.layers.block.config import BlockKwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about moving here, the convention is to leave in config.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

easier to have it here to avoid circular imports

@@ -0,0 +1,344 @@
import abc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By convention configs are expected to go in a file name config.py. I recommend moving this back to the config file. (Oother option would be to create a loss subdirectory`, but not really justified at this stage)

loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.AVG, group=group)
if return_target_entropy and target_format == TargetFormat.logits:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect for other target formats?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be fine for other formats as well

all_reduce(loss, op=ReduceOp.AVG, group=group)
if return_target_entropy and target_format == TargetFormat.logits:
# Compute teacher entropy
teacher_log_prob = torch.log(target + 1e-20)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For TargetFormat.logits we should be using log_softmax(logits) which is numerically stable. It's simply target_logits - log(sum_exp_target_logits), and we already computed sum_exp_target_logits in _fused_softmax_base

"head": {"output_weight": init_1},
"head": {
"output_weight": init_1,
"losses": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can drop this default value, it will make updating easier.

Assert.eq(len(rank_breakdowns), world_size)


if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove

},
"num_blocks": 12,
},
"head": {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"weight" should not be explicit (will go away if removed from the default in validation)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants