Skip to content

Have losses return non-reduced tensors#1020

Draft
Arcomano1234 wants to merge 6 commits intomainfrom
feature/losses-return-1d-loss-vector
Draft

Have losses return non-reduced tensors#1020
Arcomano1234 wants to merge 6 commits intomainfrom
feature/losses-return-1d-loss-vector

Conversation

@Arcomano1234
Copy link
Copy Markdown
Contributor

Short description of why the PR is needed and how it satisfies those requirements, in sentence form.

Changes:

  • symbol (e.g. fme.core.my_function) or script and concise description of changes or added feature

  • Can group multiple related symbols on a single bullet

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

@Arcomano1234 Arcomano1234 changed the title Feature/losses return 1d loss vector Have losses return non-reduced tensors Mar 26, 2026
)
step_loss = self._loss_obj(gen_step, target_step, step=step)
metrics[f"loss_step_{step}"] = step_loss.detach()
step_total = step_loss.sum()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: step total what?

Suggested change
step_total = step_loss.sum()
step_total_loss = step_loss.sum()

May want to rename step_loss to step_channel_loss if you do this, but also if you do the all reductions here (like we discussed earlier), having them here should make it clear.

prediction.data, target_data, prediction.step
return (
self._weight
* self._loss(prediction.data, target_data, prediction.step).sum()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we discussed was having the loss take in a reduce argument, and passing reduce=False for the ace case, which would avoid needing to do anything here (though coupled wouldn't get access to these improvements, either).

gen_step, target_step, step=step, reduce=False
)
step_total_loss = step_loss.sum()
metrics[f"loss_step_{step}"] = step_total_loss.detach()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are temporary and will be moved to per-channel reductions with the other per-channel loss PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants