Skip to content

CompositeLossMetrics now performs a weighted sum of losses. #1251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ds-hwang
Copy link
Contributor

Currently, CompositeLossMetrics sums the losses without considering their weights (i.e., the number of live targets). To make this a weighted sum, downstream code has been implementing CompositeLossWeights to inject the number of live targets into loss_weights. This is essentially patching a surprising logic (initail loss sum) with complex logic (CompositeLossWeights) into a straightforward one (weighted sum).

Therefore, we’re changing the default loss aggregation logic to be straightforward from the beginning.

From now on, our standarized loss aggregation logic is

loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(each_loss_weight * num_each_samples)

Historically, the complex logic was introduced because the weights of losses returned by child metrics were unknown. But now that child metrics return losses as WeightedScalar, we can adopt a simpler, cleaner aggregation logic.

Note: alternative formulation could be

loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(num_each_samples)

However, when num_each_samples is large and each_loss_weight is small, the denominator can become disproportionately large. So we discard this option.

Currently, `CompositeLossMetrics` sums the losses without considering their
weights (i.e., the number of live targets). To make this a weighted sum,
downstream code has been implementing `CompositeLossWeights` to inject the
number of live targets into `loss_weights`. This is essentially patching a
surprising logic (initail loss sum) with complex logic (CompositeLossWeights)
into a straightforward one (weighted sum).

Therefore, we’re changing the default loss aggregation logic to be
straightforward from the beginning.

From now on, our standarized loss aggregation logic is
```
loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(each_loss_weight * num_each_samples)
```

Historically, the complex logic was introduced because the weights of
losses returned by child metrics were unknown. But now that child metrics
return losses as `WeightedScalar`, we can adopt a simpler, cleaner aggregation
logic.

Note: alternative formulation could be
```
loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(num_each_samples)
```
However, when num_each_samples is large and each_loss_weight is small, the
denominator can become disproportionately large. So we discard this option.
@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners June 10, 2025 16:22
@ds-hwang
Copy link
Contributor Author

@markblee Could you take a look? From 1399

Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Will approve after the internal review completes.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants