30
30
31
31
#include " cutlass/util/packed_stride.hpp"
32
32
33
+ #include " core/math.hpp"
34
+
33
35
using namespace cute ;
34
36
35
37
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
36
- // Kernel Perf config
37
- template <typename T>
38
- struct KernelTraits ;
39
38
40
- template <>
41
- struct KernelTraits <float > {
42
- using MmaTileShape = Shape<_128, _128, _256>;
43
- using ClusterShape = Shape<_1, _1, _1>;
44
- using PerSmTileShape_MNK = Shape<_128, _128, _256>;
39
+ // Configuration for M in (256, inf)
40
+ struct sm100_fp4_config_default {
41
+ using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
42
+ using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
43
+ using TileShape = Shape<_256, _256, _256>;
44
+ using ClusterShape = Shape<_2, _1, _1>;
45
+ using PerSmTileShape_MNK = Shape<_128, _256, _256>;
45
46
};
46
47
47
- template <>
48
- struct KernelTraits <cutlass::half_t > {
49
- using MmaTileShape = Shape<_256, _256, _256>;
50
- using ClusterShape = Shape<_4, _4, _1>;
51
- using PerSmTileShape_MNK = Shape<_128, _256, _256>;
48
+ // Configuration for M in (16, 256]
49
+ struct sm100_fp4_config_M256 {
50
+ using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
51
+ using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
52
+ using TileShape = Shape<_256, _128, _256>;
53
+ using ClusterShape = Shape<_2, _1, _1>;
54
+ using PerSmTileShape_MNK = Shape<_128, _128, _256>;
52
55
};
53
56
54
- template <>
55
- struct KernelTraits <cutlass::bfloat16_t > {
56
- using MmaTileShape = Shape<_256, _256, _256>;
57
- using ClusterShape = Shape<_4, _4, _1>;
58
- using PerSmTileShape_MNK = Shape<_128, _256, _256>;
57
+ // Configuration for M in [1, 16]
58
+ struct sm100_fp4_config_M16 {
59
+ using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
60
+ using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
61
+ using TileShape = Shape<_128, _128, _256>;
62
+ using ClusterShape = Shape<_1, _1, _1>;
63
+ using PerSmTileShape_MNK = Shape<_128, _128, _256>;
59
64
};
60
65
61
- template <typename T >
66
+ template <typename Config, typename OutType >
62
67
struct Fp4GemmSm100 {
63
68
// A matrix configuration
64
69
using ElementA = cutlass::nv_float4_t <cutlass::float_e2m1_t >;
@@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
71
76
static constexpr int AlignmentB = 32 ;
72
77
73
78
// C/D matrix configuration
74
- using ElementD = T ;
75
- using ElementC = T ;
79
+ using ElementD = OutType ;
80
+ using ElementC = OutType ;
76
81
using LayoutCTag = cutlass::layout::RowMajor;
77
82
using LayoutDTag = cutlass::layout::RowMajor;
78
83
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
79
84
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
85
+
80
86
// Kernel functional config
81
87
using ElementAccumulator = float ;
82
88
using ArchTag = cutlass::arch::Sm100;
83
89
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
84
90
85
- // Kernel Perf config
86
- using MmaTileShape = typename KernelTraits<T>::MmaTileShape ;
87
- using ClusterShape = typename KernelTraits<T> ::ClusterShape;
88
- using PerSmTileShape_MNK = typename KernelTraits<T> ::PerSmTileShape_MNK;
91
+ // Use config's tile shapes
92
+ using MmaTileShape = typename Config::TileShape ;
93
+ using ClusterShape = typename Config ::ClusterShape;
94
+ using PerSmTileShape_MNK = typename Config ::PerSmTileShape_MNK;
89
95
90
96
using CollectiveEpilogue =
91
97
typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
119
125
using LayoutD = decltype (cute::make_layout(make_shape(0 , 0 , 0 ), StrideD{}));
120
126
};
121
127
122
- template <typename T >
123
- typename T ::Gemm::Arguments args_from_options (
128
+ template <typename Config >
129
+ typename Config ::Gemm::Arguments args_from_options (
124
130
at::Tensor& D, at::Tensor const & A, at::Tensor const & B,
125
131
at::Tensor const & A_sf, at::Tensor const & B_sf, at::Tensor const & alpha,
126
132
int64_t M, int64_t N, int64_t K) {
127
- using ElementA = typename T ::Gemm::ElementA;
128
- using ElementB = typename T ::Gemm::ElementB;
133
+ using ElementA = typename Config ::Gemm::ElementA;
134
+ using ElementB = typename Config ::Gemm::ElementB;
129
135
using ElementSFA = cutlass::float_ue4m3_t ;
130
136
using ElementSFB = cutlass::float_ue4m3_t ;
131
- using ElementD = typename T ::Gemm::ElementD;
137
+ using ElementD = typename Config ::Gemm::ElementD;
132
138
using ElementCompute = float ;
133
- using StrideA = typename T ::StrideA;
134
- using StrideB = typename T ::StrideB;
135
- using StrideD = typename T ::StrideD;
136
- using Sm100BlkScaledConfig =
137
- typename T::Gemm::GemmKernel:: CollectiveMainloop::Sm1xxBlkScaledConfig;
139
+ using StrideA = typename Config ::StrideA;
140
+ using StrideB = typename Config ::StrideB;
141
+ using StrideD = typename Config ::StrideD;
142
+ using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel::
143
+ CollectiveMainloop::Sm1xxBlkScaledConfig;
138
144
139
145
int m = static_cast <int >(M);
140
146
int n = static_cast <int >(N);
@@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
148
154
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB (
149
155
cute::make_shape (m, n, k, 1 ));
150
156
151
- typename T ::Gemm::Arguments arguments{
157
+ typename Config ::Gemm::Arguments arguments{
152
158
cutlass::gemm::GemmUniversalMode::kGemm ,
153
159
{m, n, k, 1 },
154
160
{// Mainloop arguments
@@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
167
173
return arguments;
168
174
}
169
175
170
- template <typename T >
176
+ template <typename Config >
171
177
void runGemm (at::Tensor& D, at::Tensor const & A, at::Tensor const & B,
172
178
at::Tensor const & A_sf, at::Tensor const & B_sf,
173
179
at::Tensor const & alpha, int64_t m, int64_t n, int64_t k,
174
180
cudaStream_t stream) {
175
- typename Fp4GemmSm100<T> ::Gemm gemm;
181
+ typename Config ::Gemm gemm;
176
182
177
183
auto arguments =
178
- args_from_options<Fp4GemmSm100<T> >(D, A, B, A_sf, B_sf, alpha, m, n, k);
184
+ args_from_options<Config >(D, A, B, A_sf, B_sf, alpha, m, n, k);
179
185
180
- size_t workspace_size = Fp4GemmSm100<T> ::Gemm::get_workspace_size (arguments);
186
+ size_t workspace_size = Config ::Gemm::get_workspace_size (arguments);
181
187
auto const workspace_options =
182
188
torch::TensorOptions ().dtype (torch::kUInt8 ).device (A.device ());
183
189
auto workspace = torch::empty (workspace_size, workspace_options);
@@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
188
194
189
195
CUTLASS_CHECK (gemm.run (arguments, workspace.data_ptr (), stream));
190
196
}
197
+
198
+ // Dispatch function to select appropriate config based on M
199
+ template <typename OutType>
200
+ void cutlass_fp4_gemm_dispatch (torch::Tensor& D, torch::Tensor const & A,
201
+ torch::Tensor const & B,
202
+ torch::Tensor const & A_sf,
203
+ torch::Tensor const & B_sf,
204
+ torch::Tensor const & alpha, int64_t m, int64_t n,
205
+ int64_t k, cudaStream_t stream) {
206
+ uint32_t const mp2 = std::max (static_cast <uint32_t >(16 ), next_pow_2 (m));
207
+
208
+ if (mp2 <= 16 ) {
209
+ // m in [1, 16]
210
+ runGemm<Fp4GemmSm100<sm100_fp4_config_M16, OutType>>(
211
+ D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
212
+ } else if (mp2 <= 256 ) {
213
+ // m in (16, 256]
214
+ runGemm<Fp4GemmSm100<sm100_fp4_config_M256, OutType>>(
215
+ D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
216
+ } else {
217
+ // m in (256, inf)
218
+ runGemm<Fp4GemmSm100<sm100_fp4_config_default, OutType>>(
219
+ D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
220
+ }
221
+ }
222
+
191
223
#else
192
- template <typename T>
193
- void runGemm (at::Tensor& D, at::Tensor const & A, at::Tensor const & B,
194
- at::Tensor const & A_sf, at::Tensor const & B_sf,
195
- at::Tensor const & alpha, int64_t m, int64_t n, int64_t k,
196
- cudaStream_t stream) {
224
+ template <typename OutType>
225
+ void cutlass_fp4_gemm_dispatch (torch::Tensor& D, torch::Tensor const & A,
226
+ torch::Tensor const & B,
227
+ torch::Tensor const & A_sf,
228
+ torch::Tensor const & B_sf,
229
+ torch::Tensor const & alpha, int64_t m, int64_t n,
230
+ int64_t k, cudaStream_t stream) {
197
231
TORCH_CHECK (false ,
198
232
" Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
199
233
" a CUTLASS 3.8 source directory to enable support." );
@@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
271
305
const cudaStream_t stream = at::cuda::getCurrentCUDAStream (A.get_device ());
272
306
273
307
if (out_dtype == at::ScalarType::Half) {
274
- runGemm<cutlass::half_t >(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
308
+ cutlass_fp4_gemm_dispatch<cutlass::half_t >(D, A, B, A_sf, B_sf, alpha, m, n,
309
+ k, stream);
275
310
} else if (out_dtype == at::ScalarType::BFloat16) {
276
- runGemm<cutlass::bfloat16_t >(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
277
- } else if (out_dtype == at::ScalarType::Float) {
278
- runGemm<float >(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
311
+ cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t >(D, A, B, A_sf, B_sf, alpha,
312
+ m, n, k, stream);
279
313
} else {
280
- TORCH_CHECK (false , " Unsupported output data type of nvfp4 mm" );
314
+ TORCH_CHECK (false , " Unsupported output data type of nvfp4 mm (" , out_dtype,
315
+ " )" );
281
316
}
282
317
}
0 commit comments