Skip to content

Commit d70eb9d

Browse files
committed
Add LSX support for S8S8 and S8U8 GEMM kernels
- Add LSX support for S8S8/S8U8 to fix build on loong64: ``` error: struct MLAS_PLATFORM’ has no member named ‘GemmS8S8Dispatch ``` - Add new dispatch entries for S8S8 and S8U8 GEMM operations in mlasi.h - Extend MLAS_PLATFORM struct to include S8S8 and S8U8 dispatch pointers - Add GemmS8S8Dispatch/GemmS8U8Dispatch to qgemm.h's AIsSigned paths Signed-off-by: Zhou Qiankang <[email protected]>
1 parent d58ff6b commit d70eb9d

File tree

4 files changed

+399
-1
lines changed

4 files changed

+399
-1
lines changed

onnxruntime/core/mlas/lib/mlasi.h

+4
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,8 @@ struct MLAS_GEMM_QUANT_DISPATCH;
982982

983983
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse;
984984
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX;
985+
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchLSX;
986+
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8U8DispatchLSX;
985987
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41;
986988
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2;
987989
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2;
@@ -1150,6 +1152,8 @@ struct MLAS_PLATFORM {
11501152
#if defined(MLAS_TARGET_LARCH64)
11511153
const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch;
11521154
const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch;
1155+
const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch;
1156+
const MLAS_GEMM_QUANT_DISPATCH* GemmS8U8Dispatch;
11531157
MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel;
11541158
MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel;
11551159
MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel;

onnxruntime/core/mlas/lib/platform.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,14 @@ Return Value:
658658

659659
this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX;
660660
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX;
661+
this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchLSX;
662+
this->GemmS8U8Dispatch = &MlasGemmS8U8DispatchLSX;
661663
}else if( cap_lsx ){
662664
this->GemmFloatKernel = MlasGemmFloatKernelLSX;
663665
this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX;
664666
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX;
667+
this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchLSX;
668+
this->GemmS8U8Dispatch = &MlasGemmS8U8DispatchLSX;
665669
this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4LSX;
666670
this->GemmDoubleKernel = MlasGemmDoubleKernelLSX;
667671
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLSX;

onnxruntime/core/mlas/lib/qgemm.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,10 @@ MlasGemmQuantGetDispatch(
905905
GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch;
906906
}
907907
#elif defined(MLAS_TARGET_LARCH64)
908-
if (!AIsSigned) {
908+
if (AIsSigned) {
909+
GemmQuantDispatch =
910+
BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch;
911+
} else { // !AIsSigned
909912
GemmQuantDispatch =
910913
BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch;
911914
}

0 commit comments

Comments
 (0)