Skip to content

Commit d47661f

Browse files
authored
[Kernel] Basic tuned configs for NVFP4 CUTLASS dense GEMM (#20646)
Signed-off-by: mgoin <[email protected]>
1 parent 53fa457 commit d47661f

File tree

1 file changed

+85
-50
lines changed

1 file changed

+85
-50
lines changed

csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu

Lines changed: 85 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,40 @@
3030

3131
#include "cutlass/util/packed_stride.hpp"
3232

33+
#include "core/math.hpp"
34+
3335
using namespace cute;
3436

3537
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
36-
// Kernel Perf config
37-
template <typename T>
38-
struct KernelTraits;
3938

40-
template <>
41-
struct KernelTraits<float> {
42-
using MmaTileShape = Shape<_128, _128, _256>;
43-
using ClusterShape = Shape<_1, _1, _1>;
44-
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
39+
// Configuration for M in (256, inf)
40+
struct sm100_fp4_config_default {
41+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
42+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
43+
using TileShape = Shape<_256, _256, _256>;
44+
using ClusterShape = Shape<_2, _1, _1>;
45+
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
4546
};
4647

47-
template <>
48-
struct KernelTraits<cutlass::half_t> {
49-
using MmaTileShape = Shape<_256, _256, _256>;
50-
using ClusterShape = Shape<_4, _4, _1>;
51-
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
48+
// Configuration for M in (16, 256]
49+
struct sm100_fp4_config_M256 {
50+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
51+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
52+
using TileShape = Shape<_256, _128, _256>;
53+
using ClusterShape = Shape<_2, _1, _1>;
54+
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
5255
};
5356

54-
template <>
55-
struct KernelTraits<cutlass::bfloat16_t> {
56-
using MmaTileShape = Shape<_256, _256, _256>;
57-
using ClusterShape = Shape<_4, _4, _1>;
58-
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
57+
// Configuration for M in [1, 16]
58+
struct sm100_fp4_config_M16 {
59+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
60+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
61+
using TileShape = Shape<_128, _128, _256>;
62+
using ClusterShape = Shape<_1, _1, _1>;
63+
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
5964
};
6065

61-
template <typename T>
66+
template <typename Config, typename OutType>
6267
struct Fp4GemmSm100 {
6368
// A matrix configuration
6469
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
@@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
7176
static constexpr int AlignmentB = 32;
7277

7378
// C/D matrix configuration
74-
using ElementD = T;
75-
using ElementC = T;
79+
using ElementD = OutType;
80+
using ElementC = OutType;
7681
using LayoutCTag = cutlass::layout::RowMajor;
7782
using LayoutDTag = cutlass::layout::RowMajor;
7883
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
7984
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
85+
8086
// Kernel functional config
8187
using ElementAccumulator = float;
8288
using ArchTag = cutlass::arch::Sm100;
8389
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
8490

85-
// Kernel Perf config
86-
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
87-
using ClusterShape = typename KernelTraits<T>::ClusterShape;
88-
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
91+
// Use config's tile shapes
92+
using MmaTileShape = typename Config::TileShape;
93+
using ClusterShape = typename Config::ClusterShape;
94+
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
8995

9096
using CollectiveEpilogue =
9197
typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
119125
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
120126
};
121127

122-
template <typename T>
123-
typename T::Gemm::Arguments args_from_options(
128+
template <typename Config>
129+
typename Config::Gemm::Arguments args_from_options(
124130
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
125131
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
126132
int64_t M, int64_t N, int64_t K) {
127-
using ElementA = typename T::Gemm::ElementA;
128-
using ElementB = typename T::Gemm::ElementB;
133+
using ElementA = typename Config::Gemm::ElementA;
134+
using ElementB = typename Config::Gemm::ElementB;
129135
using ElementSFA = cutlass::float_ue4m3_t;
130136
using ElementSFB = cutlass::float_ue4m3_t;
131-
using ElementD = typename T::Gemm::ElementD;
137+
using ElementD = typename Config::Gemm::ElementD;
132138
using ElementCompute = float;
133-
using StrideA = typename T::StrideA;
134-
using StrideB = typename T::StrideB;
135-
using StrideD = typename T::StrideD;
136-
using Sm100BlkScaledConfig =
137-
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
139+
using StrideA = typename Config::StrideA;
140+
using StrideB = typename Config::StrideB;
141+
using StrideD = typename Config::StrideD;
142+
using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel::
143+
CollectiveMainloop::Sm1xxBlkScaledConfig;
138144

139145
int m = static_cast<int>(M);
140146
int n = static_cast<int>(N);
@@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
148154
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
149155
cute::make_shape(m, n, k, 1));
150156

151-
typename T::Gemm::Arguments arguments{
157+
typename Config::Gemm::Arguments arguments{
152158
cutlass::gemm::GemmUniversalMode::kGemm,
153159
{m, n, k, 1},
154160
{// Mainloop arguments
@@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
167173
return arguments;
168174
}
169175

170-
template <typename T>
176+
template <typename Config>
171177
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
172178
at::Tensor const& A_sf, at::Tensor const& B_sf,
173179
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
174180
cudaStream_t stream) {
175-
typename Fp4GemmSm100<T>::Gemm gemm;
181+
typename Config::Gemm gemm;
176182

177183
auto arguments =
178-
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
184+
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
179185

180-
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
186+
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
181187
auto const workspace_options =
182188
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
183189
auto workspace = torch::empty(workspace_size, workspace_options);
@@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
188194

189195
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
190196
}
197+
198+
// Dispatch function to select appropriate config based on M
199+
template <typename OutType>
200+
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
201+
torch::Tensor const& B,
202+
torch::Tensor const& A_sf,
203+
torch::Tensor const& B_sf,
204+
torch::Tensor const& alpha, int64_t m, int64_t n,
205+
int64_t k, cudaStream_t stream) {
206+
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
207+
208+
if (mp2 <= 16) {
209+
// m in [1, 16]
210+
runGemm<Fp4GemmSm100<sm100_fp4_config_M16, OutType>>(
211+
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
212+
} else if (mp2 <= 256) {
213+
// m in (16, 256]
214+
runGemm<Fp4GemmSm100<sm100_fp4_config_M256, OutType>>(
215+
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
216+
} else {
217+
// m in (256, inf)
218+
runGemm<Fp4GemmSm100<sm100_fp4_config_default, OutType>>(
219+
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
220+
}
221+
}
222+
191223
#else
192-
template <typename T>
193-
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
194-
at::Tensor const& A_sf, at::Tensor const& B_sf,
195-
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
196-
cudaStream_t stream) {
224+
template <typename OutType>
225+
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
226+
torch::Tensor const& B,
227+
torch::Tensor const& A_sf,
228+
torch::Tensor const& B_sf,
229+
torch::Tensor const& alpha, int64_t m, int64_t n,
230+
int64_t k, cudaStream_t stream) {
197231
TORCH_CHECK(false,
198232
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
199233
"a CUTLASS 3.8 source directory to enable support.");
@@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
271305
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
272306

273307
if (out_dtype == at::ScalarType::Half) {
274-
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
308+
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
309+
k, stream);
275310
} else if (out_dtype == at::ScalarType::BFloat16) {
276-
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
277-
} else if (out_dtype == at::ScalarType::Float) {
278-
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
311+
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
312+
m, n, k, stream);
279313
} else {
280-
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
314+
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
315+
")");
281316
}
282317
}

0 commit comments

Comments
 (0)