Skip to content

Commit 980dd4a

Browse files
Fix overflow in awq kernel (#1295)
Co-authored-by: 楚天翔 <[email protected]>
1 parent 8285736 commit 980dd4a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

csrc/quantization/awq/gemm_kernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
9090
+ (((int)threadIdx.x) % (128 / 8)) * 8;
9191

9292
half* C_ptr = C
93-
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
93+
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
9494
+ (((int)blockIdx_y) % j_factors1) * 128
9595
+ ((int)threadIdx.y) * 64
9696
+ (((int)threadIdx.x) % 4) * 2;
@@ -323,7 +323,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
323323
+ (((int)threadIdx.x) % (64 / 8)) * 8;
324324

325325
half* C_ptr = C
326-
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
326+
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
327327
+ (((int)blockIdx_y) % j_factors1) * 64
328328
+ ((int)threadIdx.y) * 32
329329
+ (((int)threadIdx.x) % 4) * 2;

0 commit comments

Comments
 (0)