@@ -1283,10 +1283,51 @@ void getAllocInTrivialLoop(ForLoop* fl, std::unordered_set<Expr*>& output) {
1283
1283
}
1284
1284
}
1285
1285
1286
+ // Create something like below:
1287
+ // for (int i = 0; i < prefetch + 1; ++i) {
1288
+ // mbarrier::arrive(mbarrier0[stage + i]]);
1289
+ // mbarrier::arrive(mbarrier1[stage + i]);
1290
+ // ...
1291
+ // }
1292
+ // where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i.
1293
+ //
1294
+ // This is needed because we prefetch data in circular buffering, and we
1295
+ // need to make sure the initial prefetches are not blocked by the
1296
+ // non-existing WAR hazards.
1297
+ ForLoop* createArrivesForWar (ForLoop* circular_buffer_loop) {
1298
+ const auto & opt =
1299
+ GpuLower::current ()->circularBufferInfo ().getCircularBufferOptionsFor (
1300
+ circular_buffer_loop->iter_domain ());
1301
+ auto circular_buffer_tvs =
1302
+ GpuLower::current ()->circularBufferInfo ().getCircularBufferTvs (
1303
+ circular_buffer_loop->iter_domain ());
1304
+ VectorOfUniqueEntries<TensorView*> mbarriers;
1305
+ for (auto tv : circular_buffer_tvs) {
1306
+ auto ldst = dynamic_cast <LoadStoreOp*>(tv->definition ());
1307
+ NVF_ERROR (ldst != nullptr );
1308
+ auto it = GpuLower::current ()->mbarrierMap ().find (ldst);
1309
+ if (it == GpuLower::current ()->mbarrierMap ().end ()) {
1310
+ continue ;
1311
+ }
1312
+ mbarriers.pushBack (it->second );
1313
+ }
1314
+ auto prefetch_loop = ir_utils::createRangeLoop (opt.prefetch + 1 );
1315
+ for (auto mbarrier : mbarriers) {
1316
+ auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
1317
+ mbarrier,
1318
+ SimplifyingIrBuilder::addExpr (
1319
+ prefetch_loop->indexOrStartIfTrivial (), opt.stage ));
1320
+ auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
1321
+ /* state=*/ nullptr , mbarrier_to_arrive);
1322
+ prefetch_loop->body ().push_back (prefetch);
1323
+ }
1324
+ return prefetch_loop;
1325
+ }
1326
+
1286
1327
} // namespace
1287
1328
1288
- // Apply circular buffering transformations
1289
- class CircularBufferInserter : private kir ::ExprMutator {
1329
+ // Apply warp specialized circular buffering transformations
1330
+ class WarpSpecializedCircularBufferInserter : private kir ::ExprMutator {
1290
1331
public:
1291
1332
// When there exist multiple circular buffer loops, apply
1292
1333
// transformations to inner-most loops first. A single ExprMutator
@@ -1296,14 +1337,15 @@ class CircularBufferInserter : private kir::ExprMutator {
1296
1337
InsertionInfo insertion_info) {
1297
1338
std::vector<Expr*> inserted_exprs = exprs;
1298
1339
while (!insertion_info.empty ()) {
1299
- CircularBufferInserter inserter (inserted_exprs, insertion_info);
1340
+ WarpSpecializedCircularBufferInserter inserter (
1341
+ inserted_exprs, insertion_info);
1300
1342
inserted_exprs = inserter.exprs_ ;
1301
1343
}
1302
1344
return inserted_exprs;
1303
1345
}
1304
1346
1305
1347
private:
1306
- CircularBufferInserter (
1348
+ WarpSpecializedCircularBufferInserter (
1307
1349
const std::vector<Expr*>& exprs,
1308
1350
InsertionInfo& insertion_info)
1309
1351
: insertion_info_(insertion_info) {
@@ -1329,143 +1371,24 @@ class CircularBufferInserter : private kir::ExprMutator {
1329
1371
return ;
1330
1372
}
1331
1373
1332
- auto has_cp_async_bulk = std::any_of (
1333
- it->second .begin (), it->second .end (), ir_utils::isCpAsyncBulk);
1334
-
1335
1374
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1336
1375
GpuLower::current ()
1337
1376
->circularBufferInfo ()
1338
1377
.getCircularBufferOptionsFor (loop->iter_domain ())
1339
1378
.type );
1340
- if (use_warp_specialization) {
1341
- NVF_ERROR (
1342
- std::all_of (
1343
- it->second .begin (), it->second .end (), ir_utils::isCpAsyncBulk),
1344
- " In order to use warp specialization, all buffers must be loaded by TMA" );
1345
- int64_t insertion_position =
1346
- GpuLower::current ()
1347
- ->circularBufferInfo ()
1348
- .getCircularBufferInsertionPosition (loop->iter_domain ());
1349
- insertTmaWarpSpecialized (loop, it->second , insertion_position);
1350
- } else if (has_cp_async_bulk) {
1351
- insertTmaPipelined (loop, it->second );
1352
- } else {
1353
- insert (loop, it->second );
1354
- }
1355
- processed_loop_ = loop;
1356
- insertion_info_.erase (loop);
1357
- }
1358
-
1359
- bool hasPrefetch (ForLoop* circular_buffer_loop) {
1360
- int64_t prefetch_distance =
1379
+ NVF_ERROR (use_warp_specialization);
1380
+ NVF_ERROR (
1381
+ std::all_of (
1382
+ it->second .begin (), it->second .end (), ir_utils::isCpAsyncBulk),
1383
+ " In order to use warp specialization, all buffers must be loaded by TMA" );
1384
+ int64_t insertion_position =
1361
1385
GpuLower::current ()
1362
1386
->circularBufferInfo ()
1363
- .getCircularBufferOptionsFor (circular_buffer_loop->iter_domain ())
1364
- .prefetch ;
1365
- return prefetch_distance > 0 ;
1366
- }
1367
-
1368
- // Create something like below:
1369
- // for (int i = 0; i < prefetch + 1; ++i) {
1370
- // mbarrier::arrive(mbarrier0[stage + i]]);
1371
- // mbarrier::arrive(mbarrier1[stage + i]);
1372
- // ...
1373
- // }
1374
- // where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i.
1375
- //
1376
- // This is needed because we prefetch data in circular buffering, and we
1377
- // need to make sure the initial prefetches are not blocked by the
1378
- // non-existing WAR hazards.
1379
- ForLoop* createArrivesForWar (ForLoop* circular_buffer_loop) {
1380
- const auto & opt =
1381
- GpuLower::current ()->circularBufferInfo ().getCircularBufferOptionsFor (
1382
- circular_buffer_loop->iter_domain ());
1383
- auto circular_buffer_tvs =
1384
- GpuLower::current ()->circularBufferInfo ().getCircularBufferTvs (
1385
- circular_buffer_loop->iter_domain ());
1386
- VectorOfUniqueEntries<TensorView*> mbarriers;
1387
- for (auto tv : circular_buffer_tvs) {
1388
- auto ldst = dynamic_cast <LoadStoreOp*>(tv->definition ());
1389
- NVF_ERROR (ldst != nullptr );
1390
- auto it = GpuLower::current ()->mbarrierMap ().find (ldst);
1391
- if (it == GpuLower::current ()->mbarrierMap ().end ()) {
1392
- continue ;
1393
- }
1394
- mbarriers.pushBack (it->second );
1395
- }
1396
- auto prefetch_loop = ir_utils::createRangeLoop (opt.prefetch + 1 );
1397
- for (auto mbarrier : mbarriers) {
1398
- auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
1399
- mbarrier,
1400
- SimplifyingIrBuilder::addExpr (
1401
- prefetch_loop->indexOrStartIfTrivial (), opt.stage ));
1402
- auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
1403
- /* state=*/ nullptr , mbarrier_to_arrive);
1404
- prefetch_loop->body ().push_back (prefetch);
1405
- }
1406
- return prefetch_loop;
1407
- }
1408
-
1409
- static bool usesMBarrierForWAR (ForLoop* circular_buffer_loop) {
1410
- return GpuLower::current ()
1411
- ->circularBufferInfo ()
1412
- .getCircularBufferOptionsFor (circular_buffer_loop->iter_domain ())
1413
- .usesMBarrierForWAR ();
1414
- }
1415
-
1416
- void insertTmaPipelined (
1417
- ForLoop* circular_buffer_loop,
1418
- const std::vector<Expr*>& loads) {
1419
- // Arrive on the WAR mbarriers to let the prefetching start.
1420
- if (usesMBarrierForWAR (circular_buffer_loop)) {
1421
- auto prefetch_loop = createArrivesForWar (circular_buffer_loop);
1422
- registerInsertBefore (circular_buffer_loop, prefetch_loop);
1423
- }
1424
-
1425
- // Prologue loop:
1426
- // - launch only
1427
- // - arrive_expect_tx and tma load operations
1428
- if (hasPrefetch (circular_buffer_loop)) {
1429
- // If there is no prefetch, then we don't need a prologue loop.
1430
- ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1431
- circular_buffer_loop,
1432
- loads,
1433
- CircularBufferLoopStage::Prolog,
1434
- /* insertion_position=*/ 1 );
1435
- registerInsertBefore (circular_buffer_loop, prologue_loop);
1436
- }
1437
-
1438
- // Main loop:
1439
- // - Launch and wait
1440
- // - arrive_expect_tx, tma load operations, and mbarrier_wait
1441
- ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1442
- circular_buffer_loop,
1443
- loads,
1444
- CircularBufferLoopStage::Main,
1445
- /* insertion_position=*/ 1 );
1446
- registerReplace (circular_buffer_loop, main_loop);
1387
+ .getCircularBufferInsertionPosition (loop->iter_domain ());
1388
+ insertTmaWarpSpecialized (loop, it->second , insertion_position);
1447
1389
1448
- if (!hasPrefetch (circular_buffer_loop)) {
1449
- // If there is no prefetch, then we don't need a epilogue loop.
1450
- return ;
1451
- }
1452
-
1453
- // We can use exclude argument in
1454
- // CloneTmaCircularBufferLoopAndInsertSync clone to avoid
1455
- // duplicating allocations if main loop is trivial.
1456
- std::unordered_set<Expr*> expressions_allocated_in_main_loop;
1457
- getAllocInTrivialLoop (main_loop, expressions_allocated_in_main_loop);
1458
-
1459
- // Epilogue loop:
1460
- // - wait only
1461
- // - mbarrier_wait
1462
- ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1463
- circular_buffer_loop,
1464
- loads,
1465
- CircularBufferLoopStage::Epilog,
1466
- /* insertion_position=*/ 1 ,
1467
- expressions_allocated_in_main_loop);
1468
- registerInsertAfter (circular_buffer_loop, epilogue_loop);
1390
+ processed_loop_ = loop;
1391
+ insertion_info_.erase (loop);
1469
1392
}
1470
1393
1471
1394
void insertTmaWarpSpecialized (
@@ -1559,6 +1482,145 @@ class CircularBufferInserter : private kir::ExprMutator {
1559
1482
registerReplace (circular_buffer_loop, warp_dispatch_ite);
1560
1483
}
1561
1484
1485
+ private:
1486
+ InsertionInfo& insertion_info_;
1487
+ ForLoop* processed_loop_ = nullptr ;
1488
+ };
1489
+
1490
+ // Apply pipeline circular buffering transformations
1491
+ class PipelineCircularBufferInserter : private kir ::ExprMutator {
1492
+ public:
1493
+ // When there exist multiple circular buffer loops, apply
1494
+ // transformations to inner-most loops first. A single ExprMutator
1495
+ // pass can only process one loop.
1496
+ static std::vector<Expr*> run (
1497
+ const std::vector<Expr*>& exprs,
1498
+ InsertionInfo insertion_info) {
1499
+ std::vector<Expr*> inserted_exprs = exprs;
1500
+ while (!insertion_info.empty ()) {
1501
+ PipelineCircularBufferInserter inserter (inserted_exprs, insertion_info);
1502
+ inserted_exprs = inserter.exprs_ ;
1503
+ }
1504
+ return inserted_exprs;
1505
+ }
1506
+
1507
+ private:
1508
+ PipelineCircularBufferInserter (
1509
+ const std::vector<Expr*>& exprs,
1510
+ InsertionInfo& insertion_info)
1511
+ : insertion_info_(insertion_info) {
1512
+ size_t num_circular_buffer_loops = insertion_info.size ();
1513
+ traverseAndInsert (exprs);
1514
+ NVF_ERROR (processed_loop_ != nullptr );
1515
+ NVF_ERROR (insertion_info.size () == num_circular_buffer_loops - 1 );
1516
+ }
1517
+
1518
+ using kir::ExprMutator::handle;
1519
+
1520
+ void handle (ForLoop* loop) final {
1521
+ kir::ExprMutator::handle (loop);
1522
+
1523
+ // If another loop is already taken care of, no more loop should
1524
+ // be done in the same pass
1525
+ if (processed_loop_ != nullptr ) {
1526
+ return ;
1527
+ }
1528
+
1529
+ auto it = insertion_info_.find (loop);
1530
+ if (it == insertion_info_.end ()) {
1531
+ return ;
1532
+ }
1533
+
1534
+ bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1535
+ GpuLower::current ()
1536
+ ->circularBufferInfo ()
1537
+ .getCircularBufferOptionsFor (loop->iter_domain ())
1538
+ .type );
1539
+ NVF_ERROR (!use_warp_specialization);
1540
+
1541
+ auto has_cp_async_bulk = std::any_of (
1542
+ it->second .begin (), it->second .end (), ir_utils::isCpAsyncBulk);
1543
+ if (has_cp_async_bulk) {
1544
+ insertTmaPipelined (loop, it->second );
1545
+ } else {
1546
+ insert (loop, it->second );
1547
+ }
1548
+
1549
+ processed_loop_ = loop;
1550
+ insertion_info_.erase (loop);
1551
+ }
1552
+
1553
+ bool hasPrefetch (ForLoop* circular_buffer_loop) {
1554
+ int64_t prefetch_distance =
1555
+ GpuLower::current ()
1556
+ ->circularBufferInfo ()
1557
+ .getCircularBufferOptionsFor (circular_buffer_loop->iter_domain ())
1558
+ .prefetch ;
1559
+ return prefetch_distance > 0 ;
1560
+ }
1561
+
1562
+ static bool usesMBarrierForWAR (ForLoop* circular_buffer_loop) {
1563
+ return GpuLower::current ()
1564
+ ->circularBufferInfo ()
1565
+ .getCircularBufferOptionsFor (circular_buffer_loop->iter_domain ())
1566
+ .usesMBarrierForWAR ();
1567
+ }
1568
+
1569
+ void insertTmaPipelined (
1570
+ ForLoop* circular_buffer_loop,
1571
+ const std::vector<Expr*>& loads) {
1572
+ // Arrive on the WAR mbarriers to let the prefetching start.
1573
+ if (usesMBarrierForWAR (circular_buffer_loop)) {
1574
+ auto prefetch_loop = createArrivesForWar (circular_buffer_loop);
1575
+ registerInsertBefore (circular_buffer_loop, prefetch_loop);
1576
+ }
1577
+
1578
+ // Prologue loop:
1579
+ // - launch only
1580
+ // - arrive_expect_tx and tma load operations
1581
+ if (hasPrefetch (circular_buffer_loop)) {
1582
+ // If there is no prefetch, then we don't need a prologue loop.
1583
+ ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1584
+ circular_buffer_loop,
1585
+ loads,
1586
+ CircularBufferLoopStage::Prolog,
1587
+ /* insertion_position=*/ 1 );
1588
+ registerInsertBefore (circular_buffer_loop, prologue_loop);
1589
+ }
1590
+
1591
+ // Main loop:
1592
+ // - Launch and wait
1593
+ // - arrive_expect_tx, tma load operations, and mbarrier_wait
1594
+ ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1595
+ circular_buffer_loop,
1596
+ loads,
1597
+ CircularBufferLoopStage::Main,
1598
+ /* insertion_position=*/ 1 );
1599
+ registerReplace (circular_buffer_loop, main_loop);
1600
+
1601
+ if (!hasPrefetch (circular_buffer_loop)) {
1602
+ // If there is no prefetch, then we don't need a epilogue loop.
1603
+ return ;
1604
+ }
1605
+
1606
+ // We can use exclude argument in
1607
+ // CloneTmaCircularBufferLoopAndInsertSync clone to avoid
1608
+ // duplicating allocations if main loop is trivial.
1609
+ std::unordered_set<Expr*> expressions_allocated_in_main_loop;
1610
+ getAllocInTrivialLoop (main_loop, expressions_allocated_in_main_loop);
1611
+
1612
+ // Epilogue loop:
1613
+ // - wait only
1614
+ // - mbarrier_wait
1615
+ ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone (
1616
+ circular_buffer_loop,
1617
+ loads,
1618
+ CircularBufferLoopStage::Epilog,
1619
+ /* insertion_position=*/ 1 ,
1620
+ expressions_allocated_in_main_loop);
1621
+ registerInsertAfter (circular_buffer_loop, epilogue_loop);
1622
+ }
1623
+
1562
1624
void insert (ForLoop* circular_buffer_loop, const std::vector<Expr*>& loads) {
1563
1625
NVF_ERROR (
1564
1626
!usesMBarrierForWAR (circular_buffer_loop),
@@ -1788,8 +1850,9 @@ std::vector<Expr*> CircularBufferPass::run(const std::vector<Expr*>& exprs) {
1788
1850
// Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized
1789
1851
// inside of Pipeline circular buffering.
1790
1852
std::vector<Expr*> result_exprs =
1791
- CircularBufferInserter::run (exprs, pipeline_insertion_info);
1792
- return CircularBufferInserter::run (result_exprs, ws_insertion_info);
1853
+ PipelineCircularBufferInserter::run (exprs, pipeline_insertion_info);
1854
+ return WarpSpecializedCircularBufferInserter::run (
1855
+ result_exprs, ws_insertion_info);
1793
1856
}
1794
1857
1795
1858
} // namespace nvfuser
0 commit comments