From 5ec1f25b2c0a437b2ab9668aec2f5245bd7cf31b Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Fri, 30 May 2025 15:11:38 -0700 Subject: [PATCH] `CompositeLossMetrics` now performs a weighted sum of losses. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, `CompositeLossMetrics` sums the losses without considering their weights (i.e., the number of live targets). To make this a weighted sum, downstream code has been implementing `CompositeLossWeights` to inject the number of live targets into `loss_weights`. This is essentially patching a surprising logic (initail loss sum) with complex logic (CompositeLossWeights) into a straightforward one (weighted sum). Therefore, we’re changing the default loss aggregation logic to be straightforward from the beginning. From now on, our standarized loss aggregation logic is ``` loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(each_loss_weight * num_each_samples) ``` Historically, the complex logic was introduced because the weights of losses returned by child metrics were unknown. But now that child metrics return losses as `WeightedScalar`, we can adopt a simpler, cleaner aggregation logic. Note: alternative formulation could be ``` loss = sum(each_loss_weight * each_loss * num_each_samples) / sum(num_each_samples) ``` However, when num_each_samples is large and each_loss_weight is small, the denominator can become disproportionately large. So we discard this option. --- axlearn/common/causal_lm.py | 51 +++++++++++--------------------- axlearn/common/causal_lm_test.py | 24 +++++++-------- 2 files changed, 28 insertions(+), 47 deletions(-) diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index ee60f84a7..55d486ac1 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -234,21 +234,6 @@ def _update(x: dict, updates: dict): x.update(updates) -class CompositeLossWeights(Module): - """Computes loss weights.""" - - def forward(self, child_metrics: dict[str, tuple[Tensor, Nested[Tensor]]]) -> dict[str, Tensor]: - """Computes per-child loss weights from child metrics. - - Args: - child_metrics: A mapping from child name to (child_loss, child_metrics). - - Returns: - A mapping from child name to loss weight. - """ - raise NotImplementedError(type(self)) - - class CompositeLossMetrics(BaseLossMetrics): """Computes a composite loss from multiple child metrics.""" @@ -258,14 +243,14 @@ class Config(BaseLossMetrics.Config): Attributes: metrics: A mapping from child name to metrics config. - loss_weights: A `CompositeLossWeights` implementation. + loss_weights: A mapping from child name to its weight. If None, all weights are considered 1. flatten_metrics: Whether to flatten summaries and metrics from each child. If None, defaults to True. """ metrics: Required[dict[str, BaseLossMetrics.Config]] = REQUIRED - loss_weights: Optional[CompositeLossWeights.Config] = None + loss_weights: Optional[dict[str, float]] = None flatten_metrics: Optional[bool] = None def __init__(self, cfg, *, parent): @@ -274,10 +259,6 @@ def __init__(self, cfg, *, parent): self._metrics: dict[str, BaseLossMetrics] = {} for name, child in cfg.metrics.items(): self._metrics[name] = self._add_child(name, child) - if cfg.loss_weights is not None: - self.loss_weights: CompositeLossMetrics = self._add_child( - "loss_weights", cfg.loss_weights - ) def forward( self, @@ -301,19 +282,21 @@ def forward( module_outputs=module_outputs, ) - if "loss_weights" in self.children: - loss_weights: dict[str, Tensor] = self.loss_weights(all_child_metrics) - else: - loss_weights = None - + loss_weights = cfg.loss_weights losses = [] metrics = {} for name, (child_loss, child_metrics) in all_child_metrics.items(): - # Downstream wants unweighted losses. - child_metrics[f"loss_{name}"] = child_loss if loss_weights is not None and loss_weights.get(name, None) is not None: - child_loss = WeightedScalar(child_loss.mean * loss_weights[name], child_loss.weight) + # Multiply loss_weights only to child_loss.weight. Note that child_loss.mean can be + # interpreted as the result of: + # child_loss.mean = (child_loss.mean * child_loss.weight * loss_weights[name]) + # / (child_loss.weight * loss_weights[name]) + # For reference, the total loss is computed as: + # total_loss = sum(each_loss_weight * each_loss * num_each_samples) + # / sum(each_loss_weight * num_each_samples) + child_loss = WeightedScalar(child_loss.mean, child_loss.weight * loss_weights[name]) losses.append(child_loss) + child_metrics[f"loss_{name}"] = child_loss ctx = self.get_invocation_context() @@ -327,13 +310,13 @@ def _aggregate(losses): if not losses: return WeightedScalar(0.0, 0.0) - # For backward compatibility, aggregation is done using sum(each.mean) instead of - # sum(each.mean * each.weight) / sum(each.weight). - loss = weight = 0.0 + loss_sum = weight = 0.0 for each in losses: - loss += each.mean + loss_sum += each.mean * each.weight weight += each.weight - return WeightedScalar(loss, weight) + # Note: weight = loss_weights * num_samples, so can be smaller than 1. + eps = 1e-8 + return WeightedScalar(loss_sum / jnp.maximum(weight, eps), weight) loss = _aggregate(losses) return loss, metrics diff --git a/axlearn/common/causal_lm_test.py b/axlearn/common/causal_lm_test.py index 199768024..e698b67fe 100644 --- a/axlearn/common/causal_lm_test.py +++ b/axlearn/common/causal_lm_test.py @@ -647,18 +647,13 @@ def forward(self, input_batch, **kwargs): del kwargs return WeightedScalar(input_batch[self.name], 1.0), {} - class FixedLossWeights(causal_lm.CompositeLossWeights): - def forward(self, child_metrics): - del child_metrics - return loss_weights - cfg = causal_lm.CompositeLossMetrics.default_config().set( name="test", metrics={ "test0": DummyMetrics.default_config(), "test1": DummyMetrics.default_config(), }, - loss_weights=FixedLossWeights.default_config(), + loss_weights=loss_weights, ) metrics = cfg.instantiate(parent=None) @@ -673,14 +668,17 @@ def forward(self, child_metrics): is_training=True, ) self.assertAlmostEqual( - loss.value(), 1.23 * loss_weights["test0"] + 3.45 * loss_weights["test1"] + loss.value(), + (1.23 * loss_weights["test0"] + 3.45 * loss_weights["test1"]) + / (loss_weights["test0"] + loss_weights["test1"]), ) def _aggregate(aux): - loss = 0.0 - for name, loss_weight in loss_weights.items(): - loss += loss_weight * aux[f"loss_{name}"].mean - return loss + loss_sum = weight = 0.0 + for name in loss_weights: + loss_sum += aux[f"loss_{name}"].weight * aux[f"loss_{name}"].mean + weight += aux[f"loss_{name}"].weight + return loss_sum / weight self.assertAlmostEqual(loss.value(), _aggregate(aux)) @@ -784,7 +782,7 @@ def test_aux_loss(self, aux_loss_regex, stack_cfg, use_aux_layer): else: self.assertEqual(aux["metrics"]["aux_loss"].value(), 0.0) self.assertEqual( - aux["metrics"]["cross_entropy"].value() + aux["metrics"]["aux_loss"].value(), loss + (aux["metrics"]["cross_entropy"] + aux["metrics"]["aux_loss"]).value(), loss ) else: self.assertNotIn("aux_loss", aux) @@ -862,7 +860,7 @@ def loss_fn(model_params, inputs): self.assertIn("aux_loss", summaries) self.assertEqual(summaries["aux_loss"].value(), 1.0) self.assertNestedAllClose( - summaries["cross_entropy_loss"].value() + summaries["aux_loss"].value(), + (summaries["cross_entropy_loss"] + summaries["aux_loss"]).value(), outputs.forward_outputs.loss, )