3
3
from torch .utils .cpp_extension import load
4
4
from functools import partial
5
5
from typing import Optional
6
+ import argparse
6
7
7
8
torch .set_grad_enabled (False )
8
9
@@ -95,15 +96,24 @@ def run_benchmark(perf_func: callable,
95
96
else :
96
97
improve = 0
97
98
MAX_TFLOPS = TFLOPS
98
- print (f"{ out_info :>40 } : { out_val } , time:{ mean_time } ms, "
99
+ print (f"{ out_info :>35 } : { out_val } , time:{ mean_time } ms, "
99
100
f"swizzle: { swizzle_stride :<4} , TFLOPS: { TFLOPS :<6.2f} (+{ improve :.2f} %)" )
100
101
else :
101
- print (f"{ out_info :>40 } : { out_val } , time:{ mean_time } ms, "
102
+ print (f"{ out_info :>35 } : { out_val } , time:{ mean_time } ms, "
102
103
f"swizzle: { swizzle_stride :<4} , TFLOPS: { TFLOPS :<6.2f} " )
103
104
if show_all : print (out )
104
105
return out , mean_time
105
106
106
107
108
+ def get_args ():
109
+ parser = argparse .ArgumentParser (description = "hgemm benchmark" )
110
+ parser .add_argument ("--enable-mma-all" , "-ma" , action = "store_true" )
111
+ parser .add_argument ("--enable-wmma-all" , "-wa" , action = "store_true" )
112
+ parser .add_argument ("--enable-cuda-all" , "-ca" , action = "store_true" )
113
+ return parser .parse_args ()
114
+
115
+
116
+ args = get_args ()
107
117
Ms = [4096 , 8192 , 16384 ]
108
118
Ns = [4096 , 8192 , 16384 ]
109
119
Ks = [2048 , 4096 , 8192 ]
@@ -124,44 +134,44 @@ def run_benchmark(perf_func: callable,
124
134
c = C [:M , :N ].contiguous ()
125
135
torch .cuda .synchronize ()
126
136
127
- # CUDA Cores FP16
128
- # run_benchmark(lib.hgemm_naive_f16, a, b, "f16(naive)", c)
129
- # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(t8x8+bcf)", c)
137
+ if args .enable_cuda_all :
138
+ # CUDA Cores FP16
139
+ run_benchmark (lib .hgemm_naive_f16 , a , b , "f16(naive)" , c )
140
+ run_benchmark (lib .hgemm_t_8x8_sliced_k_f16x8_pack_bcf , a , b , "f16x8pack(t8x8+bcf)" , c )
141
+
130
142
run_benchmark (lib .hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf , a , b , "f16x8pack(t8x8+dbuf)" , c )
131
143
run_benchmark (lib .hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf , a , b , "f16x8pack(t8x8+k16+dbuf)" , c )
132
144
133
145
print ("-" * 68 + "WMMA" + "-" * 58 )
134
- # run_benchmark(lib.hgemm_wmma_m16n16k16_naive, a, b, "f16wmma(naive)", c)
135
- run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2 , a , b , "f16wmma(mma4x2)" , c )
136
- run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4 , a , b , "f16wmma(mma4x2+warp2x4)" , c )
137
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_offset, a, b, "f16wmma(mma2x4+warp2x4+dbuf)", c)
138
-
139
- # Stages, dsmem
140
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage4)", c, stages=4)
141
- run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages , a , b , "f16wmma(mma2x4+warp2x4+stage3)" , c , stages = 3 )
142
- run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages , a , b , "f16wmma(mma2x4+warp2x4+stage2)" , c , stages = 2 )
143
-
144
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage4+dsmem)", c, stages=4)
145
- run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem , a , b , "f16wmma(mma2x4+...+stage3+dsmem)" , c , stages = 3 )
146
- run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem , a , b , "f16wmma(mma2x4+...+stage2+dsmem)" , c , stages = 2 )
147
-
148
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage4+dsmem)", c, stages=4)
149
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage3+dsmem)", c, stages=3)
150
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage2+dsmem)", c, stages=2)
146
+ # wmma api, stages, dsmem, swizzle
147
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2 , a , b , "(mma4x2)" , c )
148
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4 , a , b , "(mma4x2+warp2x4)" , c )
151
149
152
- # Thread block swizzle
153
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma (mma2x4+...+stage4+swizzle )", c, stages=4, swizzle=True )
154
- run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages , a , b , "f16wmma (mma2x4+...+stage3+swizzle )" , c , stages = 3 , swizzle = True )
155
- run_benchmark ( lib . hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages , a , b , "f16wmma(mma2x4+...+stage2+swizzle)" , c , stages = 2 , swizzle = True )
156
-
157
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma( ...+stage4 +dsmem+swizzle )", c, stages=4, swizzle=True )
158
- run_benchmark ( lib . hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem , a , b , "f16wmma(...+stage3+dsmem+swizzle)" , c , stages = 3 , swizzle = True )
159
- run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem , a , b , "f16wmma( ...+stage2+dsmem+ swizzle)" , c , stages = 2 , swizzle = True )
160
-
161
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
162
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem , a, b, "f16wmma(mma4x4 +stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
163
- # run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem , a, b, "f16wmma(mma4x4 +stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
150
+ # prefer on NVIDIA L20 device.
151
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages , a , b , "(mma2x4+warp2x4+stage3 )" , c , stages = 3 )
152
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages , a , b , "(mma2x4+warp2x4+stage2 )" , c , stages = 2 )
153
+
154
+ run_benchmark ( lib . hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem , a , b , "(mma2x4+...+stage3+dsmem)" , c , stages = 3 )
155
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem , a , b , "(mma2x4+ ...+stage2 +dsmem)" , c , stages = 2 )
156
+
157
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages , a , b , "(mma2x4+ ...+stage3+ swizzle)" , c , stages = 3 , swizzle = True )
158
+ run_benchmark ( lib . hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages , a , b , "(mma2x4+...+stage2+swizzle)" , c , stages = 2 , swizzle = True )
159
+
160
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem , a , b , "(... +stage3+dsmem+swizzle)" , c , stages = 3 , swizzle = True )
161
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem , a , b , "(... +stage2+dsmem+swizzle)" , c , stages = 2 , swizzle = True )
164
162
163
+ if args .enable_wmma_all :
164
+ # prefer on NVIDIA TRX 3080 Laptop 16GB GDDR6 device.
165
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem , a , b , "(mma4x4+...+stage3+dsmem)" , c , stages = 3 )
166
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem , a , b , "(mma4x4+...+stage2+dsmem)" , c , stages = 2 )
167
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem , a , b , "(warp2x4x2+...+stage3+dsmem)" , c , stages = 3 )
168
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem , a , b , "(warp2x4x2+...+stage2+dsmem)" , c , stages = 2 )
169
+
170
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem , a , b , "(mma4x4+stage3+dsmem+swizzle)" , c , stages = 3 , swizzle = True )
171
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem , a , b , "(mma4x4+stage2+dsmem+swizzle)" , c , stages = 2 , swizzle = True )
172
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem , a , b , "(warp2x4x2+stage3+dsmem+swizzle)" , c , stages = 3 , swizzle = True )
173
+ run_benchmark (lib .hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem , a , b , "(warp2x4x2+stage2+dsmem+swizzle)" , c , stages = 2 , swizzle = True )
174
+
165
175
run_benchmark (lib .hgemm_cublas_tensor_op , a , b , "f16(cublas)" , c )
166
176
run_benchmark (partial (torch .matmul , out = c ), a , b , "f16_th" )
167
177
torch .cuda .synchronize ()
0 commit comments