-
Notifications
You must be signed in to change notification settings - Fork 39
Refactor lm_head losses #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
jlamypoirier
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand this PR. If we only have layer distillation, doesn't that mean we don't train the model head at all?
jlamypoirier
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this, it's badly needed.
| self.language_model_loss_factor = 0.0 | ||
| for loss_config in self.losses.values(): | ||
| if "dist" in loss_config.type: | ||
| assert self.distillation_model is not None, "Distillation loss requires a distillation model." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the distillation model go with the loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, this raises error when there is no distillation mode, this is correct, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distillation_model parameter is not needed in the language model head itself, only the losses use it, so it should be moved with the losses that use it along with these checks.
| desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| cross_entropy_implementation: CrossEntropyImpl = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These removals are likely to cause backward compatibility issues when loading existing models. Please make sure it doesn't disrupt ongoing work, and if needed add backward compatibility in _validate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested training with checkpoints created on the main branch in both distributed and apriel2 format. Training starts with no issues.
…ow/Fast-LLM into train_only_layer_losses
…ow/Fast-LLM into train_only_layer_losses
Removed the targets, class, moved tragets processing to losses, made loss masks more explicit
| else: | ||
| self.language_model_loss_factor = 0.0 | ||
| if not self.losses: | ||
| if "losses" not in self._explicit_fields: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure it's needed, it doesn't make sense to have a head without loss.
Can simplify to
self.losses = {"lm_loss": CrossEntropyLMLossConfig()}
| self.language_model_loss_factor = 0.0 | ||
| for loss_config in self.losses.values(): | ||
| if "dist" in loss_config.type: | ||
| assert self.distillation_model is not None, "Distillation loss requires a distillation model." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distillation_model parameter is not needed in the language model head itself, only the losses use it, so it should be moved with the losses that use it along with these checks.
| sequence_parallel=self._sequence_parallel and self._vocab_parallel, | ||
| ) | ||
|
|
||
| # TODO: also move to lm_head_losses? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not making it into an independent loss like the others?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed.
Actually it looks like z_loss is useless here, since it is implemented using gradient injection back backward on logits is never called, and if it is, then it does not backprop into the model due to detach() in the _logits_loss_forward_backward call.
| @@ -0,0 +1,23 @@ | |||
| from fast_llm.layers.block.config import BlockKwargs | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about moving here, the convention is to leave in config.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
easier to have it here to avoid circular imports
| @@ -0,0 +1,344 @@ | |||
| import abc | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By convention configs are expected to go in a file name config.py. I recommend moving this back to the config file. (Oother option would be to create a loss subdirectory`, but not really justified at this stage)
fast_llm/functional/cross_entropy.py
Outdated
| loss = per_sample_loss.mean() | ||
| if target_format != TargetFormat.labels and group is not None: | ||
| all_reduce(loss, op=ReduceOp.AVG, group=group) | ||
| if return_target_entropy and target_format == TargetFormat.logits: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect for other target formats?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be fine for other formats as well
fast_llm/functional/cross_entropy.py
Outdated
| all_reduce(loss, op=ReduceOp.AVG, group=group) | ||
| if return_target_entropy and target_format == TargetFormat.logits: | ||
| # Compute teacher entropy | ||
| teacher_log_prob = torch.log(target + 1e-20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For TargetFormat.logits we should be using log_softmax(logits) which is numerically stable. It's simply target_logits - log(sum_exp_target_logits), and we already computed sum_exp_target_logits in _fused_softmax_base
| "head": {"output_weight": init_1}, | ||
| "head": { | ||
| "output_weight": init_1, | ||
| "losses": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can drop this default value, it will make updating easier.
tests/test_config.py
Outdated
| Assert.eq(len(rank_breakdowns), world_size) | ||
|
|
||
|
|
||
| if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove
tests/test_config.py
Outdated
| }, | ||
| "num_blocks": 12, | ||
| }, | ||
| "head": {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"weight" should not be explicit (will go away if removed from the default in validation)
✨ Description
Refactors loss definition and logging in the
head.py:TODO:
Config example:
This will train using
lm_losswhich is across_entropy_lm_lossas well asreverse_kl, both weighted with 1. Will also logforward_kl.🔍 Type of change
Select all that apply:
📝 Changes
List the key changes introduced in this PR:
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.