Skip to content

Commit 924cb7a

Browse files
committed
fix kernel_key for fused_gemm_epilogue_impl
1 parent d2d97e7 commit 924cb7a

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

paddle/phi/api/lib/api_custom_impl.cc

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,21 @@ std::tuple<Tensor, Tensor> fused_gemm_epilogue_impl(
240240
Backend kernel_backend = Backend::UNDEFINED;
241241
DataLayout kernel_layout = DataLayout::UNDEFINED;
242242
DataType kernel_data_type = DataType::UNDEFINED;
243+
if (kernel_backend == Backend::UNDEFINED ||
244+
kernel_layout == DataLayout::UNDEFINED ||
245+
kernel_data_type == DataType::UNDEFINED) {
246+
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y, bias);
247+
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
248+
if (kernel_backend == Backend::UNDEFINED) {
249+
kernel_backend = kernel_key.backend();
250+
}
251+
if (kernel_layout == DataLayout::UNDEFINED) {
252+
kernel_layout = kernel_key.layout();
253+
}
254+
if (kernel_data_type == DataType::UNDEFINED) {
255+
kernel_data_type = kernel_key.dtype();
256+
}
257+
}
243258
#ifdef PADDLE_WITH_DISTRIBUTE
244259
bool run_auto_parallel = AllInputsAreDistTensor(x, y, bias);
245260
bool rank_is_in_current_mesh = true;
@@ -250,23 +265,6 @@ std::tuple<Tensor, Tensor> fused_gemm_epilogue_impl(
250265
.process_mesh();
251266
rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh);
252267
}
253-
if (rank_is_in_current_mesh) {
254-
if (kernel_backend == Backend::UNDEFINED ||
255-
kernel_layout == DataLayout::UNDEFINED ||
256-
kernel_data_type == DataType::UNDEFINED) {
257-
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y, bias);
258-
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
259-
if (kernel_backend == Backend::UNDEFINED) {
260-
kernel_backend = kernel_key.backend();
261-
}
262-
if (kernel_layout == DataLayout::UNDEFINED) {
263-
kernel_layout = kernel_key.layout();
264-
}
265-
if (kernel_data_type == DataType::UNDEFINED) {
266-
kernel_data_type = kernel_key.dtype();
267-
}
268-
}
269-
}
270268

271269
// Kernel Dispatch Body
272270
// Auto Parallel condition

0 commit comments

Comments
 (0)