Skip to content

Commit 523a610

Browse files
authored
[Docs] rename mat_transpose -> mat-transpose (#93)
* Update sgemm_wmma_tf32_stage.cu * Update sgemm.py * Update README.md * Update sgemm_wmma_tf32_stage.cu * Update hgemm_wmma_stage.cu * Update hgemm.cu * Update hgemm.py * Update hgemm.py * rename mat_transpose->mat-transpose * update hgemm benchmark * update hgemm benchmark
1 parent 2f854e8 commit 523a610

11 files changed

+922
-522
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@
6262
| ✔️ [embedding_f16x2](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️|
6363
| ✔️ [embedding_f16x8](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️|
6464
| ✔️ [embedding_f16x8_pack](./embedding/embedding.cu)|f16|/|[link](./embedding/)|⭐️⭐️|
65-
| ✔️ [mat_trans_f32_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️|
66-
| ✔️ [mat_trans_f32_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️|
67-
| ✔️ [mat_trans_f32_diagonal2d](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
68-
| ✔️ [mat_trans_f32x4_col2row{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
69-
| ✔️ [mat_trans_f32x4_row2col{2d}](./mat_transpose/mat_transpose.cu)|f32|/|[link](./mat_transpose/)|⭐️⭐️|
65+
| ✔️ [mat_trans_f32_col2row{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️|
66+
| ✔️ [mat_trans_f32_row2col{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️|
67+
| ✔️ [mat_trans_f32_diagonal2d](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️|
68+
| ✔️ [mat_trans_f32x4_col2row{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️|
69+
| ✔️ [mat_trans_f32x4_row2col{2d}](./mat-transpose/mat_transpose.cu)|f32|/|[link](./mat-transpose/)|⭐️⭐️|
7070
| ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
7171
| ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
7272
| ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|

hgemm/hgemm.cu

+3
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,8 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Te
12371237
int stages, bool swizzle, int swizzle_stride);
12381238
void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c,
12391239
int stages, bool swizzle, int swizzle_stride);
1240+
void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c,
1241+
int stages, bool swizzle, int swizzle_stride);
12401242

12411243

12421244
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -1284,5 +1286,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12841286
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages)
12851287
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem)
12861288
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem)
1289+
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem)
12871290
}
12881291

hgemm/hgemm.py

+44-34
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.utils.cpp_extension import load
44
from functools import partial
55
from typing import Optional
6+
import argparse
67

78
torch.set_grad_enabled(False)
89

@@ -95,15 +96,24 @@ def run_benchmark(perf_func: callable,
9596
else:
9697
improve = 0
9798
MAX_TFLOPS = TFLOPS
98-
print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, "
99+
print(f"{out_info:>35}: {out_val}, time:{mean_time}ms, "
99100
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)")
100101
else:
101-
print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, "
102+
print(f"{out_info:>35}: {out_val}, time:{mean_time}ms, "
102103
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}")
103104
if show_all: print(out)
104105
return out, mean_time
105106

106107

108+
def get_args():
109+
parser = argparse.ArgumentParser(description="hgemm benchmark")
110+
parser.add_argument("--enable-mma-all", "-ma", action="store_true")
111+
parser.add_argument("--enable-wmma-all", "-wa", action="store_true")
112+
parser.add_argument("--enable-cuda-all", "-ca", action="store_true")
113+
return parser.parse_args()
114+
115+
116+
args = get_args()
107117
Ms = [4096, 8192, 16384]
108118
Ns = [4096, 8192, 16384]
109119
Ks = [2048, 4096, 8192]
@@ -124,44 +134,44 @@ def run_benchmark(perf_func: callable,
124134
c = C[:M, :N].contiguous()
125135
torch.cuda.synchronize()
126136

127-
# CUDA Cores FP16
128-
# run_benchmark(lib.hgemm_naive_f16, a, b, "f16(naive)", c)
129-
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(t8x8+bcf)", c)
137+
if args.enable_cuda_all:
138+
# CUDA Cores FP16
139+
run_benchmark(lib.hgemm_naive_f16, a, b, "f16(naive)", c)
140+
run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(t8x8+bcf)", c)
141+
130142
run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(t8x8+dbuf)", c)
131143
run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "f16x8pack(t8x8+k16+dbuf)", c)
132144

133145
print("-" * 68 + "WMMA" + "-" * 58)
134-
# run_benchmark(lib.hgemm_wmma_m16n16k16_naive, a, b, "f16wmma(naive)", c)
135-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "f16wmma(mma4x2)", c)
136-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "f16wmma(mma4x2+warp2x4)", c)
137-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_offset, a, b, "f16wmma(mma2x4+warp2x4+dbuf)", c)
138-
139-
# Stages, dsmem
140-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage4)", c, stages=4)
141-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage3)", c, stages=3)
142-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+warp2x4+stage2)", c, stages=2)
143-
144-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage4+dsmem)", c, stages=4)
145-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage3+dsmem)", c, stages=3)
146-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(mma2x4+...+stage2+dsmem)", c, stages=2)
147-
148-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage4+dsmem)", c, stages=4)
149-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage3+dsmem)", c, stages=3)
150-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+...+stage2+dsmem)", c, stages=2)
146+
# wmma api, stages, dsmem, swizzle
147+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "(mma4x2)", c)
148+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "(mma4x2+warp2x4)", c)
151149

152-
# Thread block swizzle
153-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage4+swizzle)", c, stages=4, swizzle=True)
154-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage3+swizzle)", c, stages=3, swizzle=True)
155-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "f16wmma(mma2x4+...+stage2+swizzle)", c, stages=2, swizzle=True)
156-
157-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
158-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
159-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "f16wmma(...+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
160-
161-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
162-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
163-
# run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "f16wmma(mma4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
150+
# prefer on NVIDIA L20 device.
151+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+warp2x4+stage3)", c, stages=3)
152+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+warp2x4+stage2)", c, stages=2)
153+
154+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma2x4+...+stage3+dsmem)", c, stages=3)
155+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma2x4+...+stage2+dsmem)", c, stages=2)
156+
157+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+...+stage3+swizzle)", c, stages=3, swizzle=True)
158+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma2x4+...+stage2+swizzle)", c, stages=2, swizzle=True)
159+
160+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(...+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
161+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(...+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
164162

163+
if args.enable_wmma_all:
164+
# prefer on NVIDIA TRX 3080 Laptop 16GB GDDR6 device.
165+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+...+stage3+dsmem)", c, stages=3)
166+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+...+stage2+dsmem)", c, stages=2)
167+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+...+stage3+dsmem)", c, stages=3)
168+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+...+stage2+dsmem)", c, stages=2)
169+
170+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
171+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
172+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
173+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_stages_dsmem, a, b, "(warp2x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
174+
165175
run_benchmark(lib.hgemm_cublas_tensor_op, a, b, "f16(cublas)", c)
166176
run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th")
167177
torch.cuda.synchronize()

0 commit comments

Comments
 (0)