Skip to content

Commit 108a188

Browse files
committed
Split insertion_info into Pipeline and WarpSpecialized
1 parent 7bc8c17 commit 108a188

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

csrc/device_lower/pass/circular_buffer.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,9 +1148,38 @@ class IsCircularBufferLoadLoop : public kir::IrVisitor {
11481148
// associated load expressions.
11491149
class CircularBufferLoopNestInspector : private kir::IrVisitor {
11501150
public:
1151-
static InsertionInfo run(const std::vector<Expr*>& exprs) {
1151+
static std::pair<InsertionInfo, InsertionInfo> run(
1152+
const std::vector<Expr*>& exprs) {
11521153
CircularBufferLoopNestInspector inspector(exprs);
1153-
return inspector.insertion_info_;
1154+
1155+
// InsertionInfo holds all circular buffer for-loops and is ordered from
1156+
// inner-most to outer-most. Split it into warp specialized and pipeline
1157+
// circular buffers. Enforce that we can only nest pipeline circular
1158+
// buffering inside of warp-specialization.
1159+
InsertionInfo ws_info;
1160+
InsertionInfo pipeline_info;
1161+
for (auto&& [cb_loop, cb_exprs] : inspector.insertion_info_) {
1162+
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1163+
GpuLower::current()
1164+
->circularBufferInfo()
1165+
.getCircularBufferOptionsFor(cb_loop->iter_domain())
1166+
.type);
1167+
if (use_warp_specialization) {
1168+
ws_info[cb_loop] = cb_exprs;
1169+
} else {
1170+
NVF_ERROR(
1171+
ws_info.empty(),
1172+
"Warp Specialization cannot be nested in Pipeline circular buffering!");
1173+
pipeline_info[cb_loop] = cb_exprs;
1174+
}
1175+
}
1176+
NVF_ERROR(
1177+
ws_info.size() <= 4,
1178+
"At most four for-loops can run concurrently inside the AsyncWarp.\n",
1179+
"Detected ",
1180+
ws_info.size(),
1181+
" WarpSpecialized for-loops.");
1182+
return {ws_info, pipeline_info};
11541183
}
11551184

11561185
private:
@@ -1728,8 +1757,14 @@ kir::TensorIndex* TmaCircularBufferInfo::getTensorIndex(const Expr* expr) {
17281757
}
17291758

17301759
std::vector<Expr*> CircularBufferPass::run(const std::vector<Expr*>& exprs) {
1731-
InsertionInfo insertion_info = CircularBufferLoopNestInspector::run(exprs);
1732-
return CircularBufferInserter::run(exprs, insertion_info);
1760+
auto&& [ws_insertion_info, pipeline_insertion_info] =
1761+
CircularBufferLoopNestInspector::run(exprs);
1762+
// Process circular buffer for-loops from inner to outer-most.
1763+
// Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized
1764+
// inside of Pipeline circular buffering.
1765+
std::vector<Expr*> result_exprs =
1766+
CircularBufferInserter::run(exprs, pipeline_insertion_info);
1767+
return CircularBufferInserter::run(result_exprs, ws_insertion_info);
17331768
}
17341769

17351770
} // namespace nvfuser

0 commit comments

Comments
 (0)