@@ -1144,6 +1144,18 @@ class IsCircularBufferLoadLoop : public kir::IrVisitor {
1144
1144
bool result_ = false ;
1145
1145
};
1146
1146
1147
+ namespace {
1148
+
1149
+ bool isWarpSpecialized (ForLoop* loop) {
1150
+ return std::holds_alternative<WarpSpecialized>(
1151
+ GpuLower::current ()
1152
+ ->circularBufferInfo ()
1153
+ .getCircularBufferOptionsFor (loop->iter_domain ())
1154
+ .type );
1155
+ }
1156
+
1157
+ } // namespace
1158
+
1147
1159
// Traverse lowered loop-nests and find all circular buffer loops and
1148
1160
// associated load expressions.
1149
1161
class CircularBufferLoopNestInspector : private kir ::IrVisitor {
@@ -1152,33 +1164,40 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor {
1152
1164
const std::vector<Expr*>& exprs) {
1153
1165
CircularBufferLoopNestInspector inspector (exprs);
1154
1166
1155
- // InsertionInfo holds all circular buffer for-loops and is ordered from
1156
- // inner-most to outer-most. Split it into warp specialized and pipeline
1157
- // circular buffers. Enforce that we can only nest pipeline circular
1158
- // buffering inside of warp-specialization.
1167
+ // InsertionInfo holds all circular buffer for-loops. Split it into warp
1168
+ // specialized and pipeline circular buffers. Enforce that we can only nest
1169
+ // pipeline circular buffering inside of warp-specialization.
1170
+
1171
+ // Get WarpSpecialized InsertionInfo
1159
1172
InsertionInfo ws_info;
1160
- InsertionInfo pipeline_info ;
1173
+ int64_t inner_most_ws_position = - 1 ;
1161
1174
for (auto && [cb_loop, cb_exprs] : inspector.insertion_info_ ) {
1162
- bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1163
- GpuLower::current ()
1164
- ->circularBufferInfo ()
1165
- .getCircularBufferOptionsFor (cb_loop->iter_domain ())
1166
- .type );
1167
- if (use_warp_specialization) {
1168
- ws_info[cb_loop] = cb_exprs;
1169
- } else {
1170
- NVF_ERROR (
1171
- ws_info.empty (),
1172
- " Warp Specialization cannot be nested in Pipeline circular buffering!" );
1173
- pipeline_info[cb_loop] = cb_exprs;
1175
+ if (!isWarpSpecialized (cb_loop)) {
1176
+ continue ;
1174
1177
}
1178
+ ws_info[cb_loop] = cb_exprs;
1179
+ inner_most_ws_position = std::max (
1180
+ inner_most_ws_position, inspector.loop_position_ .at (cb_loop));
1175
1181
}
1176
1182
NVF_ERROR (
1177
1183
ws_info.size () <= 4 ,
1178
1184
" At most four for-loops can run concurrently inside the AsyncWarp.\n " ,
1179
1185
" Detected " ,
1180
1186
ws_info.size (),
1181
1187
" WarpSpecialized for-loops." );
1188
+
1189
+ // Get Pipeline InsertionInfo
1190
+ InsertionInfo pipeline_info;
1191
+ for (auto && [cb_loop, cb_exprs] : inspector.insertion_info_ ) {
1192
+ if (isWarpSpecialized (cb_loop)) {
1193
+ continue ;
1194
+ }
1195
+ NVF_ERROR (
1196
+ inspector.loop_position_ .at (cb_loop) > inner_most_ws_position,
1197
+ " Warp Specialization cannot be nested in Pipeline circular buffering!" );
1198
+ pipeline_info[cb_loop] = cb_exprs;
1199
+ }
1200
+
1182
1201
return {ws_info, pipeline_info};
1183
1202
}
1184
1203
@@ -1215,6 +1234,10 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor {
1215
1234
1216
1235
validateCircularBufferLoop (circular_buffer_loop);
1217
1236
1237
+ auto cb_loop_it =
1238
+ std::find (for_loops_.begin (), for_loops_.end (), circular_buffer_loop);
1239
+ loop_position_[circular_buffer_loop] =
1240
+ std::distance (for_loops_.begin (), cb_loop_it);
1218
1241
insertion_info_[circular_buffer_loop].push_back (expr);
1219
1242
}
1220
1243
@@ -1240,6 +1263,8 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor {
1240
1263
loop->toString ());
1241
1264
}
1242
1265
1266
+ // Map circular buffer loop to its position in the for_loop_ stack.
1267
+ std::unordered_map<ForLoop*, int64_t > loop_position_;
1243
1268
InsertionInfo insertion_info_;
1244
1269
};
1245
1270
0 commit comments