161
161
check_numerics ,
162
162
flatten_items ,
163
163
get_or_none ,
164
+ maybe_shard ,
164
165
save_and_offload_only_these_names_regex ,
165
166
shapes ,
166
167
split_prng_key ,
167
- with_sharding_constraint ,
168
168
)
169
169
170
170
@@ -1560,18 +1560,18 @@ class Config(BaseLayer.Config):
1560
1560
logit_sink : Optional [bool ] = None
1561
1561
1562
1562
# Partition spec for query ([batch, seq, q_heads, head_dim]) after input projections.
1563
- q_partition_spec : Optional [PartitionSpec ] = None
1563
+ q_partition_spec : Optional [Sequence [ Union [ str , Sequence [ str ], None ]] ] = None
1564
1564
1565
1565
# Partition spec for key ([batch, seq, kv_heads, head_dim]) after input projections.
1566
1566
# Follows `q_partition_spec` if None.
1567
- k_partition_spec : Optional [PartitionSpec ] = None
1567
+ k_partition_spec : Optional [Sequence [ Union [ str , Sequence [ str ], None ]] ] = None
1568
1568
1569
1569
# Partition spec for value ([batch, seq, kv_heads, head_dim]) after input projections.
1570
1570
# Follows `q_partition_spec` if None.
1571
- v_partition_spec : Optional [PartitionSpec ] = None
1571
+ v_partition_spec : Optional [Sequence [ Union [ str , Sequence [ str ], None ]] ] = None
1572
1572
1573
1573
# Partition spec for output ([batch, seq, hidden_dim]) after output projections.
1574
- o_partition_spec : Optional [PartitionSpec ] = None
1574
+ o_partition_spec : Optional [Sequence [ Union [ str , Sequence [ str ], None ]] ] = None
1575
1575
1576
1576
def __init__ (self , cfg : Config , * , parent : Module ):
1577
1577
super ().__init__ (cfg , parent = parent )
@@ -1736,12 +1736,9 @@ def _forward_for_mode(
1736
1736
time_step = cached_states ["time_step" ]
1737
1737
query_positions = query_positions + time_step [:, None ] # [batch, steps]
1738
1738
q_proj , k_proj , v_proj = self .i_proj (query , query_positions = query_positions , ** kv_kwargs )
1739
- if cfg .q_partition_spec :
1740
- q_proj = with_sharding_constraint (q_proj , cfg .q_partition_spec )
1741
- if cfg .q_partition_spec or cfg .k_partition_spec :
1742
- k_proj = with_sharding_constraint (k_proj , cfg .k_partition_spec or cfg .q_partition_spec )
1743
- if cfg .q_partition_spec or cfg .v_partition_spec :
1744
- v_proj = with_sharding_constraint (v_proj , cfg .v_partition_spec or cfg .q_partition_spec )
1739
+ q_proj = maybe_shard (q_proj , cfg .q_partition_spec )
1740
+ k_proj = maybe_shard (k_proj , cfg .k_partition_spec or cfg .q_partition_spec )
1741
+ v_proj = maybe_shard (v_proj , cfg .v_partition_spec or cfg .q_partition_spec )
1745
1742
1746
1743
if cfg .scale_kv_before_cache_update :
1747
1744
if has_external_kv_state :
@@ -1844,8 +1841,7 @@ def _forward_for_mode(
1844
1841
1845
1842
# [batch, target_length, output_dim].
1846
1843
o_proj = self .o_proj (context )
1847
- if cfg .o_partition_spec :
1848
- o_proj = with_sharding_constraint (o_proj , cfg .o_partition_spec )
1844
+ o_proj = maybe_shard (o_proj , cfg .o_partition_spec )
1849
1845
outputs = self ._remat_name (o_proj , "o_proj" )
1850
1846
self ._add_tensor_stats ("o_proj_outputs" , outputs )
1851
1847
return_aux = return_aux or set ()
@@ -3608,15 +3604,17 @@ def extend_step(
3608
3604
def set_attention_partition_specs (
3609
3605
cfg : MultiheadAttention .Config ,
3610
3606
* ,
3607
+ batch_axis_names : Union [str , Sequence [str ]] = ("data" , "fsdp" ),
3611
3608
fsdp_axis_names : Union [str , Sequence [str ]] = "fsdp" ,
3612
3609
tp_axis_names : Union [str , Sequence [str ]] = "model" ,
3610
+ seq_axis_names : Union [str , Sequence [str ]] = "seq" ,
3611
+ set_attn_activation_specs : bool = False ,
3613
3612
):
3614
3613
"""Sets `cfg` to shard attention weights over both fsdp and tp axes.
3615
3614
3616
3615
Args:
3617
3616
cfg: A MultiheadAttention layer config to apply sharding spec to.
3618
- fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
3619
- tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
3617
+ **kwargs: See `set_double_shard_weights_config`.
3620
3618
"""
3621
3619
# Shard weights.
3622
3620
input_linear_cfg = cfg .input_linear
@@ -3625,6 +3623,10 @@ def set_attention_partition_specs(
3625
3623
input_linear_cfg .layer .param_partition_spec = (fsdp_axis_names , tp_axis_names , None )
3626
3624
cfg .output_linear .param_partition_spec = (fsdp_axis_names , tp_axis_names , None )
3627
3625
3626
+ if set_attn_activation_specs :
3627
+ cfg .q_partition_spec = (batch_axis_names , seq_axis_names , tp_axis_names , None )
3628
+ cfg .o_partition_spec = (batch_axis_names , seq_axis_names , tp_axis_names )
3629
+
3628
3630
3629
3631
def set_feed_forward_partition_specs (
3630
3632
cfg : TransformerFeedForwardLayer .Config ,
@@ -3638,10 +3640,7 @@ def set_feed_forward_partition_specs(
3638
3640
3639
3641
Args:
3640
3642
cfg: A TransformerFeedForwardLayer layer config to apply sharding spec to.
3641
- batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors.
3642
- fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
3643
- tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
3644
- seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors.
3643
+ **kwargs: See `set_double_shard_weights_config`.
3645
3644
"""
3646
3645
# Shard weights.
3647
3646
cfg .linear1 .param_partition_spec = (fsdp_axis_names , tp_axis_names )
@@ -3658,6 +3657,7 @@ def set_double_shard_weights_config(
3658
3657
fsdp_axis_names : Union [str , Sequence [str ]] = "fsdp" ,
3659
3658
tp_axis_names : Union [str , Sequence [str ]] = "model" ,
3660
3659
seq_axis_names : Union [str , Sequence [str ]] = "seq" ,
3660
+ set_attn_activation_specs : bool = False ,
3661
3661
):
3662
3662
"""Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes.
3663
3663
@@ -3667,32 +3667,35 @@ def set_double_shard_weights_config(
3667
3667
fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
3668
3668
tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
3669
3669
seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors.
3670
+ set_attn_activation_specs: Whether to set activation spec of qkvo projections. This may be
3671
+ required in for some complex sharding cases.
3670
3672
"""
3671
3673
3672
3674
# pytype: disable=attribute-error
3673
3675
if not isinstance (cfg , Sequence ):
3674
3676
cfg = [cfg ]
3675
3677
3678
+ axis_names = dict (
3679
+ batch_axis_names = batch_axis_names ,
3680
+ fsdp_axis_names = fsdp_axis_names ,
3681
+ tp_axis_names = tp_axis_names ,
3682
+ seq_axis_names = seq_axis_names ,
3683
+ )
3684
+
3676
3685
for layer_cfg in cfg :
3677
3686
set_attention_partition_specs (
3678
3687
layer_cfg .self_attention .attention ,
3679
- fsdp_axis_names = fsdp_axis_names ,
3680
- tp_axis_names = tp_axis_names ,
3688
+ set_attn_activation_specs = set_attn_activation_specs ,
3689
+ ** axis_names ,
3681
3690
)
3682
3691
if layer_cfg .cross_attention is not None :
3683
3692
set_attention_partition_specs (
3684
3693
layer_cfg .cross_attention .attention ,
3685
- fsdp_axis_names = fsdp_axis_names ,
3686
- tp_axis_names = tp_axis_names ,
3694
+ set_attn_activation_specs = set_attn_activation_specs ,
3695
+ ** axis_names ,
3687
3696
)
3688
3697
if isinstance (layer_cfg .feed_forward , TransformerFeedForwardLayer .Config ):
3689
- set_feed_forward_partition_specs (
3690
- layer_cfg .feed_forward ,
3691
- batch_axis_names = batch_axis_names ,
3692
- fsdp_axis_names = fsdp_axis_names ,
3693
- tp_axis_names = tp_axis_names ,
3694
- seq_axis_names = seq_axis_names ,
3695
- )
3698
+ set_feed_forward_partition_specs (layer_cfg .feed_forward , ** axis_names )
3696
3699
# pytype: enable=attribute-error
3697
3700
3698
3701
0 commit comments