Skip to content

Commit 0b2f5a8

Browse files
authored
Create separate CircularBufferInserter for WarpSpecialized and Pipeline (#4280)
This PR separates `CircularBufferInserter` into two variants: `WarpSpecializedCircularBufferInserter` and `PipelineCircularBufferInserter`. Stacked on: #4275 This refactor is a foundation for more complex changes required for supporting Blackwell UTCMMA and Ping-Pong WarpSpecialization. Further enhancements are not planned for pipeline circular buffering.
1 parent fb9b956 commit 0b2f5a8

File tree

1 file changed

+198
-135
lines changed

1 file changed

+198
-135
lines changed

csrc/device_lower/pass/circular_buffer.cpp

Lines changed: 198 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,10 +1334,51 @@ void getAllocInTrivialLoop(ForLoop* fl, std::unordered_set<Expr*>& output) {
13341334
}
13351335
}
13361336

1337+
// Create something like below:
1338+
// for (int i = 0; i < prefetch + 1; ++i) {
1339+
// mbarrier::arrive(mbarrier0[stage + i]]);
1340+
// mbarrier::arrive(mbarrier1[stage + i]);
1341+
// ...
1342+
// }
1343+
// where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i.
1344+
//
1345+
// This is needed because we prefetch data in circular buffering, and we
1346+
// need to make sure the initial prefetches are not blocked by the
1347+
// non-existing WAR hazards.
1348+
ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) {
1349+
const auto& opt =
1350+
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
1351+
circular_buffer_loop->iter_domain());
1352+
auto circular_buffer_tvs =
1353+
GpuLower::current()->circularBufferInfo().getCircularBufferTvs(
1354+
circular_buffer_loop->iter_domain());
1355+
VectorOfUniqueEntries<TensorView*> mbarriers;
1356+
for (auto tv : circular_buffer_tvs) {
1357+
auto ldst = dynamic_cast<LoadStoreOp*>(tv->definition());
1358+
NVF_ERROR(ldst != nullptr);
1359+
auto it = GpuLower::current()->mbarrierMap().find(ldst);
1360+
if (it == GpuLower::current()->mbarrierMap().end()) {
1361+
continue;
1362+
}
1363+
mbarriers.pushBack(it->second);
1364+
}
1365+
auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1);
1366+
for (auto mbarrier : mbarriers) {
1367+
auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
1368+
mbarrier,
1369+
SimplifyingIrBuilder::addExpr(
1370+
prefetch_loop->indexOrStartIfTrivial(), opt.stage));
1371+
auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
1372+
/*state=*/nullptr, mbarrier_to_arrive);
1373+
prefetch_loop->body().push_back(prefetch);
1374+
}
1375+
return prefetch_loop;
1376+
}
1377+
13371378
} // namespace
13381379

1339-
// Apply circular buffering transformations
1340-
class CircularBufferInserter : private kir::ExprMutator {
1380+
// Apply warp specialized circular buffering transformations
1381+
class WarpSpecializedCircularBufferInserter : private kir::ExprMutator {
13411382
public:
13421383
// When there exist multiple circular buffer loops, apply
13431384
// transformations to inner-most loops first. A single ExprMutator
@@ -1347,14 +1388,15 @@ class CircularBufferInserter : private kir::ExprMutator {
13471388
InsertionInfo insertion_info) {
13481389
std::vector<Expr*> inserted_exprs = exprs;
13491390
while (!insertion_info.empty()) {
1350-
CircularBufferInserter inserter(inserted_exprs, insertion_info);
1391+
WarpSpecializedCircularBufferInserter inserter(
1392+
inserted_exprs, insertion_info);
13511393
inserted_exprs = inserter.exprs_;
13521394
}
13531395
return inserted_exprs;
13541396
}
13551397

13561398
private:
1357-
CircularBufferInserter(
1399+
WarpSpecializedCircularBufferInserter(
13581400
const std::vector<Expr*>& exprs,
13591401
InsertionInfo& insertion_info)
13601402
: insertion_info_(insertion_info) {
@@ -1380,143 +1422,24 @@ class CircularBufferInserter : private kir::ExprMutator {
13801422
return;
13811423
}
13821424

1383-
auto has_cp_async_bulk = std::any_of(
1384-
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk);
1385-
13861425
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
13871426
GpuLower::current()
13881427
->circularBufferInfo()
13891428
.getCircularBufferOptionsFor(loop->iter_domain())
13901429
.type);
1391-
if (use_warp_specialization) {
1392-
NVF_ERROR(
1393-
std::all_of(
1394-
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk),
1395-
"In order to use warp specialization, all buffers must be loaded by TMA");
1396-
int64_t insertion_position =
1397-
GpuLower::current()
1398-
->circularBufferInfo()
1399-
.getCircularBufferInsertionPosition(loop->iter_domain());
1400-
insertTmaWarpSpecialized(loop, it->second, insertion_position);
1401-
} else if (has_cp_async_bulk) {
1402-
insertTmaPipelined(loop, it->second);
1403-
} else {
1404-
insert(loop, it->second);
1405-
}
1406-
processed_loop_ = loop;
1407-
insertion_info_.erase(loop);
1408-
}
1409-
1410-
bool hasPrefetch(ForLoop* circular_buffer_loop) {
1411-
int64_t prefetch_distance =
1430+
NVF_ERROR(use_warp_specialization);
1431+
NVF_ERROR(
1432+
std::all_of(
1433+
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk),
1434+
"In order to use warp specialization, all buffers must be loaded by TMA");
1435+
int64_t insertion_position =
14121436
GpuLower::current()
14131437
->circularBufferInfo()
1414-
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
1415-
.prefetch;
1416-
return prefetch_distance > 0;
1417-
}
1418-
1419-
// Create something like below:
1420-
// for (int i = 0; i < prefetch + 1; ++i) {
1421-
// mbarrier::arrive(mbarrier0[stage + i]]);
1422-
// mbarrier::arrive(mbarrier1[stage + i]);
1423-
// ...
1424-
// }
1425-
// where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i.
1426-
//
1427-
// This is needed because we prefetch data in circular buffering, and we
1428-
// need to make sure the initial prefetches are not blocked by the
1429-
// non-existing WAR hazards.
1430-
ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) {
1431-
const auto& opt =
1432-
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
1433-
circular_buffer_loop->iter_domain());
1434-
auto circular_buffer_tvs =
1435-
GpuLower::current()->circularBufferInfo().getCircularBufferTvs(
1436-
circular_buffer_loop->iter_domain());
1437-
VectorOfUniqueEntries<TensorView*> mbarriers;
1438-
for (auto tv : circular_buffer_tvs) {
1439-
auto ldst = dynamic_cast<LoadStoreOp*>(tv->definition());
1440-
NVF_ERROR(ldst != nullptr);
1441-
auto it = GpuLower::current()->mbarrierMap().find(ldst);
1442-
if (it == GpuLower::current()->mbarrierMap().end()) {
1443-
continue;
1444-
}
1445-
mbarriers.pushBack(it->second);
1446-
}
1447-
auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1);
1448-
for (auto mbarrier : mbarriers) {
1449-
auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
1450-
mbarrier,
1451-
SimplifyingIrBuilder::addExpr(
1452-
prefetch_loop->indexOrStartIfTrivial(), opt.stage));
1453-
auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
1454-
/*state=*/nullptr, mbarrier_to_arrive);
1455-
prefetch_loop->body().push_back(prefetch);
1456-
}
1457-
return prefetch_loop;
1458-
}
1459-
1460-
static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) {
1461-
return GpuLower::current()
1462-
->circularBufferInfo()
1463-
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
1464-
.usesMBarrierForWAR();
1465-
}
1466-
1467-
void insertTmaPipelined(
1468-
ForLoop* circular_buffer_loop,
1469-
const std::vector<Expr*>& loads) {
1470-
// Arrive on the WAR mbarriers to let the prefetching start.
1471-
if (usesMBarrierForWAR(circular_buffer_loop)) {
1472-
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
1473-
registerInsertBefore(circular_buffer_loop, prefetch_loop);
1474-
}
1475-
1476-
// Prologue loop:
1477-
// - launch only
1478-
// - arrive_expect_tx and tma load operations
1479-
if (hasPrefetch(circular_buffer_loop)) {
1480-
// If there is no prefetch, then we don't need a prologue loop.
1481-
ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1482-
circular_buffer_loop,
1483-
loads,
1484-
CircularBufferLoopStage::Prolog,
1485-
/*insertion_position=*/1);
1486-
registerInsertBefore(circular_buffer_loop, prologue_loop);
1487-
}
1488-
1489-
// Main loop:
1490-
// - Launch and wait
1491-
// - arrive_expect_tx, tma load operations, and mbarrier_wait
1492-
ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1493-
circular_buffer_loop,
1494-
loads,
1495-
CircularBufferLoopStage::Main,
1496-
/*insertion_position=*/1);
1497-
registerReplace(circular_buffer_loop, main_loop);
1438+
.getCircularBufferInsertionPosition(loop->iter_domain());
1439+
insertTmaWarpSpecialized(loop, it->second, insertion_position);
14981440

1499-
if (!hasPrefetch(circular_buffer_loop)) {
1500-
// If there is no prefetch, then we don't need a epilogue loop.
1501-
return;
1502-
}
1503-
1504-
// We can use exclude argument in
1505-
// CloneTmaCircularBufferLoopAndInsertSync clone to avoid
1506-
// duplicating allocations if main loop is trivial.
1507-
std::unordered_set<Expr*> expressions_allocated_in_main_loop;
1508-
getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop);
1509-
1510-
// Epilogue loop:
1511-
// - wait only
1512-
// - mbarrier_wait
1513-
ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1514-
circular_buffer_loop,
1515-
loads,
1516-
CircularBufferLoopStage::Epilog,
1517-
/*insertion_position=*/1,
1518-
expressions_allocated_in_main_loop);
1519-
registerInsertAfter(circular_buffer_loop, epilogue_loop);
1441+
processed_loop_ = loop;
1442+
insertion_info_.erase(loop);
15201443
}
15211444

15221445
void insertTmaWarpSpecialized(
@@ -1610,6 +1533,145 @@ class CircularBufferInserter : private kir::ExprMutator {
16101533
registerReplace(circular_buffer_loop, warp_dispatch_ite);
16111534
}
16121535

1536+
private:
1537+
InsertionInfo& insertion_info_;
1538+
ForLoop* processed_loop_ = nullptr;
1539+
};
1540+
1541+
// Apply pipeline circular buffering transformations
1542+
class PipelineCircularBufferInserter : private kir::ExprMutator {
1543+
public:
1544+
// When there exist multiple circular buffer loops, apply
1545+
// transformations to inner-most loops first. A single ExprMutator
1546+
// pass can only process one loop.
1547+
static std::vector<Expr*> run(
1548+
const std::vector<Expr*>& exprs,
1549+
InsertionInfo insertion_info) {
1550+
std::vector<Expr*> inserted_exprs = exprs;
1551+
while (!insertion_info.empty()) {
1552+
PipelineCircularBufferInserter inserter(inserted_exprs, insertion_info);
1553+
inserted_exprs = inserter.exprs_;
1554+
}
1555+
return inserted_exprs;
1556+
}
1557+
1558+
private:
1559+
PipelineCircularBufferInserter(
1560+
const std::vector<Expr*>& exprs,
1561+
InsertionInfo& insertion_info)
1562+
: insertion_info_(insertion_info) {
1563+
size_t num_circular_buffer_loops = insertion_info.size();
1564+
traverseAndInsert(exprs);
1565+
NVF_ERROR(processed_loop_ != nullptr);
1566+
NVF_ERROR(insertion_info.size() == num_circular_buffer_loops - 1);
1567+
}
1568+
1569+
using kir::ExprMutator::handle;
1570+
1571+
void handle(ForLoop* loop) final {
1572+
kir::ExprMutator::handle(loop);
1573+
1574+
// If another loop is already taken care of, no more loop should
1575+
// be done in the same pass
1576+
if (processed_loop_ != nullptr) {
1577+
return;
1578+
}
1579+
1580+
auto it = insertion_info_.find(loop);
1581+
if (it == insertion_info_.end()) {
1582+
return;
1583+
}
1584+
1585+
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1586+
GpuLower::current()
1587+
->circularBufferInfo()
1588+
.getCircularBufferOptionsFor(loop->iter_domain())
1589+
.type);
1590+
NVF_ERROR(!use_warp_specialization);
1591+
1592+
auto has_cp_async_bulk = std::any_of(
1593+
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk);
1594+
if (has_cp_async_bulk) {
1595+
insertTmaPipelined(loop, it->second);
1596+
} else {
1597+
insert(loop, it->second);
1598+
}
1599+
1600+
processed_loop_ = loop;
1601+
insertion_info_.erase(loop);
1602+
}
1603+
1604+
bool hasPrefetch(ForLoop* circular_buffer_loop) {
1605+
int64_t prefetch_distance =
1606+
GpuLower::current()
1607+
->circularBufferInfo()
1608+
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
1609+
.prefetch;
1610+
return prefetch_distance > 0;
1611+
}
1612+
1613+
static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) {
1614+
return GpuLower::current()
1615+
->circularBufferInfo()
1616+
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
1617+
.usesMBarrierForWAR();
1618+
}
1619+
1620+
void insertTmaPipelined(
1621+
ForLoop* circular_buffer_loop,
1622+
const std::vector<Expr*>& loads) {
1623+
// Arrive on the WAR mbarriers to let the prefetching start.
1624+
if (usesMBarrierForWAR(circular_buffer_loop)) {
1625+
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
1626+
registerInsertBefore(circular_buffer_loop, prefetch_loop);
1627+
}
1628+
1629+
// Prologue loop:
1630+
// - launch only
1631+
// - arrive_expect_tx and tma load operations
1632+
if (hasPrefetch(circular_buffer_loop)) {
1633+
// If there is no prefetch, then we don't need a prologue loop.
1634+
ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1635+
circular_buffer_loop,
1636+
loads,
1637+
CircularBufferLoopStage::Prolog,
1638+
/*insertion_position=*/1);
1639+
registerInsertBefore(circular_buffer_loop, prologue_loop);
1640+
}
1641+
1642+
// Main loop:
1643+
// - Launch and wait
1644+
// - arrive_expect_tx, tma load operations, and mbarrier_wait
1645+
ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1646+
circular_buffer_loop,
1647+
loads,
1648+
CircularBufferLoopStage::Main,
1649+
/*insertion_position=*/1);
1650+
registerReplace(circular_buffer_loop, main_loop);
1651+
1652+
if (!hasPrefetch(circular_buffer_loop)) {
1653+
// If there is no prefetch, then we don't need a epilogue loop.
1654+
return;
1655+
}
1656+
1657+
// We can use exclude argument in
1658+
// CloneTmaCircularBufferLoopAndInsertSync clone to avoid
1659+
// duplicating allocations if main loop is trivial.
1660+
std::unordered_set<Expr*> expressions_allocated_in_main_loop;
1661+
getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop);
1662+
1663+
// Epilogue loop:
1664+
// - wait only
1665+
// - mbarrier_wait
1666+
ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1667+
circular_buffer_loop,
1668+
loads,
1669+
CircularBufferLoopStage::Epilog,
1670+
/*insertion_position=*/1,
1671+
expressions_allocated_in_main_loop);
1672+
registerInsertAfter(circular_buffer_loop, epilogue_loop);
1673+
}
1674+
16131675
void insert(ForLoop* circular_buffer_loop, const std::vector<Expr*>& loads) {
16141676
NVF_ERROR(
16151677
!usesMBarrierForWAR(circular_buffer_loop),
@@ -1839,8 +1901,9 @@ std::vector<Expr*> CircularBufferPass::run(const std::vector<Expr*>& exprs) {
18391901
// Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized
18401902
// inside of Pipeline circular buffering.
18411903
std::vector<Expr*> result_exprs =
1842-
CircularBufferInserter::run(exprs, pipeline_insertion_info);
1843-
return CircularBufferInserter::run(result_exprs, ws_insertion_info);
1904+
PipelineCircularBufferInserter::run(exprs, pipeline_insertion_info);
1905+
return WarpSpecializedCircularBufferInserter::run(
1906+
result_exprs, ws_insertion_info);
18441907
}
18451908

18461909
} // namespace nvfuser

0 commit comments

Comments
 (0)