-
Notifications
You must be signed in to change notification settings - Fork 325
[core] Support capture custom ops into aclgraph #2113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wangxiyuan
merged 16 commits into
vllm-project:main
from
ganyi1996ppo:ganyi/meta_registration
Aug 11, 2025
+332
−13
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
b017650
register the meta dispatch path for custom kernels
ganyi1996ppo 753b3a9
move the get_dtype_from_torch to the cpp file prevent redefination
ganyi1996ppo 475db0d
remove numel operation to support symbolic tracing
ganyi1996ppo 6cc2548
enable meta device registration, and enable the simple model for rope
ganyi1996ppo 06f47e9
add unittes
ganyi1996ppo 0b4899b
fix format issue
ganyi1996ppo cb45377
fix the unittest
ganyi1996ppo fb85e0d
add python level meta registration method
ganyi1996ppo 3eaae42
add comments for the meta device registration
ganyi1996ppo c2f231c
fix lint
ganyi1996ppo f14bc35
makes isort ignore the meta registration's import sequence
ganyi1996ppo e5af114
fix the aclgraph ci issue
ganyi1996ppo f5a07be
remove lint to enable ut
ganyi1996ppo 7c9a5ed
restore ci
ganyi1996ppo 3bee673
add a warmup aclgraph run for ci
ganyi1996ppo b4ee77e
fix lint:
ganyi1996ppo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#include <torch/extension.h> | ||
#include <torch/library.h> | ||
#include <torch/version.h> | ||
#include <torch_npu/csrc/core/npu/NPUStream.h> | ||
#include <torch_npu/csrc/framework/OpCommand.h> | ||
#include <torch_npu/csrc/npu/Module.h> | ||
#include "utils.h" | ||
/* | ||
* How to write a meta implementation for a custom operator (meta kernel): | ||
* | ||
* Meta implementations are used for shape and dtype inference, tracing, and export. | ||
* They do NOT perform any real computation or allocate device memory. | ||
* Instead, they return empty tensors with the correct shapes, dtypes, and device types. | ||
* | ||
* Steps to write a meta implementation: | ||
* 1. The function signature should match the operator's schema, but only use the arguments | ||
* necessary to infer output shapes and dtypes. | ||
* 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes. | ||
* 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype. | ||
* 4. Do NOT perform any real computation or data movement. | ||
* 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar. | ||
* | ||
* Example: | ||
* std::tuple<at::Tensor, at::Tensor> my_op_meta( | ||
* at::Tensor &input, int64_t some_param) { | ||
* // Infer output shape based on input and parameters | ||
* auto out_shape = ...; | ||
* at::Tensor out = at::empty_symint(out_shape, input.options()); | ||
* // Return empty tensor(s) with correct shape/dtype | ||
* return {out, ...}; | ||
* } | ||
* | ||
* See below for real examples. | ||
*/ | ||
|
||
namespace vllm_ascend { | ||
namespace meta { | ||
|
||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta( | ||
at::Tensor &positions, | ||
at::Tensor &query, | ||
at::Tensor &key, | ||
int64_t head_size, | ||
at::Tensor &cos_sin_cache, | ||
bool is_neox) { | ||
auto num_tokens = positions.sym_numel(); | ||
auto query_hidden_size = query.sym_numel() / num_tokens; | ||
auto key_hidden_size = key.sym_numel() / num_tokens; | ||
|
||
auto num_heads = query_hidden_size / head_size; | ||
auto num_kv_heads = key_hidden_size / head_size; | ||
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options()); | ||
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options()); | ||
|
||
return {query_dst, key_dst}; | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta( | ||
at::Tensor &input, | ||
const int64_t org_vocab_start_index, | ||
const int64_t org_vocab_end_index, | ||
const int64_t num_org_vocab_padding, | ||
const int64_t added_vocab_start_index, | ||
const int64_t added_vocab_end_index) { | ||
|
||
at::Tensor masked_input = at::empty_like(input); | ||
at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool)); | ||
|
||
return {masked_input, mask}; | ||
} | ||
|
||
|
||
} // namespace meta | ||
} // namespace vllm_ascend | ||
|
||
namespace { | ||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also | ||
// the custom kernel been captured into aclgraph | ||
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) { | ||
// Rotary embedding meta implementation | ||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); | ||
// Masked input and mask meta implementation | ||
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta); | ||
|
||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
from torch.library import Library | ||
|
||
# This file provides a template and registration utilities for writing "meta" implementations | ||
# of custom operators in Python for the vllm_ascend project. | ||
# | ||
# We offer two ways to implement meta implementations for custom ops: | ||
# 1. Python meta implementation (as shown in this file): Write a Python function that | ||
# takes the same arguments as your operator and returns empty tensors with the correct | ||
# shapes and dtypes. This is useful for rapid prototyping and for ops that are only | ||
# used in Python. | ||
# 2. C++ meta implementation: You can also implement the meta function in C++ for better | ||
# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp` | ||
# for examples of C++ meta implementations and how to register them. | ||
# | ||
# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which | ||
# is essential for supporting `torch.compile` and aclgraph. | ||
|
||
# How to add a new meta implementation in Python: | ||
# ------------------------------------- | ||
# 1. Write a Python function that takes the same arguments as your operator, and returns | ||
# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes. | ||
# Do NOT perform any real computation or allocate device memory. | ||
# | ||
# 2. Register your meta function using `register_meta_if_necessary`, providing: | ||
# - The namespace (usually "_C" for custom ops) | ||
# - The operator name (as registered in C++) | ||
# - The Python meta function | ||
# - (Optional) The overload name, if your op has overloads | ||
# | ||
# 3. The registration utility will check if a meta implementation already exists for your op, | ||
# and only register if necessary. This avoids duplicate registrations. | ||
# | ||
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask. | ||
# | ||
# 5. When developing new custom ops, always provide a meta implementation to enable tracing, | ||
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile` | ||
# and aclgraph. | ||
# | ||
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors | ||
|
||
lib = Library("_C", "IMPL") | ||
|
||
|
||
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""): | ||
if overload != "": | ||
op_name = op_name + "." + overload | ||
schema_to_find = ns + "::" + op_name | ||
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key( | ||
"Meta") | ||
if schema_to_find in meta_impl_list: | ||
return | ||
lib.impl(op_name, fn, "Meta") | ||
|
||
|
||
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor, | ||
key: torch.Tensor, head_size: int, | ||
cos_sin_cache: torch.Tensor, is_neox: bool): | ||
|
||
num_tokens = positions.numel() | ||
query_hidden_size = query.numel() // num_tokens | ||
key_hidden_size = key.numel() // num_tokens | ||
num_heads = query_hidden_size // head_size | ||
num_kv_heads = key_hidden_size // head_size | ||
|
||
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size) | ||
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size) | ||
return query_dst, key_dst | ||
|
||
|
||
def get_masked_input_and_mask_meta(input: torch.Tensor, | ||
org_vocab_start_index: int, | ||
org_vocab_end_index: int, | ||
num_org_vocab_padding: int, | ||
added_vocab_start_index: int, | ||
added_vocab_end_index: int): | ||
|
||
masked_input = torch.empty_like(input) | ||
mask = torch.empty_like(input).to(torch.bool) | ||
|
||
return masked_input, mask | ||
|
||
|
||
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta) | ||
register_meta_if_necessary("_C", "get_masked_input_and_mask", | ||
get_masked_input_and_mask_meta) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This place is a static shape. If the shape of static_positions, static_hidden_states has changed, does meta need to go through again?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes