2
2
#include " module_hsolver/kernels/math_kernel_op.h"
3
3
#include " module_psi/kernels/memory_op.h"
4
4
#include " module_psi/psi.h"
5
+ #include " module_base/tool_quit.h"
5
6
6
7
#include < base/macros/macros.h>
7
8
#include < hip/hip_runtime.h>
@@ -676,6 +677,9 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
676
677
else if (trans == ' C' ) {
677
678
cutrans = HIPBLAS_OP_C;
678
679
}
680
+ else {
681
+ ModuleBase::WARNING_QUIT (" gemv_op" , std::string (" Unknown trans type " ) + trans + std::string (" !" ));
682
+ }
679
683
hipblasErrcheck (hipblasDgemv (cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
680
684
}
681
685
@@ -694,12 +698,18 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
694
698
const int & incy)
695
699
{
696
700
hipblasOperation_t cutrans = {};
697
- if (trans == ' N' ){
701
+ if (trans == ' N' ) {
698
702
cutrans = HIPBLAS_OP_N;
699
703
}
700
- else if (trans == ' T' ){
704
+ else if (trans == ' T' ) {
701
705
cutrans = HIPBLAS_OP_T;
702
706
}
707
+ else if (trans == ' C' ) {
708
+ cutrans = HIPBLAS_OP_C;
709
+ }
710
+ else {
711
+ ModuleBase::WARNING_QUIT (" gemv_op" , std::string (" Unknown trans type " ) + trans + std::string (" !" ));
712
+ }
703
713
hipblasErrcheck (hipblasCgemv (cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incx));
704
714
}
705
715
@@ -727,6 +737,9 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
727
737
else if (trans == ' C' ){
728
738
cutrans = HIPBLAS_OP_C;
729
739
}
740
+ else {
741
+ ModuleBase::WARNING_QUIT (" gemv_op" , std::string (" Unknown trans type " ) + trans + std::string (" !" ));
742
+ }
730
743
hipblasErrcheck (hipblasZgemv (cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incx));
731
744
}
732
745
@@ -775,13 +788,19 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
775
788
else if (transa == ' T' ) {
776
789
cutransA = HIPBLAS_OP_T;
777
790
}
791
+ else {
792
+ ModuleBase::WARNING_QUIT (" gemm_op" , std::string (" Unknown transa type " ) + transa + std::string (" !" ));
793
+ }
778
794
// cutransB
779
795
if (transb == ' N' ) {
780
796
cutransB = HIPBLAS_OP_N;
781
797
}
782
798
else if (transb == ' T' ) {
783
799
cutransB = HIPBLAS_OP_T;
784
800
}
801
+ else {
802
+ ModuleBase::WARNING_QUIT (" gemm_op" , std::string (" Unknown transb type " ) + transb + std::string (" !" ));
803
+ }
785
804
hipblasErrcheck (hipblasDgemm (cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
786
805
}
787
806
@@ -812,6 +831,9 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
812
831
}
813
832
else if (transa == ' C' ){
814
833
cutransA = HIPBLAS_OP_C;
834
+ }
835
+ else {
836
+ ModuleBase::WARNING_QUIT (" gemm_op" , std::string (" Unknown transa type " ) + transa + std::string (" !" ));
815
837
}
816
838
// cutransB
817
839
if (transb == ' N' ){
@@ -823,6 +845,9 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
823
845
else if (transb == ' C' ){
824
846
cutransB = HIPBLAS_OP_C;
825
847
}
848
+ else {
849
+ ModuleBase::WARNING_QUIT (" gemm_op" , std::string (" Unknown transb type " ) + transb + std::string (" !" ));
850
+ }
826
851
hipblasErrcheck (hipblasCgemm (cublas_handle, cutransA, cutransB, m, n ,k, (hipblasComplex*)alpha, (hipblasComplex*)a , lda, (hipblasComplex*)b, ldb, (hipblasComplex*)beta, (hipblasComplex*)c, ldc));
827
852
}
828
853
@@ -853,6 +878,9 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
853
878
}
854
879
else if (transa == ' C' ){
855
880
cutransA = HIPBLAS_OP_C;
881
+ }
882
+ else {
883
+ ModuleBase::WARNING_QUIT (" gemm_op" , std::string (" Unknown transa type " ) + transa + std::string (" !" ));
856
884
}
857
885
// cutransB
858
886
if (transb == ' N' ){
@@ -864,6 +892,9 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
864
892
else if (transb == ' C' ){
865
893
cutransB = HIPBLAS_OP_C;
866
894
}
895
+ else {
896
+ ModuleBase::WARNING_QUIT (" gemm_op" , std::string (" Unknown transb type " ) + transb + std::string (" !" ));
897
+ }
867
898
hipblasErrcheck (hipblasZgemm (cublas_handle, cutransA, cutransB, m, n ,k, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)a , lda, (hipblasDoubleComplex*)b, ldb, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)c, ldc));
868
899
}
869
900
0 commit comments