Skip to content

Commit 4a4da82

Browse files
hanzhi713changlan
authored andcommitted
Support attention activation sharding
GitOrigin-RevId: 3c305c44d263851fb49318feaef0ac6d9594449f
1 parent 0fb8d94 commit 4a4da82

File tree

3 files changed

+60
-39
lines changed

3 files changed

+60
-39
lines changed

axlearn/common/attention.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@
161161
check_numerics,
162162
flatten_items,
163163
get_or_none,
164+
maybe_shard,
164165
save_and_offload_only_these_names_regex,
165166
shapes,
166167
split_prng_key,
167-
with_sharding_constraint,
168168
)
169169

170170

@@ -1560,18 +1560,18 @@ class Config(BaseLayer.Config):
15601560
logit_sink: Optional[bool] = None
15611561

15621562
# Partition spec for query ([batch, seq, q_heads, head_dim]) after input projections.
1563-
q_partition_spec: Optional[PartitionSpec] = None
1563+
q_partition_spec: Optional[Sequence[Union[str, Sequence[str], None]]] = None
15641564

15651565
# Partition spec for key ([batch, seq, kv_heads, head_dim]) after input projections.
15661566
# Follows `q_partition_spec` if None.
1567-
k_partition_spec: Optional[PartitionSpec] = None
1567+
k_partition_spec: Optional[Sequence[Union[str, Sequence[str], None]]] = None
15681568

15691569
# Partition spec for value ([batch, seq, kv_heads, head_dim]) after input projections.
15701570
# Follows `q_partition_spec` if None.
1571-
v_partition_spec: Optional[PartitionSpec] = None
1571+
v_partition_spec: Optional[Sequence[Union[str, Sequence[str], None]]] = None
15721572

15731573
# Partition spec for output ([batch, seq, hidden_dim]) after output projections.
1574-
o_partition_spec: Optional[PartitionSpec] = None
1574+
o_partition_spec: Optional[Sequence[Union[str, Sequence[str], None]]] = None
15751575

15761576
def __init__(self, cfg: Config, *, parent: Module):
15771577
super().__init__(cfg, parent=parent)
@@ -1736,12 +1736,9 @@ def _forward_for_mode(
17361736
time_step = cached_states["time_step"]
17371737
query_positions = query_positions + time_step[:, None] # [batch, steps]
17381738
q_proj, k_proj, v_proj = self.i_proj(query, query_positions=query_positions, **kv_kwargs)
1739-
if cfg.q_partition_spec:
1740-
q_proj = with_sharding_constraint(q_proj, cfg.q_partition_spec)
1741-
if cfg.q_partition_spec or cfg.k_partition_spec:
1742-
k_proj = with_sharding_constraint(k_proj, cfg.k_partition_spec or cfg.q_partition_spec)
1743-
if cfg.q_partition_spec or cfg.v_partition_spec:
1744-
v_proj = with_sharding_constraint(v_proj, cfg.v_partition_spec or cfg.q_partition_spec)
1739+
q_proj = maybe_shard(q_proj, cfg.q_partition_spec)
1740+
k_proj = maybe_shard(k_proj, cfg.k_partition_spec or cfg.q_partition_spec)
1741+
v_proj = maybe_shard(v_proj, cfg.v_partition_spec or cfg.q_partition_spec)
17451742

17461743
if cfg.scale_kv_before_cache_update:
17471744
if has_external_kv_state:
@@ -1844,8 +1841,7 @@ def _forward_for_mode(
18441841

18451842
# [batch, target_length, output_dim].
18461843
o_proj = self.o_proj(context)
1847-
if cfg.o_partition_spec:
1848-
o_proj = with_sharding_constraint(o_proj, cfg.o_partition_spec)
1844+
o_proj = maybe_shard(o_proj, cfg.o_partition_spec)
18491845
outputs = self._remat_name(o_proj, "o_proj")
18501846
self._add_tensor_stats("o_proj_outputs", outputs)
18511847
return_aux = return_aux or set()
@@ -3608,15 +3604,17 @@ def extend_step(
36083604
def set_attention_partition_specs(
36093605
cfg: MultiheadAttention.Config,
36103606
*,
3607+
batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"),
36113608
fsdp_axis_names: Union[str, Sequence[str]] = "fsdp",
36123609
tp_axis_names: Union[str, Sequence[str]] = "model",
3610+
seq_axis_names: Union[str, Sequence[str]] = "seq",
3611+
set_attn_activation_specs: bool = False,
36133612
):
36143613
"""Sets `cfg` to shard attention weights over both fsdp and tp axes.
36153614
36163615
Args:
36173616
cfg: A MultiheadAttention layer config to apply sharding spec to.
3618-
fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
3619-
tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
3617+
**kwargs: See `set_double_shard_weights_config`.
36203618
"""
36213619
# Shard weights.
36223620
input_linear_cfg = cfg.input_linear
@@ -3625,6 +3623,10 @@ def set_attention_partition_specs(
36253623
input_linear_cfg.layer.param_partition_spec = (fsdp_axis_names, tp_axis_names, None)
36263624
cfg.output_linear.param_partition_spec = (fsdp_axis_names, tp_axis_names, None)
36273625

3626+
if set_attn_activation_specs:
3627+
cfg.q_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names, None)
3628+
cfg.o_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
3629+
36283630

36293631
def set_feed_forward_partition_specs(
36303632
cfg: TransformerFeedForwardLayer.Config,
@@ -3638,10 +3640,7 @@ def set_feed_forward_partition_specs(
36383640
36393641
Args:
36403642
cfg: A TransformerFeedForwardLayer layer config to apply sharding spec to.
3641-
batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors.
3642-
fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
3643-
tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
3644-
seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors.
3643+
**kwargs: See `set_double_shard_weights_config`.
36453644
"""
36463645
# Shard weights.
36473646
cfg.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names)
@@ -3658,6 +3657,7 @@ def set_double_shard_weights_config(
36583657
fsdp_axis_names: Union[str, Sequence[str]] = "fsdp",
36593658
tp_axis_names: Union[str, Sequence[str]] = "model",
36603659
seq_axis_names: Union[str, Sequence[str]] = "seq",
3660+
set_attn_activation_specs: bool = False,
36613661
):
36623662
"""Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes.
36633663
@@ -3667,32 +3667,35 @@ def set_double_shard_weights_config(
36673667
fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
36683668
tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
36693669
seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors.
3670+
set_attn_activation_specs: Whether to set activation spec of qkvo projections. This may be
3671+
required in for some complex sharding cases.
36703672
"""
36713673

36723674
# pytype: disable=attribute-error
36733675
if not isinstance(cfg, Sequence):
36743676
cfg = [cfg]
36753677

3678+
axis_names = dict(
3679+
batch_axis_names=batch_axis_names,
3680+
fsdp_axis_names=fsdp_axis_names,
3681+
tp_axis_names=tp_axis_names,
3682+
seq_axis_names=seq_axis_names,
3683+
)
3684+
36763685
for layer_cfg in cfg:
36773686
set_attention_partition_specs(
36783687
layer_cfg.self_attention.attention,
3679-
fsdp_axis_names=fsdp_axis_names,
3680-
tp_axis_names=tp_axis_names,
3688+
set_attn_activation_specs=set_attn_activation_specs,
3689+
**axis_names,
36813690
)
36823691
if layer_cfg.cross_attention is not None:
36833692
set_attention_partition_specs(
36843693
layer_cfg.cross_attention.attention,
3685-
fsdp_axis_names=fsdp_axis_names,
3686-
tp_axis_names=tp_axis_names,
3694+
set_attn_activation_specs=set_attn_activation_specs,
3695+
**axis_names,
36873696
)
36883697
if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config):
3689-
set_feed_forward_partition_specs(
3690-
layer_cfg.feed_forward,
3691-
batch_axis_names=batch_axis_names,
3692-
fsdp_axis_names=fsdp_axis_names,
3693-
tp_axis_names=tp_axis_names,
3694-
seq_axis_names=seq_axis_names,
3695-
)
3698+
set_feed_forward_partition_specs(layer_cfg.feed_forward, **axis_names)
36963699
# pytype: enable=attribute-error
36973700

36983701

axlearn/common/attention_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,7 +2467,7 @@ def test_gqa_forward(
24672467
)
24682468
self.assertNestedAllClose(base_outputs, test_outputs)
24692469

2470-
@parameterized.product(kv_part=[None, PartitionSpec("fsdp", None, "model", None)])
2470+
@parameterized.product(kv_part=[None, ("fsdp", None, "model", None)])
24712471
@pytest.mark.d8
24722472
def test_qkvo_partition_spec(self, kv_part):
24732473
"""Tests that QKVO partition spec are applied correctly when specified."""
@@ -2477,8 +2477,8 @@ def test_qkvo_partition_spec(self, kv_part):
24772477
model_dim = 16
24782478
num_heads = 4
24792479
mesh = jax.make_mesh(mesh_shape, axis_names=("fsdp", "seq", "model"))
2480-
q_part = PartitionSpec("fsdp", "seq", "model", None)
2481-
o_part = PartitionSpec("fsdp", "seq", None)
2480+
q_part = ("fsdp", "seq", "model", None)
2481+
o_part = ("fsdp", "seq", None)
24822482

24832483
layer_kwargs = dict(
24842484
query_dim=model_dim,
@@ -2514,14 +2514,14 @@ def callback(sharding):
25142514
# pylint: disable-next=protected-access
25152515
normalize_spec = sharding.spec._normalized_spec_for_aval(len(tensor.shape))
25162516
if name == "q_proj":
2517-
self.assertEqual(normalize_spec, q_part)
2517+
self.assertEqual(normalize_spec, PartitionSpec(*q_part))
25182518
elif name == "o_proj":
2519-
self.assertEqual(normalize_spec, o_part)
2519+
self.assertEqual(normalize_spec, PartitionSpec(*o_part))
25202520
elif name in ["k_proj", "v_proj"]:
25212521
if kv_part is None:
2522-
self.assertEqual(normalize_spec, q_part)
2522+
self.assertEqual(normalize_spec, PartitionSpec(*q_part))
25232523
else:
2524-
self.assertEqual(normalize_spec, kv_part)
2524+
self.assertEqual(normalize_spec, PartitionSpec(*kv_part))
25252525

25262526
jax.debug.inspect_array_sharding(tensor, callback=callback)
25272527
return tensor

axlearn/common/flash_attention/layer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import jax
99
import jax.numpy as jnp
1010
import numpy as np
11+
from absl import logging
1112
from jax.experimental.shard_map import shard_map
1213
from jax.interpreters.pxla import thread_resources
1314
from jax.sharding import PartitionSpec
@@ -17,7 +18,7 @@
1718
from axlearn.common.config import ConfigBase, ConfigModifier, config_class
1819
from axlearn.common.flash_attention.utils import flash_attention_implementation
1920
from axlearn.common.module import Module
20-
from axlearn.common.utils import Tensor, with_sharding_constraint
21+
from axlearn.common.utils import Tensor, maybe_shard, with_sharding_constraint
2122

2223

2324
class FlashAttention(GroupedQueryAttention):
@@ -108,7 +109,7 @@ def _logit_biases_spec(self, attention_logit_biases: BaseAttentionBias) -> BaseA
108109

109110
def _maybe_repeat_kv_heads(self, key_or_value: Tensor) -> Tensor:
110111
"""Repeats key or value heads dim to be shardable."""
111-
cfg = self.config
112+
cfg: FlashAttention.Config = self.config
112113
partition_spec = cfg.mha_dim_to_partition_spec["bsnh"]
113114
global_mesh = thread_resources.env.physical_mesh
114115
if (
@@ -132,7 +133,24 @@ def _maybe_repeat_kv_heads(self, key_or_value: Tensor) -> Tensor:
132133
num_head_repeats = axis_size // key_or_value.shape[-2]
133134
# Repeat along the num_heads dim: [batch, source_length, repeated_num_heads, per_head_dim].
134135
if num_head_repeats > 1:
136+
logging.info(
137+
"Repeating %d KV heads %d times to meet the size of %s, which is %d.",
138+
key_or_value.shape[-2],
139+
num_head_repeats,
140+
axis,
141+
axis_size,
142+
)
135143
key_or_value = jnp.repeat(key_or_value, num_head_repeats, axis=-2)
144+
if cfg.k_partition_spec != cfg.v_partition_spec:
145+
raise ValueError(
146+
"FlashAttention doesn't support "
147+
f"{cfg.k_partition_spec=} != {cfg.v_partition_spec}"
148+
)
149+
# This maybe_shard is required when using "seq" > num_kv_heads and DeepSpeed Ulysses
150+
# style sequence parallelism. It tells the compiler to not reshard from partitioning
151+
# along the sequence axis to head axis before the `jnp.repeat` above, which otherwise
152+
# would cause an involuntary full materialization.
153+
key_or_value = maybe_shard(key_or_value, cfg.k_partition_spec or cfg.q_partition_spec)
136154

137155
if key_or_value.shape[-2] % axis_size != 0:
138156
raise ValueError(

0 commit comments

Comments
 (0)