@@ -1148,9 +1148,38 @@ class IsCircularBufferLoadLoop : public kir::IrVisitor {
1148
1148
// associated load expressions.
1149
1149
class CircularBufferLoopNestInspector : private kir ::IrVisitor {
1150
1150
public:
1151
- static InsertionInfo run (const std::vector<Expr*>& exprs) {
1151
+ static std::pair<InsertionInfo, InsertionInfo> run (
1152
+ const std::vector<Expr*>& exprs) {
1152
1153
CircularBufferLoopNestInspector inspector (exprs);
1153
- return inspector.insertion_info_ ;
1154
+
1155
+ // InsertionInfo holds all circular buffer for-loops and is ordered from
1156
+ // inner-most to outer-most. Split it into warp specialized and pipeline
1157
+ // circular buffers. Enforce that we can only nest pipeline circular
1158
+ // buffering inside of warp-specialization.
1159
+ InsertionInfo ws_info;
1160
+ InsertionInfo pipeline_info;
1161
+ for (auto && [cb_loop, cb_exprs] : inspector.insertion_info_ ) {
1162
+ bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1163
+ GpuLower::current ()
1164
+ ->circularBufferInfo ()
1165
+ .getCircularBufferOptionsFor (cb_loop->iter_domain ())
1166
+ .type );
1167
+ if (use_warp_specialization) {
1168
+ ws_info[cb_loop] = cb_exprs;
1169
+ } else {
1170
+ NVF_ERROR (
1171
+ ws_info.empty (),
1172
+ " Warp Specialization cannot be nested in Pipeline circular buffering!" );
1173
+ pipeline_info[cb_loop] = cb_exprs;
1174
+ }
1175
+ }
1176
+ NVF_ERROR (
1177
+ ws_info.size () <= 4 ,
1178
+ " At most four for-loops can run concurrently inside the AsyncWarp.\n " ,
1179
+ " Detected " ,
1180
+ ws_info.size (),
1181
+ " WarpSpecialized for-loops." );
1182
+ return {ws_info, pipeline_info};
1154
1183
}
1155
1184
1156
1185
private:
@@ -1728,8 +1757,14 @@ kir::TensorIndex* TmaCircularBufferInfo::getTensorIndex(const Expr* expr) {
1728
1757
}
1729
1758
1730
1759
std::vector<Expr*> CircularBufferPass::run (const std::vector<Expr*>& exprs) {
1731
- InsertionInfo insertion_info = CircularBufferLoopNestInspector::run (exprs);
1732
- return CircularBufferInserter::run (exprs, insertion_info);
1760
+ auto && [ws_insertion_info, pipeline_insertion_info] =
1761
+ CircularBufferLoopNestInspector::run (exprs);
1762
+ // Process circular buffer for-loops from inner to outer-most.
1763
+ // Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized
1764
+ // inside of Pipeline circular buffering.
1765
+ std::vector<Expr*> result_exprs =
1766
+ CircularBufferInserter::run (exprs, pipeline_insertion_info);
1767
+ return CircularBufferInserter::run (result_exprs, ws_insertion_info);
1733
1768
}
1734
1769
1735
1770
} // namespace nvfuser
0 commit comments