8
8
import vllm .model_executor .layers .fused_moe .modular_kernel as mk
9
9
from vllm .logger import init_logger
10
10
from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
11
- from vllm .model_executor .layers .fused_moe .moe_permute_unpermute import (
12
- _moe_permute )
11
+ from vllm .model_executor .layers .fused_moe .deep_gemm_utils import (
12
+ compute_aligned_M , deepgemm_moe_permute , deepgemm_unpermute_and_reduce )
13
13
from vllm .model_executor .layers .fused_moe .prepare_finalize import (
14
14
MoEPrepareAndFinalizeNoEP )
15
15
from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
16
- TopKWeightAndReduceContiguous , TopKWeightAndReduceNoOP )
16
+ TopKWeightAndReduceNoOP )
17
17
from vllm .model_executor .layers .fused_moe .utils import _resize_cache
18
18
from vllm .model_executor .layers .quantization .utils .fp8_utils import (
19
19
per_token_group_quant_fp8 )
20
- from vllm .utils import has_deep_gemm , round_up
20
+ from vllm .utils import has_deep_gemm
21
21
from vllm .utils .deep_gemm import m_grouped_fp8_gemm_nt_contiguous
22
22
23
23
logger = init_logger (__name__ )
@@ -93,18 +93,25 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
93
93
return TopKWeightAndReduceNoOP ()
94
94
95
95
def workspace_shapes (
96
- self , a : torch .Tensor , aq : torch .Tensor , M : int , N : int , K : int ,
97
- topk : int , global_num_experts : int , local_num_experts : int
96
+ self ,
97
+ a : torch .Tensor ,
98
+ aq : torch .Tensor ,
99
+ M : int ,
100
+ N : int ,
101
+ K : int ,
102
+ topk : int ,
103
+ global_num_experts : int ,
104
+ local_num_experts : int ,
105
+ expert_tokens_meta : Optional [mk .ExpertTokensMetadata ],
98
106
) -> tuple [tuple [int , ...], tuple [int , ...], tuple [int , ...], torch .dtype ]:
99
107
assert self .block_shape is not None
100
- # We use global_num_experts due to how moe_align_block_size handles
101
- # expert_maps.
102
- num_experts = global_num_experts
103
108
block_m = self .block_shape [0 ]
104
- M_sum = (M * topk ) + num_experts * (block_m - 1 )
105
- M_sum = round_up (M_sum , block_m )
106
- workspace1 = (M_sum , max (N // 2 , K ))
107
- workspace2 = (M_sum , max (N , K ))
109
+ M_sum = compute_aligned_M (M , topk , local_num_experts , block_m ,
110
+ expert_tokens_meta )
111
+ assert M_sum % block_m == 0
112
+
113
+ workspace1 = (M_sum , max (N , K ))
114
+ workspace2 = (M_sum , max (N // 2 , K ))
108
115
output = (M , K )
109
116
return (workspace1 , workspace2 , output , a .dtype )
110
117
@@ -131,43 +138,40 @@ def apply(
131
138
apply_router_weight_on_input : bool ,
132
139
):
133
140
assert self .block_shape is not None
141
+ assert a1q_scale is not None
134
142
135
143
a1q = hidden_states
136
144
_ , N , K = w1 .size ()
137
- M , _ = output .size ()
138
- num_topk = topk_ids .size (1 )
139
145
146
+ local_num_experts = w1 .size (0 )
140
147
if global_num_experts == - 1 :
141
- global_num_experts = w1 . size ( 0 )
148
+ global_num_experts = local_num_experts
142
149
143
150
assert w2 .size (1 ) == K
144
151
145
- a1q , a1q_scale , _ , expert_ids , inv_perm = _moe_permute (
146
- a1q ,
147
- a1q_scale ,
148
- topk_ids ,
149
- global_num_experts ,
150
- expert_map ,
151
- self .block_shape [0 ],
152
- )
153
-
154
- if expert_map is not None :
155
- # DeepGemm (Grouped Contiguous) kernel needs a valid B index
156
- # for all rows of A. To that effect, simply compute with
157
- # the 0th weight matrix.
158
- # Note that this relies on the fact that corresponding topk
159
- # weights would be 0 during weight multiplication.
160
- expert_ids = torch .where (expert_ids == - 1 , 0 , expert_ids )
161
-
162
- # Note: M_sum is different than the pre-permuted shape of a1q.
163
- M_sum = a1q .size (0 )
164
-
165
- mm1_out = _resize_cache (workspace2 , (M_sum , N ))
166
- act_out = _resize_cache (workspace13 , (M_sum , N // 2 ))
167
- quant_out = _resize_cache (workspace2 .view (dtype = torch .float8_e4m3fn ),
152
+ M_sum = compute_aligned_M (M = topk_ids .size (0 ),
153
+ num_topk = topk_ids .size (1 ),
154
+ local_num_experts = local_num_experts ,
155
+ alignment = deep_gemm_block_shape ()[0 ],
156
+ expert_tokens_meta = expert_tokens_meta )
157
+
158
+ a1q_perm = _resize_cache (workspace2 .view (dtype = torch .float8_e4m3fn ),
159
+ (M_sum , K ))
160
+ mm1_out = _resize_cache (workspace13 , (M_sum , N ))
161
+ act_out = _resize_cache (workspace2 , (M_sum , N // 2 ))
162
+ quant_out = _resize_cache (workspace13 .view (dtype = torch .float8_e4m3fn ),
168
163
(M_sum , N // 2 ))
169
- mm2_out = _resize_cache (workspace13 , (M_sum , K ))
170
- perm_out = _resize_cache (workspace2 , (M * num_topk , K ))
164
+ mm2_out = _resize_cache (workspace2 , (M_sum , K ))
165
+
166
+ a1q , a1q_scale , expert_ids , inv_perm = deepgemm_moe_permute (
167
+ aq = a1q ,
168
+ aq_scale = a1q_scale ,
169
+ topk_ids = topk_ids ,
170
+ local_num_experts = local_num_experts ,
171
+ expert_map = expert_map ,
172
+ expert_tokens_meta = expert_tokens_meta ,
173
+ aq_out = a1q_perm )
174
+ assert a1q .size (0 ) == M_sum
171
175
172
176
m_grouped_fp8_gemm_nt_contiguous ((a1q , a1q_scale ), (w1 , w1_scale ),
173
177
mm1_out , expert_ids )
@@ -183,14 +187,15 @@ def apply(
183
187
m_grouped_fp8_gemm_nt_contiguous ((a2q , a2q_scale ), (w2 , w2_scale ),
184
188
mm2_out , expert_ids )
185
189
186
- torch .index_select (mm2_out , 0 , inv_perm , out = perm_out )
190
+ if apply_router_weight_on_input :
191
+ topk_weights = torch .ones_like (topk_weights )
187
192
188
- TopKWeightAndReduceContiguous (). apply (
189
- output = output ,
190
- fused_expert_output = perm_out ,
191
- topk_weights = topk_weights ,
192
- topk_ids = topk_ids ,
193
- apply_router_weight_on_input = apply_router_weight_on_input )
193
+ deepgemm_unpermute_and_reduce ( a = mm2_out ,
194
+ topk_ids = topk_ids ,
195
+ topk_weights = topk_weights ,
196
+ inv_perm = inv_perm ,
197
+ expert_map = expert_map ,
198
+ output = output )
194
199
195
200
196
201
def deep_gemm_moe_fp8 (
0 commit comments