@@ -330,7 +330,7 @@ __global__ __launch_bounds__(
330
330
EP_DEVICE_ASSERT (num_sms > 1 );
331
331
if (sm_id == 0 ) {
332
332
// The first SM is also responsible for checking QPs
333
- EP_DEVICE_ASSERT (ibgda_get_state ()->num_rc_per_pe = = num_local_experts);
333
+ EP_DEVICE_ASSERT (ibgda_get_state ()->num_rc_per_pe > = num_local_experts);
334
334
335
335
// The first SM is also responsible for cleaning the next buffer
336
336
#pragma unroll
@@ -573,7 +573,8 @@ void dispatch(void* packed_recv_x,
573
573
use_fp8
574
574
? dispatch<true , kNumWarpGroups , kNumWarpsPerGroup , kHidden >
575
575
: dispatch<false , kNumWarpGroups , kNumWarpsPerGroup , kHidden >;
576
- SETUP_LAUNCH_CONFIG (num_sms, NUM_WARPS * 32 , stream);
576
+ SETUP_LAUNCH_CONFIG (
577
+ num_sms, kNumWarpGroups * kNumWarpsPerGroup * 32 , stream);
577
578
LAUNCH_KERNEL (&cfg,
578
579
dispatch_func,
579
580
packed_recv_x,
@@ -892,7 +893,8 @@ void combine(void* combined_x,
892
893
constexpr int kNumWarpsPerGroup = NUM_WARPS / kNumWarpGroups ;
893
894
auto combine_func =
894
895
combine<kNumWarpGroups , kNumWarpsPerGroup , kHidden , kNumMaxTopk >;
895
- SETUP_LAUNCH_CONFIG (num_sms, NUM_WARPS * 32 , stream);
896
+ SETUP_LAUNCH_CONFIG (
897
+ num_sms, kNumWarpGroups * kNumWarpsPerGroup * 32 , stream);
896
898
LAUNCH_KERNEL (&cfg,
897
899
combine_func,
898
900
combined_x,
0 commit comments