Skip to content

Commit ac56adb

Browse files
committed
Create separate CircularBufferInserter for WarpSpecialized and Pipeline
1 parent b71967d commit ac56adb

File tree

1 file changed

+198
-135
lines changed

1 file changed

+198
-135
lines changed

csrc/device_lower/pass/circular_buffer.cpp

Lines changed: 198 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,10 +1283,51 @@ void getAllocInTrivialLoop(ForLoop* fl, std::unordered_set<Expr*>& output) {
12831283
}
12841284
}
12851285

1286+
// Create something like below:
1287+
// for (int i = 0; i < prefetch + 1; ++i) {
1288+
// mbarrier::arrive(mbarrier0[stage + i]]);
1289+
// mbarrier::arrive(mbarrier1[stage + i]);
1290+
// ...
1291+
// }
1292+
// where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i.
1293+
//
1294+
// This is needed because we prefetch data in circular buffering, and we
1295+
// need to make sure the initial prefetches are not blocked by the
1296+
// non-existing WAR hazards.
1297+
ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) {
1298+
const auto& opt =
1299+
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
1300+
circular_buffer_loop->iter_domain());
1301+
auto circular_buffer_tvs =
1302+
GpuLower::current()->circularBufferInfo().getCircularBufferTvs(
1303+
circular_buffer_loop->iter_domain());
1304+
VectorOfUniqueEntries<TensorView*> mbarriers;
1305+
for (auto tv : circular_buffer_tvs) {
1306+
auto ldst = dynamic_cast<LoadStoreOp*>(tv->definition());
1307+
NVF_ERROR(ldst != nullptr);
1308+
auto it = GpuLower::current()->mbarrierMap().find(ldst);
1309+
if (it == GpuLower::current()->mbarrierMap().end()) {
1310+
continue;
1311+
}
1312+
mbarriers.pushBack(it->second);
1313+
}
1314+
auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1);
1315+
for (auto mbarrier : mbarriers) {
1316+
auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
1317+
mbarrier,
1318+
SimplifyingIrBuilder::addExpr(
1319+
prefetch_loop->indexOrStartIfTrivial(), opt.stage));
1320+
auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
1321+
/*state=*/nullptr, mbarrier_to_arrive);
1322+
prefetch_loop->body().push_back(prefetch);
1323+
}
1324+
return prefetch_loop;
1325+
}
1326+
12861327
} // namespace
12871328

1288-
// Apply circular buffering transformations
1289-
class CircularBufferInserter : private kir::ExprMutator {
1329+
// Apply warp specialized circular buffering transformations
1330+
class WarpSpecializedCircularBufferInserter : private kir::ExprMutator {
12901331
public:
12911332
// When there exist multiple circular buffer loops, apply
12921333
// transformations to inner-most loops first. A single ExprMutator
@@ -1296,14 +1337,15 @@ class CircularBufferInserter : private kir::ExprMutator {
12961337
InsertionInfo insertion_info) {
12971338
std::vector<Expr*> inserted_exprs = exprs;
12981339
while (!insertion_info.empty()) {
1299-
CircularBufferInserter inserter(inserted_exprs, insertion_info);
1340+
WarpSpecializedCircularBufferInserter inserter(
1341+
inserted_exprs, insertion_info);
13001342
inserted_exprs = inserter.exprs_;
13011343
}
13021344
return inserted_exprs;
13031345
}
13041346

13051347
private:
1306-
CircularBufferInserter(
1348+
WarpSpecializedCircularBufferInserter(
13071349
const std::vector<Expr*>& exprs,
13081350
InsertionInfo& insertion_info)
13091351
: insertion_info_(insertion_info) {
@@ -1329,143 +1371,24 @@ class CircularBufferInserter : private kir::ExprMutator {
13291371
return;
13301372
}
13311373

1332-
auto has_cp_async_bulk = std::any_of(
1333-
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk);
1334-
13351374
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
13361375
GpuLower::current()
13371376
->circularBufferInfo()
13381377
.getCircularBufferOptionsFor(loop->iter_domain())
13391378
.type);
1340-
if (use_warp_specialization) {
1341-
NVF_ERROR(
1342-
std::all_of(
1343-
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk),
1344-
"In order to use warp specialization, all buffers must be loaded by TMA");
1345-
int64_t insertion_position =
1346-
GpuLower::current()
1347-
->circularBufferInfo()
1348-
.getCircularBufferInsertionPosition(loop->iter_domain());
1349-
insertTmaWarpSpecialized(loop, it->second, insertion_position);
1350-
} else if (has_cp_async_bulk) {
1351-
insertTmaPipelined(loop, it->second);
1352-
} else {
1353-
insert(loop, it->second);
1354-
}
1355-
processed_loop_ = loop;
1356-
insertion_info_.erase(loop);
1357-
}
1358-
1359-
bool hasPrefetch(ForLoop* circular_buffer_loop) {
1360-
int64_t prefetch_distance =
1379+
NVF_ERROR(use_warp_specialization);
1380+
NVF_ERROR(
1381+
std::all_of(
1382+
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk),
1383+
"In order to use warp specialization, all buffers must be loaded by TMA");
1384+
int64_t insertion_position =
13611385
GpuLower::current()
13621386
->circularBufferInfo()
1363-
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
1364-
.prefetch;
1365-
return prefetch_distance > 0;
1366-
}
1367-
1368-
// Create something like below:
1369-
// for (int i = 0; i < prefetch + 1; ++i) {
1370-
// mbarrier::arrive(mbarrier0[stage + i]]);
1371-
// mbarrier::arrive(mbarrier1[stage + i]);
1372-
// ...
1373-
// }
1374-
// where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i.
1375-
//
1376-
// This is needed because we prefetch data in circular buffering, and we
1377-
// need to make sure the initial prefetches are not blocked by the
1378-
// non-existing WAR hazards.
1379-
ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) {
1380-
const auto& opt =
1381-
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
1382-
circular_buffer_loop->iter_domain());
1383-
auto circular_buffer_tvs =
1384-
GpuLower::current()->circularBufferInfo().getCircularBufferTvs(
1385-
circular_buffer_loop->iter_domain());
1386-
VectorOfUniqueEntries<TensorView*> mbarriers;
1387-
for (auto tv : circular_buffer_tvs) {
1388-
auto ldst = dynamic_cast<LoadStoreOp*>(tv->definition());
1389-
NVF_ERROR(ldst != nullptr);
1390-
auto it = GpuLower::current()->mbarrierMap().find(ldst);
1391-
if (it == GpuLower::current()->mbarrierMap().end()) {
1392-
continue;
1393-
}
1394-
mbarriers.pushBack(it->second);
1395-
}
1396-
auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1);
1397-
for (auto mbarrier : mbarriers) {
1398-
auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
1399-
mbarrier,
1400-
SimplifyingIrBuilder::addExpr(
1401-
prefetch_loop->indexOrStartIfTrivial(), opt.stage));
1402-
auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
1403-
/*state=*/nullptr, mbarrier_to_arrive);
1404-
prefetch_loop->body().push_back(prefetch);
1405-
}
1406-
return prefetch_loop;
1407-
}
1408-
1409-
static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) {
1410-
return GpuLower::current()
1411-
->circularBufferInfo()
1412-
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
1413-
.usesMBarrierForWAR();
1414-
}
1415-
1416-
void insertTmaPipelined(
1417-
ForLoop* circular_buffer_loop,
1418-
const std::vector<Expr*>& loads) {
1419-
// Arrive on the WAR mbarriers to let the prefetching start.
1420-
if (usesMBarrierForWAR(circular_buffer_loop)) {
1421-
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
1422-
registerInsertBefore(circular_buffer_loop, prefetch_loop);
1423-
}
1424-
1425-
// Prologue loop:
1426-
// - launch only
1427-
// - arrive_expect_tx and tma load operations
1428-
if (hasPrefetch(circular_buffer_loop)) {
1429-
// If there is no prefetch, then we don't need a prologue loop.
1430-
ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1431-
circular_buffer_loop,
1432-
loads,
1433-
CircularBufferLoopStage::Prolog,
1434-
/*insertion_position=*/1);
1435-
registerInsertBefore(circular_buffer_loop, prologue_loop);
1436-
}
1437-
1438-
// Main loop:
1439-
// - Launch and wait
1440-
// - arrive_expect_tx, tma load operations, and mbarrier_wait
1441-
ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1442-
circular_buffer_loop,
1443-
loads,
1444-
CircularBufferLoopStage::Main,
1445-
/*insertion_position=*/1);
1446-
registerReplace(circular_buffer_loop, main_loop);
1387+
.getCircularBufferInsertionPosition(loop->iter_domain());
1388+
insertTmaWarpSpecialized(loop, it->second, insertion_position);
14471389

1448-
if (!hasPrefetch(circular_buffer_loop)) {
1449-
// If there is no prefetch, then we don't need a epilogue loop.
1450-
return;
1451-
}
1452-
1453-
// We can use exclude argument in
1454-
// CloneTmaCircularBufferLoopAndInsertSync clone to avoid
1455-
// duplicating allocations if main loop is trivial.
1456-
std::unordered_set<Expr*> expressions_allocated_in_main_loop;
1457-
getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop);
1458-
1459-
// Epilogue loop:
1460-
// - wait only
1461-
// - mbarrier_wait
1462-
ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1463-
circular_buffer_loop,
1464-
loads,
1465-
CircularBufferLoopStage::Epilog,
1466-
/*insertion_position=*/1,
1467-
expressions_allocated_in_main_loop);
1468-
registerInsertAfter(circular_buffer_loop, epilogue_loop);
1390+
processed_loop_ = loop;
1391+
insertion_info_.erase(loop);
14691392
}
14701393

14711394
void insertTmaWarpSpecialized(
@@ -1559,6 +1482,145 @@ class CircularBufferInserter : private kir::ExprMutator {
15591482
registerReplace(circular_buffer_loop, warp_dispatch_ite);
15601483
}
15611484

1485+
private:
1486+
InsertionInfo& insertion_info_;
1487+
ForLoop* processed_loop_ = nullptr;
1488+
};
1489+
1490+
// Apply pipeline circular buffering transformations
1491+
class PipelineCircularBufferInserter : private kir::ExprMutator {
1492+
public:
1493+
// When there exist multiple circular buffer loops, apply
1494+
// transformations to inner-most loops first. A single ExprMutator
1495+
// pass can only process one loop.
1496+
static std::vector<Expr*> run(
1497+
const std::vector<Expr*>& exprs,
1498+
InsertionInfo insertion_info) {
1499+
std::vector<Expr*> inserted_exprs = exprs;
1500+
while (!insertion_info.empty()) {
1501+
PipelineCircularBufferInserter inserter(inserted_exprs, insertion_info);
1502+
inserted_exprs = inserter.exprs_;
1503+
}
1504+
return inserted_exprs;
1505+
}
1506+
1507+
private:
1508+
PipelineCircularBufferInserter(
1509+
const std::vector<Expr*>& exprs,
1510+
InsertionInfo& insertion_info)
1511+
: insertion_info_(insertion_info) {
1512+
size_t num_circular_buffer_loops = insertion_info.size();
1513+
traverseAndInsert(exprs);
1514+
NVF_ERROR(processed_loop_ != nullptr);
1515+
NVF_ERROR(insertion_info.size() == num_circular_buffer_loops - 1);
1516+
}
1517+
1518+
using kir::ExprMutator::handle;
1519+
1520+
void handle(ForLoop* loop) final {
1521+
kir::ExprMutator::handle(loop);
1522+
1523+
// If another loop is already taken care of, no more loop should
1524+
// be done in the same pass
1525+
if (processed_loop_ != nullptr) {
1526+
return;
1527+
}
1528+
1529+
auto it = insertion_info_.find(loop);
1530+
if (it == insertion_info_.end()) {
1531+
return;
1532+
}
1533+
1534+
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1535+
GpuLower::current()
1536+
->circularBufferInfo()
1537+
.getCircularBufferOptionsFor(loop->iter_domain())
1538+
.type);
1539+
NVF_ERROR(!use_warp_specialization);
1540+
1541+
auto has_cp_async_bulk = std::any_of(
1542+
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk);
1543+
if (has_cp_async_bulk) {
1544+
insertTmaPipelined(loop, it->second);
1545+
} else {
1546+
insert(loop, it->second);
1547+
}
1548+
1549+
processed_loop_ = loop;
1550+
insertion_info_.erase(loop);
1551+
}
1552+
1553+
bool hasPrefetch(ForLoop* circular_buffer_loop) {
1554+
int64_t prefetch_distance =
1555+
GpuLower::current()
1556+
->circularBufferInfo()
1557+
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
1558+
.prefetch;
1559+
return prefetch_distance > 0;
1560+
}
1561+
1562+
static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) {
1563+
return GpuLower::current()
1564+
->circularBufferInfo()
1565+
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
1566+
.usesMBarrierForWAR();
1567+
}
1568+
1569+
void insertTmaPipelined(
1570+
ForLoop* circular_buffer_loop,
1571+
const std::vector<Expr*>& loads) {
1572+
// Arrive on the WAR mbarriers to let the prefetching start.
1573+
if (usesMBarrierForWAR(circular_buffer_loop)) {
1574+
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
1575+
registerInsertBefore(circular_buffer_loop, prefetch_loop);
1576+
}
1577+
1578+
// Prologue loop:
1579+
// - launch only
1580+
// - arrive_expect_tx and tma load operations
1581+
if (hasPrefetch(circular_buffer_loop)) {
1582+
// If there is no prefetch, then we don't need a prologue loop.
1583+
ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1584+
circular_buffer_loop,
1585+
loads,
1586+
CircularBufferLoopStage::Prolog,
1587+
/*insertion_position=*/1);
1588+
registerInsertBefore(circular_buffer_loop, prologue_loop);
1589+
}
1590+
1591+
// Main loop:
1592+
// - Launch and wait
1593+
// - arrive_expect_tx, tma load operations, and mbarrier_wait
1594+
ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1595+
circular_buffer_loop,
1596+
loads,
1597+
CircularBufferLoopStage::Main,
1598+
/*insertion_position=*/1);
1599+
registerReplace(circular_buffer_loop, main_loop);
1600+
1601+
if (!hasPrefetch(circular_buffer_loop)) {
1602+
// If there is no prefetch, then we don't need a epilogue loop.
1603+
return;
1604+
}
1605+
1606+
// We can use exclude argument in
1607+
// CloneTmaCircularBufferLoopAndInsertSync clone to avoid
1608+
// duplicating allocations if main loop is trivial.
1609+
std::unordered_set<Expr*> expressions_allocated_in_main_loop;
1610+
getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop);
1611+
1612+
// Epilogue loop:
1613+
// - wait only
1614+
// - mbarrier_wait
1615+
ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1616+
circular_buffer_loop,
1617+
loads,
1618+
CircularBufferLoopStage::Epilog,
1619+
/*insertion_position=*/1,
1620+
expressions_allocated_in_main_loop);
1621+
registerInsertAfter(circular_buffer_loop, epilogue_loop);
1622+
}
1623+
15621624
void insert(ForLoop* circular_buffer_loop, const std::vector<Expr*>& loads) {
15631625
NVF_ERROR(
15641626
!usesMBarrierForWAR(circular_buffer_loop),
@@ -1788,8 +1850,9 @@ std::vector<Expr*> CircularBufferPass::run(const std::vector<Expr*>& exprs) {
17881850
// Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized
17891851
// inside of Pipeline circular buffering.
17901852
std::vector<Expr*> result_exprs =
1791-
CircularBufferInserter::run(exprs, pipeline_insertion_info);
1792-
return CircularBufferInserter::run(result_exprs, ws_insertion_info);
1853+
PipelineCircularBufferInserter::run(exprs, pipeline_insertion_info);
1854+
return WarpSpecializedCircularBufferInserter::run(
1855+
result_exprs, ws_insertion_info);
17931856
}
17941857

17951858
} // namespace nvfuser

0 commit comments

Comments
 (0)