@@ -143,24 +143,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
143
143
attn_metadata = get_forward_context ().attn_metadata
144
144
if attn_metadata is None :
145
145
# for profile run
146
- return hidden_states
146
+ is_prefill = True
147
+ else :
148
+ is_prefill = attn_metadata .num_prefills > 0
147
149
num_tokens , hidden_dim = hidden_states .shape
148
150
hidden_states = hidden_states .view (- 1 , hidden_dim )
149
151
150
152
if self .n_shared_experts is not None :
151
153
shared_output = self .shared_experts (hidden_states )
152
154
153
- if (self .tp_size > 1 and self .enable_mc2
154
- and attn_metadata .num_prefills == 0 ):
155
+ if (self .tp_size > 1 and self .enable_mc2 and not is_prefill ):
155
156
chunks = torch .chunk (hidden_states ,
156
157
get_tp_group ().world_size ,
157
158
dim = 0 )
158
159
hidden_states = chunks [get_tp_group ().rank_in_group ]
159
160
160
161
# router_logits: (num_tokens, n_experts)
161
162
router_logits , _ = self .gate (hidden_states )
162
- is_prefill = True if attn_metadata .num_prefills > 0 else False
163
- # is_prefill = attn_metadata.num_prefills > 0
163
+
164
164
final_hidden_states = self .experts (
165
165
hidden_states = hidden_states ,
166
166
router_logits = router_logits ,
0 commit comments