|
3 | 3 | #include "module_base/global_variable.h"
|
4 | 4 | #include "module_base/tool_quit.h"
|
5 | 5 | #include "module_psi/kernels/device.h"
|
| 6 | +#include <type_traits> |
6 | 7 |
|
7 | 8 | #include <cassert>
|
8 | 9 | #include <complex>
|
@@ -163,11 +164,46 @@ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
|
163 | 164 | // this function will copy psi_in.psi to this->psi no matter the device types of each other.
|
164 | 165 | this->device = device::get_device_type<Device>(this->ctx);
|
165 | 166 | this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
|
166 |
| - memory::cast_memory_op<T, T_in, Device, Device_in>()(this->ctx, |
167 |
| - psi_in.get_device(), |
168 |
| - this->psi, |
169 |
| - psi_in.get_pointer() - psi_in.get_psi_bias(), |
170 |
| - psi_in.size()); |
| 167 | + // No need to cast the memory if the data types are the same. |
| 168 | + if (std::is_same<T, T_in>::value) |
| 169 | + { |
| 170 | + memory::synchronize_memory_op<T, Device, Device_in>()(this->ctx, |
| 171 | + psi_in.get_device(), |
| 172 | + this->psi, |
| 173 | + reinterpret_cast<T*>(psi_in.get_pointer()) - psi_in.get_psi_bias(), |
| 174 | + psi_in.size()); |
| 175 | + } |
| 176 | + // Specifically, if the Device_in type is CPU and the Device type is GPU: |
| 177 | + // Which means we need to initialize a GPU psi from a given CPU psi. |
| 178 | + // We first malloc a memory in CPU, then cast the memory from T_in to T in CPU. |
| 179 | + // Finally, synchronize the memory from CPU to GPU. |
| 180 | + // This could help to reduce the peak memory usage of device. |
| 181 | + else if (std::is_same<Device, DEVICE_GPU>::value && |
| 182 | + std::is_same<Device_in, DEVICE_CPU>::value) |
| 183 | + { |
| 184 | + auto * arr = (T*) malloc(sizeof(T) * psi_in.size()); |
| 185 | + // cast the memory from T_in to T in CPU |
| 186 | + memory::cast_memory_op<T, T_in, Device_in, Device_in>()(psi_in.get_device(), |
| 187 | + psi_in.get_device(), |
| 188 | + arr, |
| 189 | + psi_in.get_pointer() - psi_in.get_psi_bias(), |
| 190 | + psi_in.size()); |
| 191 | + // synchronize the memory from CPU to GPU |
| 192 | + memory::synchronize_memory_op<T, Device, Device_in>()(this->ctx, |
| 193 | + psi_in.get_device(), |
| 194 | + this->psi, |
| 195 | + arr, |
| 196 | + psi_in.size()); |
| 197 | + free(arr); |
| 198 | + } |
| 199 | + else |
| 200 | + { |
| 201 | + memory::cast_memory_op<T, T_in, Device, Device_in>()(this->ctx, |
| 202 | + psi_in.get_device(), |
| 203 | + this->psi, |
| 204 | + psi_in.get_pointer() - psi_in.get_psi_bias(), |
| 205 | + psi_in.size()); |
| 206 | + } |
171 | 207 | this->psi_bias = psi_in.get_psi_bias();
|
172 | 208 | this->current_nbasis = psi_in.get_current_nbas();
|
173 | 209 | this->psi_current = this->psi + psi_in.get_psi_bias();
|
|
0 commit comments