@@ -1334,10 +1334,51 @@ void getAllocInTrivialLoop(ForLoop* fl, std::unordered_set<Expr*>& output) {
1334
1334
}
1335
1335
}
1336
1336
1337
+ // Create something like below:
1338
+ // for (int i = 0; i < prefetch + 1; ++i) {
1339
+ // mbarrier::arrive(mbarrier0[stage + i]]);
1340
+ // mbarrier::arrive(mbarrier1[stage + i]);
1341
+ // ...
1342
+ // }
1343
+ // where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i.
1344
+ //
1345
+ // This is needed because we prefetch data in circular buffering, and we
1346
+ // need to make sure the initial prefetches are not blocked by the
1347
+ // non-existing WAR hazards.
1348
+ ForLoop* createArrivesForWar (ForLoop* circular_buffer_loop) {
1349
+ const auto & opt =
1350
+ GpuLower::current ()->circularBufferInfo ().getCircularBufferOptionsFor (
1351
+ circular_buffer_loop->iter_domain ());
1352
+ auto circular_buffer_tvs =
1353
+ GpuLower::current ()->circularBufferInfo ().getCircularBufferTvs (
1354
+ circular_buffer_loop->iter_domain ());
1355
+ VectorOfUniqueEntries<TensorView*> mbarriers;
1356
+ for (auto tv : circular_buffer_tvs) {
1357
+ auto ldst = dynamic_cast <LoadStoreOp*>(tv->definition ());
1358
+ NVF_ERROR (ldst != nullptr );
1359
+ auto it = GpuLower::current ()->mbarrierMap ().find (ldst);
1360
+ if (it == GpuLower::current ()->mbarrierMap ().end ()) {
1361
+ continue ;
1362
+ }
1363
+ mbarriers.pushBack (it->second );
1364
+ }
1365
+ auto prefetch_loop = ir_utils::createRangeLoop (opt.prefetch + 1 );
1366
+ for (auto mbarrier : mbarriers) {
1367
+ auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
1368
+ mbarrier,
1369
+ SimplifyingIrBuilder::addExpr (
1370
+ prefetch_loop->indexOrStartIfTrivial (), opt.stage ));
1371
+ auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
1372
+ /* state=*/ nullptr , mbarrier_to_arrive);
1373
+ prefetch_loop->body ().push_back (prefetch);
1374
+ }
1375
+ return prefetch_loop;
1376
+ }
1377
+
1337
1378
} // namespace
1338
1379
1339
- // Apply circular buffering transformations
1340
- class CircularBufferInserter : private kir ::ExprMutator {
1380
+ // Apply warp specialized circular buffering transformations
1381
+ class WarpSpecializedCircularBufferInserter : private kir ::ExprMutator {
1341
1382
public:
1342
1383
// When there exist multiple circular buffer loops, apply
1343
1384
// transformations to inner-most loops first. A single ExprMutator
@@ -1347,14 +1388,15 @@ class CircularBufferInserter : private kir::ExprMutator {
1347
1388
InsertionInfo insertion_info) {
1348
1389
std::vector<Expr*> inserted_exprs = exprs;
1349
1390
while (!insertion_info.empty ()) {
1350
- CircularBufferInserter inserter (inserted_exprs, insertion_info);
1391
+ WarpSpecializedCircularBufferInserter inserter (
1392
+ inserted_exprs, insertion_info);
1351
1393
inserted_exprs = inserter.exprs_ ;
1352
1394
}
1353
1395
return inserted_exprs;
1354
1396
}
1355
1397
1356
1398
private:
1357
- CircularBufferInserter (
1399
+ WarpSpecializedCircularBufferInserter (
1358
1400
const std::vector<Expr*>& exprs,
1359
1401
InsertionInfo& insertion_info)
1360
1402
: insertion_info_(insertion_info) {
@@ -1380,143 +1422,24 @@ class CircularBufferInserter : private kir::ExprMutator {
1380
1422
return ;
1381
1423
}
1382
1424
1383
- auto has_cp_async_bulk = std::any_of (
1384
- it->second .begin (), it->second .end (), ir_utils::isCpAsyncBulk);
1385
-
1386
1425
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1387
1426
GpuLower::current ()
1388
1427
->circularBufferInfo ()
1389
1428
.getCircularBufferOptionsFor (loop->iter_domain ())
1390
1429
.type );
1391
- if (use_warp_specialization) {
1392
- NVF_ERROR (
1393
- std::all_of (
1394
- it->second .begin (), it->second .end (), ir_utils::isCpAsyncBulk),
1395
- " In order to use warp specialization, all buffers must be loaded by TMA" );
1396
- int64_t insertion_position =
1397
- GpuLower::current ()
1398
- ->circularBufferInfo ()
1399
- .getCircularBufferInsertionPosition (loop->iter_domain ());
1400
- insertTmaWarpSpecialized (loop, it->second , insertion_position);
1401
- } else if (has_cp_async_bulk) {
1402
- insertTmaPipelined (loop, it->second );
1403
- } else {
1404
- insert (loop, it->second );
1405
- }
1406
- processed_loop_ = loop;
1407
- insertion_info_.erase (loop);
1408
- }
1409
-
1410
- bool hasPrefetch (ForLoop* circular_buffer_loop) {
1411
- int64_t prefetch_distance =
1430
+ NVF_ERROR (use_warp_specialization);
1431
+ NVF_ERROR (
1432
+ std::all_of (
1433
+ it->second .begin (), it->second .end (), ir_utils::isCpAsyncBulk),
1434
+ " In order to use warp specialization, all buffers must be loaded by TMA" );
1435
+ int64_t insertion_position =
1412
1436
GpuLower::current ()
1413
1437
->circularBufferInfo ()
1414
- .getCircularBufferOptionsFor (circular_buffer_loop->iter_domain ())
1415
- .prefetch ;
1416
- return prefetch_distance > 0 ;
1417
- }
1418
-
1419
- // Create something like below:
1420
- // for (int i = 0; i < prefetch + 1; ++i) {
1421
- // mbarrier::arrive(mbarrier0[stage + i]]);
1422
- // mbarrier::arrive(mbarrier1[stage + i]);
1423
- // ...
1424
- // }
1425
- // where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i.
1426
- //
1427
- // This is needed because we prefetch data in circular buffering, and we
1428
- // need to make sure the initial prefetches are not blocked by the
1429
- // non-existing WAR hazards.
1430
- ForLoop* createArrivesForWar (ForLoop* circular_buffer_loop) {
1431
- const auto & opt =
1432
- GpuLower::current ()->circularBufferInfo ().getCircularBufferOptionsFor (
1433
- circular_buffer_loop->iter_domain ());
1434
- auto circular_buffer_tvs =
1435
- GpuLower::current ()->circularBufferInfo ().getCircularBufferTvs (
1436
- circular_buffer_loop->iter_domain ());
1437
- VectorOfUniqueEntries<TensorView*> mbarriers;
1438
- for (auto tv : circular_buffer_tvs) {
1439
- auto ldst = dynamic_cast <LoadStoreOp*>(tv->definition ());
1440
- NVF_ERROR (ldst != nullptr );
1441
- auto it = GpuLower::current ()->mbarrierMap ().find (ldst);
1442
- if (it == GpuLower::current ()->mbarrierMap ().end ()) {
1443
- continue ;
1444
- }
1445
- mbarriers.pushBack (it->second );
1446
- }
1447
- auto prefetch_loop = ir_utils::createRangeLoop (opt.prefetch + 1 );
1448
- for (auto mbarrier : mbarriers) {
1449
- auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
1450
- mbarrier,
1451
- SimplifyingIrBuilder::addExpr (
1452
- prefetch_loop->indexOrStartIfTrivial (), opt.stage ));
1453
- auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
1454
- /* state=*/ nullptr , mbarrier_to_arrive);
1455
- prefetch_loop->body ().push_back (prefetch);
1456
- }
1457
- return prefetch_loop;
1458
- }
1459
-
1460
- static bool usesMBarrierForWAR (ForLoop* circular_buffer_loop) {
1461
- return GpuLower::current ()
1462
- ->circularBufferInfo ()
1463
- .getCircularBufferOptionsFor (circular_buffer_loop->iter_domain ())
1464
- .usesMBarrierForWAR ();
1465
- }
1466
-
1467
- void insertTmaPipelined (
1468
- ForLoop* circular_buffer_loop,
1469
- const std::vector<Expr*>& loads) {
1470
- // Arrive on the WAR mbarriers to let the prefetching start.
1471
- if (usesMBarrierForWAR (circular_buffer_loop)) {
1472
- auto prefetch_loop = createArrivesForWar (circular_buffer_loop);
1473
- registerInsertBefore (circular_buffer_loop, prefetch_loop);
1474
- }
1475
-
1476
- // Prologue loop:
1477
- // - launch only
1478
- // - arrive_expect_tx and tma load operations
1479
- if (hasPrefetch (circular_buffer_loop)) {
1480
- // If there is no prefetch, then we don't need a prologue loop.
1481
- ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1482
- circular_buffer_loop,
1483
- loads,
1484
- CircularBufferLoopStage::Prolog,
1485
- /* insertion_position=*/ 1 );
1486
- registerInsertBefore (circular_buffer_loop, prologue_loop);
1487
- }
1488
-
1489
- // Main loop:
1490
- // - Launch and wait
1491
- // - arrive_expect_tx, tma load operations, and mbarrier_wait
1492
- ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1493
- circular_buffer_loop,
1494
- loads,
1495
- CircularBufferLoopStage::Main,
1496
- /* insertion_position=*/ 1 );
1497
- registerReplace (circular_buffer_loop, main_loop);
1438
+ .getCircularBufferInsertionPosition (loop->iter_domain ());
1439
+ insertTmaWarpSpecialized (loop, it->second , insertion_position);
1498
1440
1499
- if (!hasPrefetch (circular_buffer_loop)) {
1500
- // If there is no prefetch, then we don't need a epilogue loop.
1501
- return ;
1502
- }
1503
-
1504
- // We can use exclude argument in
1505
- // CloneTmaCircularBufferLoopAndInsertSync clone to avoid
1506
- // duplicating allocations if main loop is trivial.
1507
- std::unordered_set<Expr*> expressions_allocated_in_main_loop;
1508
- getAllocInTrivialLoop (main_loop, expressions_allocated_in_main_loop);
1509
-
1510
- // Epilogue loop:
1511
- // - wait only
1512
- // - mbarrier_wait
1513
- ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1514
- circular_buffer_loop,
1515
- loads,
1516
- CircularBufferLoopStage::Epilog,
1517
- /* insertion_position=*/ 1 ,
1518
- expressions_allocated_in_main_loop);
1519
- registerInsertAfter (circular_buffer_loop, epilogue_loop);
1441
+ processed_loop_ = loop;
1442
+ insertion_info_.erase (loop);
1520
1443
}
1521
1444
1522
1445
void insertTmaWarpSpecialized (
@@ -1610,6 +1533,145 @@ class CircularBufferInserter : private kir::ExprMutator {
1610
1533
registerReplace (circular_buffer_loop, warp_dispatch_ite);
1611
1534
}
1612
1535
1536
+ private:
1537
+ InsertionInfo& insertion_info_;
1538
+ ForLoop* processed_loop_ = nullptr ;
1539
+ };
1540
+
1541
+ // Apply pipeline circular buffering transformations
1542
+ class PipelineCircularBufferInserter : private kir ::ExprMutator {
1543
+ public:
1544
+ // When there exist multiple circular buffer loops, apply
1545
+ // transformations to inner-most loops first. A single ExprMutator
1546
+ // pass can only process one loop.
1547
+ static std::vector<Expr*> run (
1548
+ const std::vector<Expr*>& exprs,
1549
+ InsertionInfo insertion_info) {
1550
+ std::vector<Expr*> inserted_exprs = exprs;
1551
+ while (!insertion_info.empty ()) {
1552
+ PipelineCircularBufferInserter inserter (inserted_exprs, insertion_info);
1553
+ inserted_exprs = inserter.exprs_ ;
1554
+ }
1555
+ return inserted_exprs;
1556
+ }
1557
+
1558
+ private:
1559
+ PipelineCircularBufferInserter (
1560
+ const std::vector<Expr*>& exprs,
1561
+ InsertionInfo& insertion_info)
1562
+ : insertion_info_(insertion_info) {
1563
+ size_t num_circular_buffer_loops = insertion_info.size ();
1564
+ traverseAndInsert (exprs);
1565
+ NVF_ERROR (processed_loop_ != nullptr );
1566
+ NVF_ERROR (insertion_info.size () == num_circular_buffer_loops - 1 );
1567
+ }
1568
+
1569
+ using kir::ExprMutator::handle;
1570
+
1571
+ void handle (ForLoop* loop) final {
1572
+ kir::ExprMutator::handle (loop);
1573
+
1574
+ // If another loop is already taken care of, no more loop should
1575
+ // be done in the same pass
1576
+ if (processed_loop_ != nullptr ) {
1577
+ return ;
1578
+ }
1579
+
1580
+ auto it = insertion_info_.find (loop);
1581
+ if (it == insertion_info_.end ()) {
1582
+ return ;
1583
+ }
1584
+
1585
+ bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1586
+ GpuLower::current ()
1587
+ ->circularBufferInfo ()
1588
+ .getCircularBufferOptionsFor (loop->iter_domain ())
1589
+ .type );
1590
+ NVF_ERROR (!use_warp_specialization);
1591
+
1592
+ auto has_cp_async_bulk = std::any_of (
1593
+ it->second .begin (), it->second .end (), ir_utils::isCpAsyncBulk);
1594
+ if (has_cp_async_bulk) {
1595
+ insertTmaPipelined (loop, it->second );
1596
+ } else {
1597
+ insert (loop, it->second );
1598
+ }
1599
+
1600
+ processed_loop_ = loop;
1601
+ insertion_info_.erase (loop);
1602
+ }
1603
+
1604
+ bool hasPrefetch (ForLoop* circular_buffer_loop) {
1605
+ int64_t prefetch_distance =
1606
+ GpuLower::current ()
1607
+ ->circularBufferInfo ()
1608
+ .getCircularBufferOptionsFor (circular_buffer_loop->iter_domain ())
1609
+ .prefetch ;
1610
+ return prefetch_distance > 0 ;
1611
+ }
1612
+
1613
+ static bool usesMBarrierForWAR (ForLoop* circular_buffer_loop) {
1614
+ return GpuLower::current ()
1615
+ ->circularBufferInfo ()
1616
+ .getCircularBufferOptionsFor (circular_buffer_loop->iter_domain ())
1617
+ .usesMBarrierForWAR ();
1618
+ }
1619
+
1620
+ void insertTmaPipelined (
1621
+ ForLoop* circular_buffer_loop,
1622
+ const std::vector<Expr*>& loads) {
1623
+ // Arrive on the WAR mbarriers to let the prefetching start.
1624
+ if (usesMBarrierForWAR (circular_buffer_loop)) {
1625
+ auto prefetch_loop = createArrivesForWar (circular_buffer_loop);
1626
+ registerInsertBefore (circular_buffer_loop, prefetch_loop);
1627
+ }
1628
+
1629
+ // Prologue loop:
1630
+ // - launch only
1631
+ // - arrive_expect_tx and tma load operations
1632
+ if (hasPrefetch (circular_buffer_loop)) {
1633
+ // If there is no prefetch, then we don't need a prologue loop.
1634
+ ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1635
+ circular_buffer_loop,
1636
+ loads,
1637
+ CircularBufferLoopStage::Prolog,
1638
+ /* insertion_position=*/ 1 );
1639
+ registerInsertBefore (circular_buffer_loop, prologue_loop);
1640
+ }
1641
+
1642
+ // Main loop:
1643
+ // - Launch and wait
1644
+ // - arrive_expect_tx, tma load operations, and mbarrier_wait
1645
+ ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1646
+ circular_buffer_loop,
1647
+ loads,
1648
+ CircularBufferLoopStage::Main,
1649
+ /* insertion_position=*/ 1 );
1650
+ registerReplace (circular_buffer_loop, main_loop);
1651
+
1652
+ if (!hasPrefetch (circular_buffer_loop)) {
1653
+ // If there is no prefetch, then we don't need a epilogue loop.
1654
+ return ;
1655
+ }
1656
+
1657
+ // We can use exclude argument in
1658
+ // CloneTmaCircularBufferLoopAndInsertSync clone to avoid
1659
+ // duplicating allocations if main loop is trivial.
1660
+ std::unordered_set<Expr*> expressions_allocated_in_main_loop;
1661
+ getAllocInTrivialLoop (main_loop, expressions_allocated_in_main_loop);
1662
+
1663
+ // Epilogue loop:
1664
+ // - wait only
1665
+ // - mbarrier_wait
1666
+ ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1667
+ circular_buffer_loop,
1668
+ loads,
1669
+ CircularBufferLoopStage::Epilog,
1670
+ /* insertion_position=*/ 1 ,
1671
+ expressions_allocated_in_main_loop);
1672
+ registerInsertAfter (circular_buffer_loop, epilogue_loop);
1673
+ }
1674
+
1613
1675
void insert (ForLoop* circular_buffer_loop, const std::vector<Expr*>& loads) {
1614
1676
NVF_ERROR (
1615
1677
!usesMBarrierForWAR (circular_buffer_loop),
@@ -1839,8 +1901,9 @@ std::vector<Expr*> CircularBufferPass::run(const std::vector<Expr*>& exprs) {
1839
1901
// Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized
1840
1902
// inside of Pipeline circular buffering.
1841
1903
std::vector<Expr*> result_exprs =
1842
- CircularBufferInserter::run (exprs, pipeline_insertion_info);
1843
- return CircularBufferInserter::run (result_exprs, ws_insertion_info);
1904
+ PipelineCircularBufferInserter::run (exprs, pipeline_insertion_info);
1905
+ return WarpSpecializedCircularBufferInserter::run (
1906
+ result_exprs, ws_insertion_info);
1844
1907
}
1845
1908
1846
1909
} // namespace nvfuser
0 commit comments