Skip to content

CompositeLossMetrics now performs a weighted sum of losses. #1251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,6 @@ def _update(x: dict, updates: dict):
x.update(updates)


class CompositeLossWeights(Module):
"""Computes loss weights."""

def forward(self, child_metrics: dict[str, tuple[Tensor, Nested[Tensor]]]) -> dict[str, Tensor]:
"""Computes per-child loss weights from child metrics.

Args:
child_metrics: A mapping from child name to (child_loss, child_metrics).

Returns:
A mapping from child name to loss weight.
"""
raise NotImplementedError(type(self))


class CompositeLossMetrics(BaseLossMetrics):
"""Computes a composite loss from multiple child metrics."""

Expand All @@ -258,14 +243,14 @@ class Config(BaseLossMetrics.Config):

Attributes:
metrics: A mapping from child name to metrics config.
loss_weights: A `CompositeLossWeights` implementation.
loss_weights: A mapping from child name to its weight.
If None, all weights are considered 1.
flatten_metrics: Whether to flatten summaries and metrics from each child. If None,
defaults to True.
"""

metrics: Required[dict[str, BaseLossMetrics.Config]] = REQUIRED
loss_weights: Optional[CompositeLossWeights.Config] = None
loss_weights: Optional[dict[str, float]] = None
flatten_metrics: Optional[bool] = None

def __init__(self, cfg, *, parent):
Expand All @@ -274,10 +259,6 @@ def __init__(self, cfg, *, parent):
self._metrics: dict[str, BaseLossMetrics] = {}
for name, child in cfg.metrics.items():
self._metrics[name] = self._add_child(name, child)
if cfg.loss_weights is not None:
self.loss_weights: CompositeLossMetrics = self._add_child(
"loss_weights", cfg.loss_weights
)

def forward(
self,
Expand All @@ -301,19 +282,21 @@ def forward(
module_outputs=module_outputs,
)

if "loss_weights" in self.children:
loss_weights: dict[str, Tensor] = self.loss_weights(all_child_metrics)
else:
loss_weights = None

loss_weights = cfg.loss_weights
losses = []
metrics = {}
for name, (child_loss, child_metrics) in all_child_metrics.items():
# Downstream wants unweighted losses.
child_metrics[f"loss_{name}"] = child_loss
if loss_weights is not None and loss_weights.get(name, None) is not None:
child_loss = WeightedScalar(child_loss.mean * loss_weights[name], child_loss.weight)
# Multiply loss_weights only to child_loss.weight. Note that child_loss.mean can be
# interpreted as the result of:
# child_loss.mean = (child_loss.mean * child_loss.weight * loss_weights[name])
# / (child_loss.weight * loss_weights[name])
# For reference, the total loss is computed as:
# total_loss = sum(each_loss_weight * each_loss * num_each_samples)
# / sum(each_loss_weight * num_each_samples)
child_loss = WeightedScalar(child_loss.mean, child_loss.weight * loss_weights[name])
losses.append(child_loss)
child_metrics[f"loss_{name}"] = child_loss

ctx = self.get_invocation_context()

Expand All @@ -327,13 +310,13 @@ def _aggregate(losses):
if not losses:
return WeightedScalar(0.0, 0.0)

# For backward compatibility, aggregation is done using sum(each.mean) instead of
# sum(each.mean * each.weight) / sum(each.weight).
loss = weight = 0.0
loss_sum = weight = 0.0
for each in losses:
loss += each.mean
loss_sum += each.mean * each.weight
weight += each.weight
return WeightedScalar(loss, weight)
# Note: weight = loss_weights * num_samples, so can be smaller than 1.
eps = 1e-8
return WeightedScalar(loss_sum / jnp.maximum(weight, eps), weight)

loss = _aggregate(losses)
return loss, metrics
Expand Down
24 changes: 11 additions & 13 deletions axlearn/common/causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,18 +647,13 @@ def forward(self, input_batch, **kwargs):
del kwargs
return WeightedScalar(input_batch[self.name], 1.0), {}

class FixedLossWeights(causal_lm.CompositeLossWeights):
def forward(self, child_metrics):
del child_metrics
return loss_weights

cfg = causal_lm.CompositeLossMetrics.default_config().set(
name="test",
metrics={
"test0": DummyMetrics.default_config(),
"test1": DummyMetrics.default_config(),
},
loss_weights=FixedLossWeights.default_config(),
loss_weights=loss_weights,
)

metrics = cfg.instantiate(parent=None)
Expand All @@ -673,14 +668,17 @@ def forward(self, child_metrics):
is_training=True,
)
self.assertAlmostEqual(
loss.value(), 1.23 * loss_weights["test0"] + 3.45 * loss_weights["test1"]
loss.value(),
(1.23 * loss_weights["test0"] + 3.45 * loss_weights["test1"])
/ (loss_weights["test0"] + loss_weights["test1"]),
)

def _aggregate(aux):
loss = 0.0
for name, loss_weight in loss_weights.items():
loss += loss_weight * aux[f"loss_{name}"].mean
return loss
loss_sum = weight = 0.0
for name in loss_weights:
loss_sum += aux[f"loss_{name}"].weight * aux[f"loss_{name}"].mean
weight += aux[f"loss_{name}"].weight
return loss_sum / weight

self.assertAlmostEqual(loss.value(), _aggregate(aux))

Expand Down Expand Up @@ -784,7 +782,7 @@ def test_aux_loss(self, aux_loss_regex, stack_cfg, use_aux_layer):
else:
self.assertEqual(aux["metrics"]["aux_loss"].value(), 0.0)
self.assertEqual(
aux["metrics"]["cross_entropy"].value() + aux["metrics"]["aux_loss"].value(), loss
(aux["metrics"]["cross_entropy"] + aux["metrics"]["aux_loss"]).value(), loss
)
else:
self.assertNotIn("aux_loss", aux)
Expand Down Expand Up @@ -862,7 +860,7 @@ def loss_fn(model_params, inputs):
self.assertIn("aux_loss", summaries)
self.assertEqual(summaries["aux_loss"].value(), 1.0)
self.assertNestedAllClose(
summaries["cross_entropy_loss"].value() + summaries["aux_loss"].value(),
(summaries["cross_entropy_loss"] + summaries["aux_loss"]).value(),
outputs.forward_outputs.loss,
)

Expand Down