@@ -31,6 +31,13 @@ struct DeviceIdHandler {
31
31
FatalErrorInFunction << " Only parallel runs are supported for OGL"
32
32
<< exit (FatalError);
33
33
}
34
+
35
+ if (Pstream::nProcs (0 ) % ranks_per_gpu != 0 ) {
36
+ FatalErrorInFunction
37
+ << " Total number of ranks = " << Pstream::nProcs (0 )
38
+ << " is not a multiple of "
39
+ << " ranksPerGPU " << ranks_per_gpu << exit (FatalError);
40
+ }
34
41
}
35
42
36
43
/* @brief compute the local device id
@@ -102,7 +109,6 @@ struct ExecutorInitFunctor {
102
109
{
103
110
auto host_exec = gko::share (gko::ReferenceExecutor::create ());
104
111
105
-
106
112
auto msg = [](auto exec, auto id) {
107
113
std::string s;
108
114
// auto node_comm = Pstream::commInterHost();
@@ -111,16 +117,17 @@ struct ExecutorInitFunctor {
111
117
label global_ranks = Pstream::nProcs (0 );
112
118
label device_ranks = Pstream::nProcs (node_comm);
113
119
label node_id = global_ranks / device_ranks;
120
+
114
121
// Pstream::barrier(0);
115
122
// sleep(0.03 * global_rank);
116
123
s += std::string (" Create " ) + std::string (exec) +
117
124
std::string (" executor device " ) + std::to_string (id) +
118
125
std::string (" node " ) + std::to_string (node_id) +
119
126
std::string (" local rank [" ) +
120
127
std::to_string (Pstream::myProcNo (node_comm)) +
121
- std::string (" /" ) + std::to_string (device_ranks) +
128
+ std::string (" /" ) + std::to_string (device_ranks - 1 ) +
122
129
std::string (" ] global rank [" ) + std::to_string (global_rank) +
123
- std::string (" /" ) + std::to_string (global_ranks) +
130
+ std::string (" /" ) + std::to_string (global_ranks - 1 ) +
124
131
std::string (" ]" );
125
132
return s;
126
133
};
@@ -135,7 +142,8 @@ struct ExecutorInitFunctor {
135
142
label id = device_id_handler_.compute_device_id (
136
143
gko::CudaExecutor::get_num_devices ());
137
144
LOG_0 (verbose_, msg (executor_name_, id))
138
- return gko::share (gko::CudaExecutor::create (id, host_exec));
145
+ auto ret = gko::share (gko::CudaExecutor::create (id, host_exec));
146
+ return ret;
139
147
}
140
148
if (executor_name_ == " sycl" || executor_name_ == " dpcpp" ) {
141
149
if (version.dpcpp_version .tag == not_compiled_tag) {
0 commit comments