Skip to content

Commit 898afb2

Browse files
authored
Fix: address single precision error of dcu (#4201)
* fix single precision error or dcu. * fix CI test error
1 parent d6c5d43 commit 898afb2

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

source/module_hsolver/kernels/cuda/math_kernel_op.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "module_hsolver/kernels/math_kernel_op.h"
33
#include "module_psi/kernels/memory_op.h"
44
#include "module_psi/psi.h"
5+
#include "module_base/tool_quit.h"
56

67
#include <base/macros/macros.h>
78
#include <cuda_runtime.h>
@@ -669,6 +670,9 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
669670
else if (trans == 'T') {
670671
cutrans = CUBLAS_OP_T;
671672
}
673+
else {
674+
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
675+
}
672676
cublasErrcheck(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
673677
}
674678

@@ -696,6 +700,9 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
696700
else if (trans == 'C'){
697701
cutrans = CUBLAS_OP_C;
698702
}
703+
else {
704+
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
705+
}
699706
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)alpha, (float2*)A, lda, (float2*)X, incx, (float2*)beta, (float2*)Y, incx));
700707
}
701708

@@ -723,6 +730,9 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
723730
else if (trans == 'C'){
724731
cutrans = CUBLAS_OP_C;
725732
}
733+
else {
734+
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
735+
}
726736
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)alpha, (double2*)A, lda, (double2*)X, incx, (double2*)beta, (double2*)Y, incx));
727737
}
728738

@@ -771,13 +781,19 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
771781
else if (transa == 'T') {
772782
cutransA = CUBLAS_OP_T;
773783
}
784+
else {
785+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
786+
}
774787
// cutransB
775788
if (transb == 'N') {
776789
cutransB = CUBLAS_OP_N;
777790
}
778791
else if (transb == 'T') {
779792
cutransB = CUBLAS_OP_T;
780793
}
794+
else {
795+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
796+
}
781797
cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
782798
}
783799
template <>
@@ -808,6 +824,9 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
808824
else if (transa == 'C'){
809825
cutransA = CUBLAS_OP_C;
810826
}
827+
else {
828+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
829+
}
811830
// cutransB
812831
if (transb == 'N'){
813832
cutransB = CUBLAS_OP_N;
@@ -818,6 +837,9 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
818837
else if (transb == 'C'){
819838
cutransB = CUBLAS_OP_C;
820839
}
840+
else {
841+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
842+
}
821843
cublasErrcheck(cublasCgemm(cublas_handle, cutransA, cutransB, m, n ,k, (float2*)alpha, (float2*)a , lda, (float2*)b, ldb, (float2*)beta, (float2*)c, ldc));
822844
}
823845

@@ -849,6 +871,9 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
849871
else if (transa == 'C'){
850872
cutransA = CUBLAS_OP_C;
851873
}
874+
else {
875+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
876+
}
852877
// cutransB
853878
if (transb == 'N'){
854879
cutransB = CUBLAS_OP_N;
@@ -859,6 +884,9 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
859884
else if (transb == 'C'){
860885
cutransB = CUBLAS_OP_C;
861886
}
887+
else {
888+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
889+
}
862890
cublasErrcheck(cublasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (double2*)alpha, (double2*)a , lda, (double2*)b, ldb, (double2*)beta, (double2*)c, ldc));
863891
}
864892

source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "module_hsolver/kernels/math_kernel_op.h"
33
#include "module_psi/kernels/memory_op.h"
44
#include "module_psi/psi.h"
5+
#include "module_base/tool_quit.h"
56

67
#include <base/macros/macros.h>
78
#include <hip/hip_runtime.h>
@@ -676,6 +677,9 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
676677
else if (trans == 'C') {
677678
cutrans = HIPBLAS_OP_C;
678679
}
680+
else {
681+
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
682+
}
679683
hipblasErrcheck(hipblasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
680684
}
681685

@@ -694,12 +698,18 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
694698
const int& incy)
695699
{
696700
hipblasOperation_t cutrans = {};
697-
if (trans == 'N'){
701+
if (trans == 'N') {
698702
cutrans = HIPBLAS_OP_N;
699703
}
700-
else if (trans == 'T'){
704+
else if (trans == 'T') {
701705
cutrans = HIPBLAS_OP_T;
702706
}
707+
else if (trans == 'C') {
708+
cutrans = HIPBLAS_OP_C;
709+
}
710+
else {
711+
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
712+
}
703713
hipblasErrcheck(hipblasCgemv(cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incx));
704714
}
705715

@@ -727,6 +737,9 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
727737
else if (trans == 'C'){
728738
cutrans = HIPBLAS_OP_C;
729739
}
740+
else {
741+
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
742+
}
730743
hipblasErrcheck(hipblasZgemv(cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incx));
731744
}
732745

@@ -775,13 +788,19 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
775788
else if (transa == 'T') {
776789
cutransA = HIPBLAS_OP_T;
777790
}
791+
else {
792+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
793+
}
778794
// cutransB
779795
if (transb == 'N') {
780796
cutransB = HIPBLAS_OP_N;
781797
}
782798
else if (transb == 'T') {
783799
cutransB = HIPBLAS_OP_T;
784800
}
801+
else {
802+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
803+
}
785804
hipblasErrcheck(hipblasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
786805
}
787806

@@ -812,6 +831,9 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
812831
}
813832
else if (transa == 'C'){
814833
cutransA = HIPBLAS_OP_C;
834+
}
835+
else {
836+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
815837
}
816838
// cutransB
817839
if (transb == 'N'){
@@ -823,6 +845,9 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
823845
else if (transb == 'C'){
824846
cutransB = HIPBLAS_OP_C;
825847
}
848+
else {
849+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
850+
}
826851
hipblasErrcheck(hipblasCgemm(cublas_handle, cutransA, cutransB, m, n ,k, (hipblasComplex*)alpha, (hipblasComplex*)a , lda, (hipblasComplex*)b, ldb, (hipblasComplex*)beta, (hipblasComplex*)c, ldc));
827852
}
828853

@@ -853,6 +878,9 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
853878
}
854879
else if (transa == 'C'){
855880
cutransA = HIPBLAS_OP_C;
881+
}
882+
else {
883+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
856884
}
857885
// cutransB
858886
if (transb == 'N'){
@@ -864,6 +892,9 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
864892
else if (transb == 'C'){
865893
cutransB = HIPBLAS_OP_C;
866894
}
895+
else {
896+
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
897+
}
867898
hipblasErrcheck(hipblasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)a , lda, (hipblasDoubleComplex*)b, ldb, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)c, ldc));
868899
}
869900

0 commit comments

Comments
 (0)