Skip to content

Add opensource MoE switch model for TPU v6e testing #1268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion axlearn/common/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""

import re
from typing import NamedTuple, Optional, Union
from functools import reduce
from typing import NamedTuple, Optional, Sequence, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -45,13 +46,16 @@
from axlearn.common.param_init import FanAxes, constant_initializer
from axlearn.common.quantized_dot_general.layers import DenseGeneralBaseLayer
from axlearn.common.utils import (
HybridMeshShape,
MeshShape,
Nested,
NestedTensor,
PartitionSpec,
Tensor,
VDict,
flatten_items,
get_recursively,
infer_mesh_shape,
set_recursively,
tree_paths,
with_sharding_constraint,
Expand Down Expand Up @@ -167,6 +171,52 @@ def _cap_logits(logits: Tensor, gating_logit_cap: float) -> Tensor:
return logits


def get_outer_batch_from_mesh(
*,
mesh_axis_names: Sequence[str],
outer_batch_axis_names: Sequence[str],
mesh_shape: Optional[Union[MeshShape, HybridMeshShape]],
) -> Optional[int]:
"""Infer MoE outer batch size from mesh shape.

Args:
mesh_axis_names: The name of each mesh axis.
outer_batch_axis_names: The names of the mesh axes corresponding to the outer batch size.
mesh_shape: The size of each mesh axis corresponding to `mesh_axis_names`.
If None, the returned outer batch size will also be None.

Returns:
The MoE outer batch size. Will be None if `mesh_shape` is None.
"""
if mesh_shape is None:
return None

ici_mesh_shape = (
mesh_shape.ici_mesh_shape if isinstance(mesh_shape, HybridMeshShape) else mesh_shape
)
try:
ici_mesh_shape = infer_mesh_shape(ici_mesh_shape)
except ValueError as e:
# It could happen when running in local, the number of devices can be smaller than the
# required number of devices from the mesh shape.
logging.info(e)

if isinstance(mesh_shape, HybridMeshShape):
if -1 in mesh_shape.dcn_mesh_shape:
# TODO(markblee): Improve support for this. At the moment it is not a use-case.
raise NotImplementedError(
"Unable to infer number of granules. Please specify dcn_mesh_shape without -1."
)
mesh_shape = tuple(x * y for x, y in zip(ici_mesh_shape, mesh_shape.dcn_mesh_shape))
else:
mesh_shape = ici_mesh_shape

return reduce(
lambda x, y: x * y,
[mesh_shape[mesh_axis_names.index(el)] for el in outer_batch_axis_names],
)


class AdaptiveLoadBalanceLoss(BaseLayer):
"""A layer to adjust the aux loss weight based on the overcapacity ratio.

Expand Down
59 changes: 58 additions & 1 deletion axlearn/common/mixture_of_experts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TransformerFeedForwardMoE,
_convert_feedforward_to_moe_parameters,
convert_dense_to_moe_parameters,
get_outer_batch_from_mesh,
)
from axlearn.common.module import functional as F
from axlearn.common.quantized_dot_general.activation_clipping import TanhActivationClippingLayer
Expand All @@ -43,7 +44,14 @@
QuantizedDotGeneral,
)
from axlearn.common.test_utils import TestCase, assert_allclose
from axlearn.common.utils import get_recursively, set_recursively, shapes
from axlearn.common.utils import (
HybridMeshShape,
MeshShape,
get_recursively,
infer_mesh_shape,
set_recursively,
shapes,
)


# pylint: disable=no-self-use,protected-access
Expand Down Expand Up @@ -846,5 +854,54 @@ def test_dense_to_moe_parameters(self):
assert_allclose(outputs_dense.data, outputs_moe.data)


class GetOuterBatchFromMeshTest(absltest.TestCase):
"""Tests get_outer_batch_from_mesh."""

def test_mesh_shape_is_none(self):
result = get_outer_batch_from_mesh(
mesh_axis_names=["data", "model"],
outer_batch_axis_names=["data"],
mesh_shape=None,
)
self.assertIsNone(result)

def test_regular_mesh_shape(self):
mesh_shape: MeshShape = (2, 4, 8)
result = get_outer_batch_from_mesh(
mesh_axis_names=["data", "fsdp", "model"],
outer_batch_axis_names=["data", "fsdp"],
mesh_shape=mesh_shape,
)
self.assertEqual(result, 2 * 4)

def test_outer_batch_with_hybrid_mesh_shape(self):
hybrid = HybridMeshShape(ici_mesh_shape=(2, 8, 1), dcn_mesh_shape=(4, 1, 1))
result = get_outer_batch_from_mesh(
mesh_axis_names=["data", "model"],
outer_batch_axis_names=["data"],
mesh_shape=hybrid,
)
self.assertEqual(result, 2 * 4)

def test_outer_batch_with_hybrid_mesh_shape_with_raises(self):
hybrid = HybridMeshShape(ici_mesh_shape=(2, 4, 1), dcn_mesh_shape=(-1, 2, 1))
with self.assertRaisesRegex(NotImplementedError, "Unable to infer number of granules"):
get_outer_batch_from_mesh(
mesh_axis_names=["data", "fsdp", "model"],
outer_batch_axis_names=["data", "fsdp"],
mesh_shape=hybrid,
)

def test_outer_batch_with_infer_mesh_shape(self):
mesh_shape: MeshShape = (-1, 2, 4)
inferred_shape = infer_mesh_shape(mesh_shape, num_devices=8)
result = get_outer_batch_from_mesh(
mesh_axis_names=["data", "fsdp", "model"],
outer_batch_axis_names=["data", "fsdp"],
mesh_shape=inferred_shape,
)
self.assertEqual(result, 1 * 2)


if __name__ == "__main__":
absltest.main()
Loading