Skip to content

Commit 9abeccc

Browse files
authored
reduce peak device memory usage of psi's constructor (#4154)
1 parent b7e91aa commit 9abeccc

File tree

1 file changed

+41
-5
lines changed

1 file changed

+41
-5
lines changed

source/module_psi/psi.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "module_base/global_variable.h"
44
#include "module_base/tool_quit.h"
55
#include "module_psi/kernels/device.h"
6+
#include <type_traits>
67

78
#include <cassert>
89
#include <complex>
@@ -163,11 +164,46 @@ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
163164
// this function will copy psi_in.psi to this->psi no matter the device types of each other.
164165
this->device = device::get_device_type<Device>(this->ctx);
165166
this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
166-
memory::cast_memory_op<T, T_in, Device, Device_in>()(this->ctx,
167-
psi_in.get_device(),
168-
this->psi,
169-
psi_in.get_pointer() - psi_in.get_psi_bias(),
170-
psi_in.size());
167+
// No need to cast the memory if the data types are the same.
168+
if (std::is_same<T, T_in>::value)
169+
{
170+
memory::synchronize_memory_op<T, Device, Device_in>()(this->ctx,
171+
psi_in.get_device(),
172+
this->psi,
173+
reinterpret_cast<T*>(psi_in.get_pointer()) - psi_in.get_psi_bias(),
174+
psi_in.size());
175+
}
176+
// Specifically, if the Device_in type is CPU and the Device type is GPU:
177+
// Which means we need to initialize a GPU psi from a given CPU psi.
178+
// We first malloc a memory in CPU, then cast the memory from T_in to T in CPU.
179+
// Finally, synchronize the memory from CPU to GPU.
180+
// This could help to reduce the peak memory usage of device.
181+
else if (std::is_same<Device, DEVICE_GPU>::value &&
182+
std::is_same<Device_in, DEVICE_CPU>::value)
183+
{
184+
auto * arr = (T*) malloc(sizeof(T) * psi_in.size());
185+
// cast the memory from T_in to T in CPU
186+
memory::cast_memory_op<T, T_in, Device_in, Device_in>()(psi_in.get_device(),
187+
psi_in.get_device(),
188+
arr,
189+
psi_in.get_pointer() - psi_in.get_psi_bias(),
190+
psi_in.size());
191+
// synchronize the memory from CPU to GPU
192+
memory::synchronize_memory_op<T, Device, Device_in>()(this->ctx,
193+
psi_in.get_device(),
194+
this->psi,
195+
arr,
196+
psi_in.size());
197+
free(arr);
198+
}
199+
else
200+
{
201+
memory::cast_memory_op<T, T_in, Device, Device_in>()(this->ctx,
202+
psi_in.get_device(),
203+
this->psi,
204+
psi_in.get_pointer() - psi_in.get_psi_bias(),
205+
psi_in.size());
206+
}
171207
this->psi_bias = psi_in.get_psi_bias();
172208
this->current_nbasis = psi_in.get_current_nbas();
173209
this->psi_current = this->psi + psi_in.get_psi_bias();

0 commit comments

Comments
 (0)