Skip to content

[core] Support capture custom ops into aclgraph #2113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions csrc/torch_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@

namespace vllm_ascend {

AscendType get_dtype_from_torch(at::ScalarType scalarType)
{
if (scalarType == at::ScalarType::Float) {
return AscendType::FP32;
} else if (scalarType == at::ScalarType::BFloat16) {
return AscendType::BF16;
} else {
return AscendType::FP16;
}
}

std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
{
Expand Down
86 changes: 86 additions & 0 deletions csrc/torch_binding_meta.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/version.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/npu/Module.h>
#include "utils.h"
/*
* How to write a meta implementation for a custom operator (meta kernel):
*
* Meta implementations are used for shape and dtype inference, tracing, and export.
* They do NOT perform any real computation or allocate device memory.
* Instead, they return empty tensors with the correct shapes, dtypes, and device types.
*
* Steps to write a meta implementation:
* 1. The function signature should match the operator's schema, but only use the arguments
* necessary to infer output shapes and dtypes.
* 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes.
* 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype.
* 4. Do NOT perform any real computation or data movement.
* 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar.
*
* Example:
* std::tuple<at::Tensor, at::Tensor> my_op_meta(
* at::Tensor &input, int64_t some_param) {
* // Infer output shape based on input and parameters
* auto out_shape = ...;
* at::Tensor out = at::empty_symint(out_shape, input.options());
* // Return empty tensor(s) with correct shape/dtype
* return {out, ...};
* }
*
* See below for real examples.
*/

namespace vllm_ascend {
namespace meta {

std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
at::Tensor &positions,
at::Tensor &query,
at::Tensor &key,
int64_t head_size,
at::Tensor &cos_sin_cache,
bool is_neox) {
auto num_tokens = positions.sym_numel();
auto query_hidden_size = query.sym_numel() / num_tokens;
auto key_hidden_size = key.sym_numel() / num_tokens;

auto num_heads = query_hidden_size / head_size;
auto num_kv_heads = key_hidden_size / head_size;
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());

return {query_dst, key_dst};
}

std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
at::Tensor &input,
const int64_t org_vocab_start_index,
const int64_t org_vocab_end_index,
const int64_t num_org_vocab_padding,
const int64_t added_vocab_start_index,
const int64_t added_vocab_end_index) {

at::Tensor masked_input = at::empty_like(input);
at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool));

return {masked_input, mask};
}


} // namespace meta
} // namespace vllm_ascend

namespace {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
// Rotary embedding meta implementation
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
// Masked input and mask meta implementation
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);

}
}
12 changes: 0 additions & 12 deletions csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,3 @@
}


namespace vllm_ascend {
AscendType get_dtype_from_torch(at::ScalarType scalarType)
{
if (scalarType == at::ScalarType::Float) {
return AscendType::FP32;
} else if (scalarType == at::ScalarType::BFloat16) {
return AscendType::BF16;
} else {
return AscendType::FP16;
}
}
} // namespace vllm_ascend
146 changes: 145 additions & 1 deletion tests/e2e/singlecard/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
# Only Neox style true scenario is supported for now
IS_NEOX_STYLE = [True]
DTYPES = [torch.half]
HEAD_SIZES = [64, 96, 128, 256]
HEAD_SIZES = [64, 64, 96, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 4096] # Arbitrary values for testing
NUM_TOKENS = [10, 21]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
# Set tolerance to 1 for quant ops
Expand Down Expand Up @@ -198,3 +199,146 @@ def test_rotary_embedding_quant_with_leading_dim(
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)


class ModelwithRotaryEmbedding(nn.Module):

def __init__(
self,
hidden_size: int,
num_heads: int,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
self.rope = RotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(3, dim=-1)
query, key = torch.ops._C.rotary_embedding(
positions,
q,
k,
self.rope.head_size,
self.rope.cos_sin_cache,
self.rope.is_neox_style,
)
query = query.view(q.shape)
key = key.view(k.shape)
o = self.o_proj(query)
return o


# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
# add a warmup graph replay for workaround
ACL_GRPAH_FIRST_RUN = True


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_capture_rotary_embedding_in_aclgraph(
is_neox_style: bool,
num_tokens: int,
num_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
max_position_embeddings: int = 8192,
base: int = 10000,
):
"""Test if the rotary embedding can be captured in aclgraph."""
torch.manual_seed(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
model = ModelwithRotaryEmbedding(
hidden_size=num_heads * head_size,
num_heads=num_heads,
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)

def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
# string match
graph = str(gm.graph)
assert "_C.rotary_embedding" in graph
return gm

static_positions = torch.randint(0, max_position_embeddings,
(num_tokens, ))
static_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
stream = torch.npu.Stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
# warmup the fx graph before capture
for i in range(3):
static_output = compiled_model(static_positions,
static_hidden_states,
offsets=None)
stream.wait_stream(torch.npu.current_stream())

aclgraph = torch.npu.NPUGraph()

with torch.npu.graph(aclgraph):
# Capture the model in aclgraph.
static_output = compiled_model(static_positions, static_hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This place is a static shape. If the shape of static_positions, static_hidden_states has changed, does meta need to go through again?

Copy link
Collaborator Author

@ganyi1996ppo ganyi1996ppo Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

# Capture the model in aclgraph.
random_filled_positions = torch.randint(0,
max_position_embeddings,
(num_tokens, ),
device="npu")
random_filled_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
static_positions.copy_(random_filled_positions)
static_hidden_states.copy_(random_filled_hidden_states)

aclgraph.replay()
global ACL_GRPAH_FIRST_RUN
if ACL_GRPAH_FIRST_RUN:
ACL_GRPAH_FIRST_RUN = False
return
output_reference = model(static_positions, static_hidden_states)
torch.testing.assert_close(static_output,
output_reference,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
86 changes: 86 additions & 0 deletions vllm_ascend/meta_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
from torch.library import Library

# This file provides a template and registration utilities for writing "meta" implementations
# of custom operators in Python for the vllm_ascend project.
#
# We offer two ways to implement meta implementations for custom ops:
# 1. Python meta implementation (as shown in this file): Write a Python function that
# takes the same arguments as your operator and returns empty tensors with the correct
# shapes and dtypes. This is useful for rapid prototyping and for ops that are only
# used in Python.
# 2. C++ meta implementation: You can also implement the meta function in C++ for better
# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp`
# for examples of C++ meta implementations and how to register them.
#
# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which
# is essential for supporting `torch.compile` and aclgraph.

# How to add a new meta implementation in Python:
# -------------------------------------
# 1. Write a Python function that takes the same arguments as your operator, and returns
# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes.
# Do NOT perform any real computation or allocate device memory.
#
# 2. Register your meta function using `register_meta_if_necessary`, providing:
# - The namespace (usually "_C" for custom ops)
# - The operator name (as registered in C++)
# - The Python meta function
# - (Optional) The overload name, if your op has overloads
#
# 3. The registration utility will check if a meta implementation already exists for your op,
# and only register if necessary. This avoids duplicate registrations.
#
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
#
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
# and aclgraph.
#
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors

lib = Library("_C", "IMPL")


def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
if overload != "":
op_name = op_name + "." + overload
schema_to_find = ns + "::" + op_name
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key(
"Meta")
if schema_to_find in meta_impl_list:
return
lib.impl(op_name, fn, "Meta")


def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool):

num_tokens = positions.numel()
query_hidden_size = query.numel() // num_tokens
key_hidden_size = key.numel() // num_tokens
num_heads = query_hidden_size // head_size
num_kv_heads = key_hidden_size // head_size

query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
return query_dst, key_dst


def get_masked_input_and_mask_meta(input: torch.Tensor,
org_vocab_start_index: int,
org_vocab_end_index: int,
num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int):

masked_input = torch.empty_like(input)
mask = torch.empty_like(input).to(torch.bool)

return masked_input, mask


register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
register_meta_if_necessary("_C", "get_masked_input_and_mask",
get_masked_input_and_mask_meta)
4 changes: 4 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,12 @@ def enable_custom_op():
if _CUSTOM_OP_ENABLED is not None:
return _CUSTOM_OP_ENABLED
try:
# isort: off
# register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
# register the meta implementation for custom kernel if necessary
import vllm_ascend.meta_registration # type: ignore # noqa: F401
# isort: on
_CUSTOM_OP_ENABLED = True
except ImportError:
_CUSTOM_OP_ENABLED = False
Expand Down
Loading