Skip to content

Create separate CircularBufferInserter for WarpSpecialized and Pipeline #4280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 22, 2025

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Apr 18, 2025

This PR separates CircularBufferInserter into two variants: WarpSpecializedCircularBufferInserter and PipelineCircularBufferInserter. Stacked on: #4275

This refactor is a foundation for more complex changes required for supporting blackwell and ping-pong.
Further enhancements are not planned for pipeline circular buffering.

Copy link

github-actions bot commented Apr 18, 2025

Review updated until commit 3a7c520

Description

  • Split CircularBufferInserter into WarpSpecializedCircularBufferInserter and PipelineCircularBufferInserter

  • Refactor to handle warp specialized and pipeline circular buffering separately

  • Add createArrivesForWar function for WAR mbarrier handling


Changes walkthrough 📝

Relevant files
Enhancement
circular_buffer.cpp
Separate CircularBufferInserter for WarpSpecialized and Pipeline

csrc/device_lower/pass/circular_buffer.cpp

  • Created WarpSpecializedCircularBufferInserter class
  • Created PipelineCircularBufferInserter class
  • Moved WAR mbarrier handling to createArrivesForWar function
  • Updated CircularBufferPass::run to use new inserter classes
  • +198/-135

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The createArrivesForWar function is duplicated in both WarpSpecializedCircularBufferInserter and PipelineCircularBufferInserter. This could lead to maintenance issues and potential inconsistencies.

    // This is needed because we prefetch data in circular buffering, and we
    // need to make sure the initial prefetches are not blocked by the
    // non-existing WAR hazards.
    ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) {
      const auto& opt =
          GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
              circular_buffer_loop->iter_domain());
      auto circular_buffer_tvs =
          GpuLower::current()->circularBufferInfo().getCircularBufferTvs(
              circular_buffer_loop->iter_domain());
      VectorOfUniqueEntries<TensorView*> mbarriers;
      for (auto tv : circular_buffer_tvs) {
        auto ldst = dynamic_cast<LoadStoreOp*>(tv->definition());
        NVF_ERROR(ldst != nullptr);
        auto it = GpuLower::current()->mbarrierMap().find(ldst);
        if (it == GpuLower::current()->mbarrierMap().end()) {
          continue;
        }
        mbarriers.pushBack(it->second);
      }
      auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1);
      for (auto mbarrier : mbarriers) {
        auto mbarrier_to_arrive = IrBuilder::create<kir::TensorIndex>(
            mbarrier,
            SimplifyingIrBuilder::addExpr(
                prefetch_loop->indexOrStartIfTrivial(), opt.stage));
        auto prefetch = IrBuilder::create<kir::MBarrierArrive>(
            /*state=*/nullptr, mbarrier_to_arrive);
        prefetch_loop->body().push_back(prefetch);
      }
    Code Duplication

    The insertTmaPipelined function is duplicated in both WarpSpecializedCircularBufferInserter and PipelineCircularBufferInserter. This could lead to maintenance issues and potential inconsistencies.

    void insertTmaPipelined(
        ForLoop* circular_buffer_loop,
        const std::vector<Expr*>& loads) {
      // Arrive on the WAR mbarriers to let the prefetching start.
      if (usesMBarrierForWAR(circular_buffer_loop)) {
        auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
        registerInsertBefore(circular_buffer_loop, prefetch_loop);
      }
    
      // Prologue loop:
      //  - launch only
      //  - arrive_expect_tx and tma load operations
      if (hasPrefetch(circular_buffer_loop)) {
        // If there is no prefetch, then we don't need a prologue loop.
        ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
            circular_buffer_loop,
            loads,
            CircularBufferLoopStage::Prolog,
            /*insertion_position=*/1);
        registerInsertBefore(circular_buffer_loop, prologue_loop);
      }
    
      // Main loop:
      //  - Launch and wait
      //  - arrive_expect_tx, tma load operations, and mbarrier_wait
      ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
          circular_buffer_loop,
          loads,
          CircularBufferLoopStage::Main,
          /*insertion_position=*/1);
      registerReplace(circular_buffer_loop, main_loop);
    
      if (!hasPrefetch(circular_buffer_loop)) {
        // If there is no prefetch, then we don't need a epilogue loop.
        return;
      }
    
      // We can use exclude argument in
      // CloneTmaCircularBufferLoopAndInsertSync clone to avoid
      // duplicating allocations if main loop is trivial.
      std::unordered_set<Expr*> expressions_allocated_in_main_loop;
      getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop);
    
      // Epilogue loop:
      //  - wait only
      //  - mbarrier_wait
      ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
          circular_buffer_loop,
          loads,
          CircularBufferLoopStage::Epilog,
          /*insertion_position=*/1,
          expressions_allocated_in_main_loop);
      registerInsertAfter(circular_buffer_loop, epilogue_loop);
    Missing Tests

    The PR does not include any new tests or updates to existing tests to validate the new WarpSpecializedCircularBufferInserter and PipelineCircularBufferInserter classes.

    class WarpSpecializedCircularBufferInserter : private kir::ExprMutator {
     public:
      // When there exist multiple circular buffer loops, apply
      // transformations to inner-most loops first. A single ExprMutator
      // pass can only process one loop.
      static std::vector<Expr*> run(
          const std::vector<Expr*>& exprs,
          InsertionInfo insertion_info) {
        std::vector<Expr*> inserted_exprs = exprs;
        while (!insertion_info.empty()) {
          WarpSpecializedCircularBufferInserter inserter(
              inserted_exprs, insertion_info);
          inserted_exprs = inserter.exprs_;
        }
        return inserted_exprs;
      }
    
     private:
      WarpSpecializedCircularBufferInserter(
          const std::vector<Expr*>& exprs,
          InsertionInfo& insertion_info)
          : insertion_info_(insertion_info) {
        size_t num_circular_buffer_loops = insertion_info.size();
        traverseAndInsert(exprs);
        NVF_ERROR(processed_loop_ != nullptr);
        NVF_ERROR(insertion_info.size() == num_circular_buffer_loops - 1);
      }
    
      using kir::ExprMutator::handle;
    
      void handle(ForLoop* loop) final {
        kir::ExprMutator::handle(loop);
    
        // If another loop is already taken care of, no more loop should
        // be done in the same pass
        if (processed_loop_ != nullptr) {
          return;
        }
    
        auto it = insertion_info_.find(loop);
        if (it == insertion_info_.end()) {
          return;
        }
    
        bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
            GpuLower::current()
                ->circularBufferInfo()
                .getCircularBufferOptionsFor(loop->iter_domain())
                .type);
        NVF_ERROR(use_warp_specialization);
        NVF_ERROR(
            std::all_of(
                it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk),
            "In order to use warp specialization, all buffers must be loaded by TMA");
        int64_t insertion_position =
            GpuLower::current()
                ->circularBufferInfo()
                .getCircularBufferInsertionPosition(loop->iter_domain());
        insertTmaWarpSpecialized(loop, it->second, insertion_position);
    
        processed_loop_ = loop;
        insertion_info_.erase(loop);
      }
    
      void insertTmaWarpSpecialized(
          ForLoop* circular_buffer_loop,
          const std::vector<Expr*>& loads,
          int64_t insertion_position) {
        const auto& opt =
            GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
                circular_buffer_loop->iter_domain());
        ParallelType warp_specialize_on = std::get<WarpSpecialized>(opt.type).on;
    
        // Create warp_dispatch_ite, the predicate is either
        // Tid == bdim - 1 or Tid >= bdim - padded
        int64_t warp_specialization_pad =
            GpuLower::current()
                ->parallelDimensionMap()
                .getWarpSpecializationPaddedVal(warp_specialize_on);
        kir::Predicate* predicate_val = nullptr;
        Val* raw =
            GpuLower::current()->parallelDimensionMap().get(warp_specialize_on);
        Val* raw_minus_pad = SimplifyingIrBuilder::subExpr(
            raw, IrBuilder::create<Val>(warp_specialization_pad, DataType::Index));
        if (warp_specialization_pad == 1) {
          predicate_val = IrBuilder::create<kir::Predicate>(IrBuilder::eqExpr(
              NamedScalar::getParallelIndex(warp_specialize_on), raw_minus_pad));
        } else {
          predicate_val = IrBuilder::create<kir::Predicate>(IrBuilder::geExpr(
              NamedScalar::getParallelIndex(warp_specialize_on), raw_minus_pad));
        }
        kir::IfThenElse* warp_dispatch_ite =
            IrBuilder::create<kir::IfThenElse>(predicate_val);
    
        // Set default value
        auto& circular_buffer_options =
            GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
                circular_buffer_loop->iter_domain());
        bool enable_register_sharing =
            std::holds_alternative<WarpSpecialized>(circular_buffer_options.type) &&
            std::get<WarpSpecialized>(circular_buffer_options.type)
                .num_registers.has_value();
    
        GpuLower::current()->kernel()->manage(
            "enable_register_sharing", enable_register_sharing);
    
        if (enable_register_sharing) {
          auto&& [decrease_num_registers, increase_num_registers] =
              std::get<WarpSpecialized>(circular_buffer_options.type)
                  .num_registers.value();
          GpuLower::current()->decIncRegisterUsage() =
              std::make_pair(decrease_num_registers, increase_num_registers);
          // Decrease registers in async warp group
          kir::SetMaxNReg* dec_reg_async_warp = IrBuilder::create<kir::SetMaxNReg>(
              IrBuilder::create<Val>(decrease_num_registers, DataType::Index),
              /*increase_registers=*/false);
          warp_dispatch_ite->thenBody().push_back(dec_reg_async_warp);
    
          // Increase registers in compute warp group
          kir::SetMaxNReg* inc_reg_async_warp = IrBuilder::create<kir::SetMaxNReg>(
              IrBuilder::create<Val>(increase_num_registers, DataType::Index),
              /*increase_registers*/ true);
          warp_dispatch_ite->elseBody().push_back(inc_reg_async_warp);
        }
    
        // Load loop:
        ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
            circular_buffer_loop,
            loads,
            CircularBufferLoopStage::AsyncWarp,
            insertion_position);
        warp_dispatch_ite->thenBody().push_back(load_loop);
    
        if (enable_register_sharing) {
          // Terminate the warp group handling Load loop immediately after
          // finishing its work.
          kir::Return* ret = IrBuilder::create<kir::Return>();
          warp_dispatch_ite->thenBody().push_back(ret);
        }
    
        // Prefetch:
        auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
        warp_dispatch_ite->elseBody().push_back(prefetch_loop);
    
        // Compute loop:
        ForLoop* compute_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
            circular_buffer_loop,
            loads,
            CircularBufferLoopStage::ComputeWarp,
            insertion_position);
        warp_dispatch_ite->elseBody().push_back(compute_loop);
    
        registerReplace(circular_buffer_loop, warp_dispatch_ite);
      }
    
     private:
      InsertionInfo& insertion_info_;
      ForLoop* processed_loop_ = nullptr;
    };
    
    // Apply pipeline circular buffering transformations
    class PipelineCircularBufferInserter : private kir::ExprMutator {
     public:
      // When there exist multiple circular buffer loops, apply
      // transformations to inner-most loops first. A single ExprMutator
      // pass can only process one loop.
      static std::vector<Expr*> run(
          const std::vector<Expr*>& exprs,
          InsertionInfo insertion_info) {
        std::vector<Expr*> inserted_exprs = exprs;
        while (!insertion_info.empty()) {
          PipelineCircularBufferInserter inserter(inserted_exprs, insertion_info);
          inserted_exprs = inserter.exprs_;
        }
        return inserted_exprs;
      }
    
     private:
      PipelineCircularBufferInserter(
          const std::vector<Expr*>& exprs,
          InsertionInfo& insertion_info)
          : insertion_info_(insertion_info) {
        size_t num_circular_buffer_loops = insertion_info.size();
        traverseAndInsert(exprs);
        NVF_ERROR(processed_loop_ != nullptr);
        NVF_ERROR(insertion_info.size() == num_circular_buffer_loops - 1);
      }
    
      using kir::ExprMutator::handle;
    
      void handle(ForLoop* loop) final {
        kir::ExprMutator::handle(loop);
    
        // If another loop is already taken care of, no more loop should
        // be done in the same pass
        if (processed_loop_ != nullptr) {
          return;
        }
    
        auto it = insertion_info_.find(loop);
        if (it == insertion_info_.end()) {
          return;
        }
    
        bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
            GpuLower::current()
                ->circularBufferInfo()
                .getCircularBufferOptionsFor(loop->iter_domain())
                .type);
        NVF_ERROR(!use_warp_specialization);
    
        auto has_cp_async_bulk = std::any_of(
            it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk);
        if (has_cp_async_bulk) {
          insertTmaPipelined(loop, it->second);
        } else {
          insert(loop, it->second);
        }
    
        processed_loop_ = loop;
        insertion_info_.erase(loop);
      }
    
      bool hasPrefetch(ForLoop* circular_buffer_loop) {
        int64_t prefetch_distance =
            GpuLower::current()
                ->circularBufferInfo()
                .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
                .prefetch;
        return prefetch_distance > 0;
      }
    
      static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) {
        return GpuLower::current()
            ->circularBufferInfo()
            .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
            .usesMBarrierForWAR();
      }
    
      void insertTmaPipelined(
          ForLoop* circular_buffer_loop,
          const std::vector<Expr*>& loads) {
        // Arrive on the WAR mbarriers to let the prefetching start.
        if (usesMBarrierForWAR(circular_buffer_loop)) {
          auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
          registerInsertBefore(circular_buffer_loop, prefetch_loop);
        }
    
        // Prologue loop:
        //  - launch only
        //  - arrive_expect_tx and tma load operations
        if (hasPrefetch(circular_buffer_loop)) {
          // If there is no prefetch, then we don't need a prologue loop.
          ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
              circular_buffer_loop,
              loads,
              CircularBufferLoopStage::Prolog,
              /*insertion_position=*/1);
          registerInsertBefore(circular_buffer_loop, prologue_loop);
        }
    
        // Main loop:
        //  - Launch and wait
        //  - arrive_expect_tx, tma load operations, and mbarrier_wait
        ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
            circular_buffer_loop,
            loads,
            CircularBufferLoopStage::Main,
            /*insertion_position=*/1);
        registerReplace(circular_buffer_loop, main_loop);
    
        if (!hasPrefetch(circular_buffer_loop)) {
          // If there is no prefetch, then we don't need a epilogue loop.
          return;
        }
    
        // We can use exclude argument in
        // CloneTmaCircularBufferLoopAndInsertSync clone to avoid
        // duplicating allocations if main loop is trivial.
        std::unordered_set<Expr*> expressions_allocated_in_main_loop;
        getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop);
    
        // Epilogue loop:
        //  - wait only
        //  - mbarrier_wait
        ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
            circular_buffer_loop,
            loads,
            CircularBufferLoopStage::Epilog,
            /*insertion_position=*/1,
            expressions_allocated_in_main_loop);
        registerInsertAfter(circular_buffer_loop, epilogue_loop);
      }
    
      void insert(ForLoop* circular_buffer_loop, const std::vector<Expr*>& loads) {

    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 force-pushed the cb_pass_refactor_p2 branch from ac56adb to d050e74 Compare April 21, 2025 17:05
    Base automatically changed from cb_pass_refactor to main April 21, 2025 23:50
    @rdspring1
    Copy link
    Collaborator Author

    !build

    @rdspring1 rdspring1 merged commit 0b2f5a8 into main Apr 22, 2025
    16 checks passed
    @rdspring1 rdspring1 deleted the cb_pass_refactor_p2 branch April 22, 2025 01:50
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    3 participants