Skip to content

[WIP] Move edge ownership to SegmentedGroup using shared_ptr #4235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

csarofeen
Copy link
Collaborator

Description:
Moves edge ownership from SegmentedFusion to SegmentedGroup using shared_ptr, allowing edges to be properly shared between groups. This eliminates redundant edge storage in SegmentedFusion and provides clearer ownership semantics.

Changes:

  • Convert edge storage to shared_ptr in SegmentedGroup
  • Remove redundant edge storage from SegmentedFusion
  • Update edge creation/management to work with shared edges

Copy link

github-actions bot commented Apr 11, 2025

Review updated until commit d1e7e6d

Description

  • Moved edge ownership to SegmentedGroup using shared_ptr.

  • Removed redundant edge storage from SegmentedFusion.

  • Updated edge creation/management to work with shared edges.

  • Added collectAllEdges method in SegmentedFusion.


Changes walkthrough 📝

Relevant files
Enhancement
fusion_segmenter.cpp
Updated edge management to use shared_ptr                               

csrc/fusion_segmenter.cpp

  • Updated edge storage to use shared_ptr in SegmentedGroup.
  • Removed redundant edge storage from SegmentedFusion.
  • Updated edge creation/management to use shared edges.
  • Added collectAllEdges method.
  • +141/-134
    fusion_segmenter.h
    Updated edge management to use shared_ptr                               

    csrc/fusion_segmenter.h

  • Updated edge storage to use shared_ptr in SegmentedGroup.
  • Removed redundant edge storage from SegmentedFusion.
  • Updated edge creation/management to use shared edges.
  • Added collectAllEdges method.
  • +28/-35 
    Tests
    test_segmentation.cpp
    Updated tests for shared_ptr edges                                             

    tests/cpp/test_segmentation.cpp

    • Updated tests to use shared_ptr for edges.
    +8/-4     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Memory Management

    The use of std::make_shared and std::shared_ptr introduces reference counting overhead. Ensure that this does not lead to performance regressions or memory leaks.

      std::make_shared<SegmentedEdge>(
          groups_.at(se_fb->from_segmented_group()),
          groups_.at(se_fb->to_segmented_group()),
          vals.at(se_fb->val()));
    }
    
    // Deserialize segmented groups
    Edge Removal

    The removeEdge function now takes a std::shared_ptr to an edge. Ensure that the reference counting is correctly managed when edges are removed to avoid dangling pointers or memory leaks.

      NVF_ERROR(edge, "Edge is nullptr");
      // Validate edge exists in all expected locations
      SegmentedGroup* producer = edge->from;
      SegmentedGroup* consumer = edge->to;
      auto& producer_consumer_edges = producer->consumer_edges;
      auto& consumer_producer_edges = consumer->producer_edges;
    
      // Remove edge from producer's consumer edges
      auto producer_edge_it = std::find(
          producer_consumer_edges.begin(), producer_consumer_edges.end(), edge);
      NVF_ERROR(
          producer_edge_it != producer_consumer_edges.end(),
          "Edge not found in producer's consumer edges");
      producer_consumer_edges.erase(producer_edge_it);
    
      // Remove edge from consumer's producer edges
      auto consumer_edge_it = std::find(
          consumer_producer_edges.begin(), consumer_producer_edges.end(), edge);
      NVF_ERROR(
          consumer_edge_it != consumer_producer_edges.end(),
          "Edge not found in consumer's producer edges");
      consumer_producer_edges.erase(consumer_edge_it);
    }
    Performance Impact

    The introduction of std::shared_ptr for edges may impact performance due to increased memory usage and reference counting overhead. Measure the performance impact and compare it with the previous implementation.

      // Used to log the number of values and expressions in the fusion for
      // serialization sanity check.
      segmented_fusion_ptr->finalize();
      return segmented_fusion_ptr;
    }
    
    std::vector<std::shared_ptr<SegmentedEdge>> SegmentedFusion::collectAllEdges()
        const {
      std::vector<std::shared_ptr<SegmentedEdge>> all_edges;
      // Edges should be symmetric, so we only need to collect edges from one
      // direction. i.e. a producer edge from a group needs to match a consumer edge
      // from another group
      for (auto group : groups()) {
        all_edges.insert(
            all_edges.end(),
            group->producer_edges.begin(),
            group->producer_edges.end());
      }
      return all_edges;
    }

    @csarofeen csarofeen force-pushed the segmenter_edge_shared_ptr branch from ec7c560 to e50c295 Compare April 12, 2025 01:11
    @csarofeen
    Copy link
    Collaborator Author

    !test

    @csarofeen csarofeen changed the base branch from main to segmenter_helpers April 12, 2025 15:35
    @csarofeen
    Copy link
    Collaborator Author

    !test

    @csarofeen
    Copy link
    Collaborator Author

    This PR is fine, but using shared_ptr doesn't seem particularly beneficial. We likely don't need a smart pointer, but furthermore if we're going to change the infrastructure in the segmenter we should probably make SegmentedGroup an Expr and SegmentedEdge a Val to be compatible with the other infrastructure in nvFuser.

    @csarofeen csarofeen closed this Apr 20, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant