Skip to content

Commit b71967d

Browse files
committed
Track inner-most warp specialized for-loop
1 parent 108a188 commit b71967d

File tree

1 file changed

+42
-17
lines changed

1 file changed

+42
-17
lines changed

csrc/device_lower/pass/circular_buffer.cpp

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,18 @@ class IsCircularBufferLoadLoop : public kir::IrVisitor {
11441144
bool result_ = false;
11451145
};
11461146

1147+
namespace {
1148+
1149+
bool isWarpSpecialized(ForLoop* loop) {
1150+
return std::holds_alternative<WarpSpecialized>(
1151+
GpuLower::current()
1152+
->circularBufferInfo()
1153+
.getCircularBufferOptionsFor(loop->iter_domain())
1154+
.type);
1155+
}
1156+
1157+
} // namespace
1158+
11471159
// Traverse lowered loop-nests and find all circular buffer loops and
11481160
// associated load expressions.
11491161
class CircularBufferLoopNestInspector : private kir::IrVisitor {
@@ -1152,33 +1164,40 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor {
11521164
const std::vector<Expr*>& exprs) {
11531165
CircularBufferLoopNestInspector inspector(exprs);
11541166

1155-
// InsertionInfo holds all circular buffer for-loops and is ordered from
1156-
// inner-most to outer-most. Split it into warp specialized and pipeline
1157-
// circular buffers. Enforce that we can only nest pipeline circular
1158-
// buffering inside of warp-specialization.
1167+
// InsertionInfo holds all circular buffer for-loops. Split it into warp
1168+
// specialized and pipeline circular buffers. Enforce that we can only nest
1169+
// pipeline circular buffering inside of warp-specialization.
1170+
1171+
// Get WarpSpecialized InsertionInfo
11591172
InsertionInfo ws_info;
1160-
InsertionInfo pipeline_info;
1173+
int64_t inner_most_ws_position = -1;
11611174
for (auto&& [cb_loop, cb_exprs] : inspector.insertion_info_) {
1162-
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1163-
GpuLower::current()
1164-
->circularBufferInfo()
1165-
.getCircularBufferOptionsFor(cb_loop->iter_domain())
1166-
.type);
1167-
if (use_warp_specialization) {
1168-
ws_info[cb_loop] = cb_exprs;
1169-
} else {
1170-
NVF_ERROR(
1171-
ws_info.empty(),
1172-
"Warp Specialization cannot be nested in Pipeline circular buffering!");
1173-
pipeline_info[cb_loop] = cb_exprs;
1175+
if (!isWarpSpecialized(cb_loop)) {
1176+
continue;
11741177
}
1178+
ws_info[cb_loop] = cb_exprs;
1179+
inner_most_ws_position = std::max(
1180+
inner_most_ws_position, inspector.loop_position_.at(cb_loop));
11751181
}
11761182
NVF_ERROR(
11771183
ws_info.size() <= 4,
11781184
"At most four for-loops can run concurrently inside the AsyncWarp.\n",
11791185
"Detected ",
11801186
ws_info.size(),
11811187
" WarpSpecialized for-loops.");
1188+
1189+
// Get Pipeline InsertionInfo
1190+
InsertionInfo pipeline_info;
1191+
for (auto&& [cb_loop, cb_exprs] : inspector.insertion_info_) {
1192+
if (isWarpSpecialized(cb_loop)) {
1193+
continue;
1194+
}
1195+
NVF_ERROR(
1196+
inspector.loop_position_.at(cb_loop) > inner_most_ws_position,
1197+
"Warp Specialization cannot be nested in Pipeline circular buffering!");
1198+
pipeline_info[cb_loop] = cb_exprs;
1199+
}
1200+
11821201
return {ws_info, pipeline_info};
11831202
}
11841203

@@ -1215,6 +1234,10 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor {
12151234

12161235
validateCircularBufferLoop(circular_buffer_loop);
12171236

1237+
auto cb_loop_it =
1238+
std::find(for_loops_.begin(), for_loops_.end(), circular_buffer_loop);
1239+
loop_position_[circular_buffer_loop] =
1240+
std::distance(for_loops_.begin(), cb_loop_it);
12181241
insertion_info_[circular_buffer_loop].push_back(expr);
12191242
}
12201243

@@ -1240,6 +1263,8 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor {
12401263
loop->toString());
12411264
}
12421265

1266+
// Map circular buffer loop to its position in the for_loop_ stack.
1267+
std::unordered_map<ForLoop*, int64_t> loop_position_;
12431268
InsertionInfo insertion_info_;
12441269
};
12451270

0 commit comments

Comments
 (0)