From af21149eea31548ce91af2e47145c0729216abdd Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Fri, 2 May 2025 14:08:23 +0200 Subject: [PATCH 1/2] coll/ucc: refactor UCC collective operations to handle MPI_IN_PLACE correctly Updated the initialization functions for allgather, allgatherv, alltoall, alltoallv, gather, gatherv, scatter, and scatterv to improve handling of the MPI_IN_PLACE argument. In case of MPI_IN_PLACE for these collectives corresponding datatype and count should be ignored. Signed-off-by: Sergey Lebedev --- ompi/mca/coll/ucc/coll_ucc_allgather.c | 13 +++++++---- ompi/mca/coll/ucc/coll_ucc_allgatherv.c | 10 ++++++--- ompi/mca/coll/ucc/coll_ucc_alltoall.c | 13 +++++++---- ompi/mca/coll/ucc/coll_ucc_alltoallv.c | 10 ++++++--- ompi/mca/coll/ucc/coll_ucc_gather.c | 30 ++++++++++++++++--------- ompi/mca/coll/ucc/coll_ucc_gatherv.c | 18 +++++++++------ ompi/mca/coll/ucc/coll_ucc_scatter.c | 22 ++++++++++++++---- ompi/mca/coll/ucc/coll_ucc_scatterv.c | 15 ++++++++----- 8 files changed, 89 insertions(+), 42 deletions(-) diff --git a/ompi/mca/coll/ucc/coll_ucc_allgather.c b/ompi/mca/coll/ucc/coll_ucc_allgather.c index 3312d818bf3..2dd3ac68a55 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgather.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgather.c @@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); - if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount) || + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { goto fallback; } - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -50,7 +55,7 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c index 3190eceb857..bcfb3154514 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c @@ -17,10 +17,14 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -52,7 +56,7 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoall.c b/ompi/mca/coll/ucc/coll_ucc_alltoall.c index 275f8d67640..cfb56f47418 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoall.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoall.c @@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); - if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size) || + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { goto fallback; } - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -50,7 +55,7 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c index 728636379b8..116166a6080 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c @@ -17,10 +17,14 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -54,7 +58,7 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_gather.c b/ompi/mca/coll/ucc/coll_ucc_gather.c index 2b4b2e05474..ba91b40b189 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gather.c +++ b/ompi/mca/coll/ucc/coll_ucc_gather.c @@ -17,27 +17,35 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); int comm_size = ompi_comm_size(ucc_module->comm); - if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) { - goto fallback; - } - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (comm_rank == root) { - if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || + !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { goto fallback; } + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); - if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) || - (MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) { + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || + (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", - (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ? - rdtype->super.name : sdtype->super.name); + (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? + sdtype->super.name : rdtype->super.name); goto fallback; } } else { + if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) { + goto fallback; + } + + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", sdtype->super.name); @@ -64,7 +72,7 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om }, }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_gatherv.c b/ompi/mca/coll/ucc/coll_ucc_gatherv.c index b53d69c9b76..05620b47075 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_gatherv.c @@ -17,20 +17,24 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (comm_rank == root) { ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); - if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) || - (MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) { + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || + (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", - (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ? - rdtype->super.name : sdtype->super.name); + (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? + sdtype->super.name : rdtype->super.name); goto fallback; } } else { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", sdtype->super.name); @@ -62,7 +66,7 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc }, }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_scatter.c b/ompi/mca/coll/ucc/coll_ucc_scatter.c index 55351b8151b..481365f22bd 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatter.c @@ -18,21 +18,35 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == rbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); int comm_size = ompi_comm_size(ucc_module->comm); - ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (comm_rank == root) { + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) || + !ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) { + goto fallback; + } + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + if (!is_inplace) { + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + } + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || - (MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { + (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? sdtype->super.name : rdtype->super.name); goto fallback; } } else { + if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) { + goto fallback; + } + + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", rdtype->super.name); @@ -59,7 +73,7 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, }, }; - if (MPI_IN_PLACE == rbuf) { + if (is_inplace) { coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_scatterv.c b/ompi/mca/coll/ucc/coll_ucc_scatterv.c index 08d7dee60aa..27ff64aa097 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatterv.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatterv.c @@ -18,22 +18,25 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == rbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); - int comm_size = ompi_comm_size(ucc_module->comm); - ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (comm_rank == root) { ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + if (!is_inplace) { + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + } + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || - (MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { + (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? sdtype->super.name : rdtype->super.name); goto fallback; } - } else { + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", rdtype->super.name); @@ -65,7 +68,7 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco }, }; - if (MPI_IN_PLACE == rbuf) { + if (is_inplace) { coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; } From 887e7afd42e763b0871dd75b84771f7b42d9a63b Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Fri, 2 May 2025 16:56:37 +0200 Subject: [PATCH 2/2] coll/ucc: fix bigcount support fixing support of bigcount in UCC coll component, coll flags were not set correctly Signed-off-by: Sergey Lebedev --- ompi/mca/coll/ucc/coll_ucc_allgatherv.c | 17 +++++++---------- ompi/mca/coll/ucc/coll_ucc_alltoallv.c | 15 ++++++--------- ompi/mca/coll/ucc/coll_ucc_gatherv.c | 13 +++++-------- ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c | 8 +++++--- ompi/mca/coll/ucc/coll_ucc_scatterv.c | 14 +++++--------- 5 files changed, 28 insertions(+), 39 deletions(-) diff --git a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c index bcfb3154514..68e786e0c2a 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c @@ -9,7 +9,7 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int scount, +static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, struct ompi_datatype_t *rdtype, @@ -19,12 +19,13 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); + uint64_t flags = 0; ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (!is_inplace) { ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); } - + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -33,13 +34,13 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc goto fallback; } - uint64_t flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0; - flags |= ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0; + flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | + (ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); ucc_coll_args_t coll = { + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, .flags = flags, - .mask = 0, - .flags = 0, .coll_type = UCC_COLL_TYPE_ALLGATHERV, .src.info = { .buffer = (void*)sbuf, @@ -56,10 +57,6 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc } }; - if (is_inplace) { - coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c index 116166a6080..1e9e311cf94 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c @@ -19,12 +19,13 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); + uint64_t flags = 0; ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (!is_inplace) { ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); } - + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -34,13 +35,13 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co } /* Assumes that send counts/displs and recv counts/displs are both 32-bit or both 64-bit */ - uint64_t flags = ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0; - flags |= ompi_disp_array_is_64bit(sdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0; + flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | + (ompi_disp_array_is_64bit(sdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); ucc_coll_args_t coll = { + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, .flags = flags, - .mask = 0, - .flags = 0, .coll_type = UCC_COLL_TYPE_ALLTOALLV, .src.info_v = { .buffer = (void*)sbuf, @@ -58,10 +59,6 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co } }; - if (is_inplace) { - coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_gatherv.c b/ompi/mca/coll/ucc/coll_ucc_gatherv.c index 05620b47075..5a1da52356c 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_gatherv.c @@ -20,6 +20,7 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); + uint64_t flags = 0; if (comm_rank == root) { ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); @@ -42,13 +43,13 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc } } - uint64_t flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0; - flags |= ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0; + flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | + (ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); ucc_coll_args_t coll = { + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, .flags = flags, - .mask = 0, - .flags = 0, .coll_type = UCC_COLL_TYPE_GATHERV, .root = root, .src.info = { @@ -66,10 +67,6 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc }, }; - if (is_inplace) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c index 0df89160e8a..dabc8f11d03 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c @@ -21,6 +21,7 @@ ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi size_t total_count; int i; int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (MPI_IN_PLACE == sbuf) { /* TODO: UCC defines inplace differently: @@ -46,10 +47,11 @@ ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi total_count += ompi_count_array_get(rcounts, i); } + flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0); + ucc_coll_args_t coll = { - .flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0, - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_REDUCE_SCATTERV, .src.info = { .buffer = (void*)sbuf, diff --git a/ompi/mca/coll/ucc/coll_ucc_scatterv.c b/ompi/mca/coll/ucc/coll_ucc_scatterv.c index 27ff64aa097..36d4086a113 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatterv.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatterv.c @@ -21,7 +21,7 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == rbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); - + uint64_t flags = 0; if (comm_rank == root) { ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (!is_inplace) { @@ -44,13 +44,13 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco } } - uint64_t flags = ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0; - flags |= ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0; + flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | + (ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); ucc_coll_args_t coll = { + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, .flags = flags, - .mask = 0, - .flags = 0, .coll_type = UCC_COLL_TYPE_SCATTERV, .root = root, .src.info_v = { @@ -68,10 +68,6 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco }, }; - if (is_inplace) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: