Skip to content

Commit afd9ef0

Browse files
committed
[Inference All2All] add dispatch of num_experts and num_rdma_ranks
1 parent ce7f373 commit afd9ef0

File tree

5 files changed

+271
-251
lines changed

5 files changed

+271
-251
lines changed

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ Buffer::Buffer(int rank,
7676

7777
// Common checks
7878
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
79-
(num_nvl_bytes <= std::numeric_limits<int>::max() ||
79+
(num_nvl_bytes <= std::numeric_limits<int64_t>::max() ||
8080
num_rdma_bytes == 0));
8181
EP_HOST_ASSERT(
8282
num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&

paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ __global__ __launch_bounds__(
330330
EP_DEVICE_ASSERT(num_sms > 1);
331331
if (sm_id == 0) {
332332
// The first SM is also responsible for checking QPs
333-
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts);
333+
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts);
334334

335335
// The first SM is also responsible for cleaning the next buffer
336336
#pragma unroll
@@ -573,7 +573,8 @@ void dispatch(void* packed_recv_x,
573573
use_fp8
574574
? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, kHidden>
575575
: dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, kHidden>;
576-
SETUP_LAUNCH_CONFIG(num_sms, NUM_WARPS * 32, stream);
576+
SETUP_LAUNCH_CONFIG(
577+
num_sms, kNumWarpGroups * kNumWarpsPerGroup * 32, stream);
577578
LAUNCH_KERNEL(&cfg,
578579
dispatch_func,
579580
packed_recv_x,
@@ -892,7 +893,8 @@ void combine(void* combined_x,
892893
constexpr int kNumWarpsPerGroup = NUM_WARPS / kNumWarpGroups;
893894
auto combine_func =
894895
combine<kNumWarpGroups, kNumWarpsPerGroup, kHidden, kNumMaxTopk>;
895-
SETUP_LAUNCH_CONFIG(num_sms, NUM_WARPS * 32, stream);
896+
SETUP_LAUNCH_CONFIG(
897+
num_sms, kNumWarpGroups * kNumWarpsPerGroup * 32, stream);
896898
LAUNCH_KERNEL(&cfg,
897899
combine_func,
898900
combined_x,

0 commit comments

Comments
 (0)