@@ -240,6 +240,21 @@ std::tuple<Tensor, Tensor> fused_gemm_epilogue_impl(
240
240
Backend kernel_backend = Backend::UNDEFINED;
241
241
DataLayout kernel_layout = DataLayout::UNDEFINED;
242
242
DataType kernel_data_type = DataType::UNDEFINED;
243
+ if (kernel_backend == Backend::UNDEFINED ||
244
+ kernel_layout == DataLayout::UNDEFINED ||
245
+ kernel_data_type == DataType::UNDEFINED) {
246
+ auto kernel_key_set = ParseKernelKeyByInputArgs (x, y, bias);
247
+ auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey ();
248
+ if (kernel_backend == Backend::UNDEFINED) {
249
+ kernel_backend = kernel_key.backend ();
250
+ }
251
+ if (kernel_layout == DataLayout::UNDEFINED) {
252
+ kernel_layout = kernel_key.layout ();
253
+ }
254
+ if (kernel_data_type == DataType::UNDEFINED) {
255
+ kernel_data_type = kernel_key.dtype ();
256
+ }
257
+ }
243
258
#ifdef PADDLE_WITH_DISTRIBUTE
244
259
bool run_auto_parallel = AllInputsAreDistTensor (x, y, bias);
245
260
bool rank_is_in_current_mesh = true ;
@@ -250,23 +265,6 @@ std::tuple<Tensor, Tensor> fused_gemm_epilogue_impl(
250
265
.process_mesh ();
251
266
rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh (mesh);
252
267
}
253
- if (rank_is_in_current_mesh) {
254
- if (kernel_backend == Backend::UNDEFINED ||
255
- kernel_layout == DataLayout::UNDEFINED ||
256
- kernel_data_type == DataType::UNDEFINED) {
257
- auto kernel_key_set = ParseKernelKeyByInputArgs (x, y, bias);
258
- auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey ();
259
- if (kernel_backend == Backend::UNDEFINED) {
260
- kernel_backend = kernel_key.backend ();
261
- }
262
- if (kernel_layout == DataLayout::UNDEFINED) {
263
- kernel_layout = kernel_key.layout ();
264
- }
265
- if (kernel_data_type == DataType::UNDEFINED) {
266
- kernel_data_type = kernel_key.dtype ();
267
- }
268
- }
269
- }
270
268
271
269
// Kernel Dispatch Body
272
270
// Auto Parallel condition
0 commit comments