Skip to content

Commit 583ee1c

Browse files
committed
support cos_sin_cache prefetch for qwen2
1 parent b3d6e0c commit 583ee1c

File tree

1 file changed

+110
-8
lines changed

1 file changed

+110
-8
lines changed

vllm_ascend/models/qwen2.py

Lines changed: 110 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections.abc import Iterable
2-
from typing import Optional, Union
2+
from typing import Any, Optional, Union
33

44
import torch
55
import torch.nn.functional as F
6+
import vllm_ascend.envs as ascend_envs
67
from torch import nn
78
from transformers import Qwen2Config
9+
from vllm.attention import AttentionType
810
from vllm.compilation.decorators import support_torch_compile
911
from vllm.config import CacheConfig, VllmConfig
1012
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@@ -13,17 +15,17 @@
1315
tensor_model_parallel_all_reduce,
1416
tensor_model_parallel_reduce_scatter)
1517
from vllm.forward_context import get_forward_context
18+
from vllm.model_executor.layers.layernorm import RMSNorm
1619
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1720
from vllm.model_executor.layers.quantization import QuantizationConfig
1821
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
1922
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
20-
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model
23+
from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP,
24+
Qwen2Model)
2125
from vllm.model_executor.models.utils import (AutoWeightsLoader,
2226
PPMissingLayer, maybe_prefix)
2327
from vllm.model_executor.sampling_metadata import SamplingMetadata
2428
from vllm.sequence import IntermediateTensors
25-
26-
import vllm_ascend.envs as ascend_envs
2729
from vllm_ascend.attention.attention_v1 import AscendAttentionState
2830

2931

@@ -47,19 +49,102 @@ def maybe_pad_and_reduce_scatter(
4749
return hidden_states
4850

4951

50-
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
52+
class CustomQwen2Attention(Qwen2Attention):
5153

5254
def __init__(
5355
self,
54-
config: Qwen2Config,
56+
hidden_size: int,
57+
num_heads: int,
58+
num_kv_heads: int,
59+
max_position: int = 4096 * 32,
60+
rope_theta: float = 10000,
5561
cache_config: Optional[CacheConfig] = None,
5662
quant_config: Optional[QuantizationConfig] = None,
63+
rope_scaling: Optional[tuple] = None,
5764
prefix: str = "",
65+
attn_type: str = AttentionType.DECODER,
66+
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
5867
) -> None:
59-
super().__init__(config=config,
68+
super().__init__(hidden_size=hidden_size,
69+
num_heads=num_heads,
70+
num_kv_heads=num_kv_heads,
71+
max_position=max_position,
72+
rope_theta=rope_theta,
6073
cache_config=cache_config,
6174
quant_config=quant_config,
62-
prefix=prefix)
75+
rope_scaling=rope_scaling,
76+
prefix=prefix,
77+
attn_type=attn_type,
78+
dual_chunk_attention_config=dual_chunk_attention_config)
79+
80+
def forward(
81+
self,
82+
positions: torch.Tensor,
83+
hidden_states: torch.Tensor,
84+
cos: torch.Tensor,
85+
sin: torch.Tensor
86+
) -> torch.Tensor:
87+
qkv, _ = self.qkv_proj(hidden_states)
88+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
89+
q, k = self.rotary_emb(positions, q, k, cos=cos, sin=sin, skip_index_select=True)
90+
attn_output = self.attn(q, k, v)
91+
output, _ = self.o_proj(attn_output)
92+
return output
93+
94+
95+
class CustomQwen2DecoderLayer(nn.Module):
96+
97+
def __init__(
98+
self,
99+
config: Qwen2Config,
100+
cache_config: Optional[CacheConfig] = None,
101+
quant_config: Optional[QuantizationConfig] = None,
102+
prefix: str = "",
103+
) -> None:
104+
super().__init__()
105+
106+
self.hidden_size = config.hidden_size
107+
# Requires transformers > 4.32.0
108+
rope_theta = getattr(config, "rope_theta", 1000000)
109+
rope_scaling = getattr(config, "rope_scaling", None)
110+
dual_chunk_attention_config = getattr(config,
111+
"dual_chunk_attention_config",
112+
None)
113+
114+
# By default, Qwen2 uses causal attention as it is a decoder-only model.
115+
# You can override the HF config with `is_causal=False` to enable
116+
# bidirectional attention, which is used in some embedding models
117+
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
118+
if getattr(config, "is_causal", True):
119+
attn_type = AttentionType.DECODER
120+
else:
121+
attn_type = AttentionType.ENCODER_ONLY
122+
123+
self.self_attn = CustomQwen2Attention(
124+
hidden_size=self.hidden_size,
125+
num_heads=config.num_attention_heads,
126+
max_position=config.max_position_embeddings,
127+
num_kv_heads=config.num_key_value_heads,
128+
rope_theta=rope_theta,
129+
cache_config=cache_config,
130+
quant_config=quant_config,
131+
rope_scaling=rope_scaling,
132+
prefix=f"{prefix}.self_attn",
133+
attn_type=attn_type,
134+
dual_chunk_attention_config=dual_chunk_attention_config,
135+
)
136+
self.mlp = Qwen2MLP(
137+
hidden_size=self.hidden_size,
138+
intermediate_size=config.intermediate_size,
139+
hidden_act=config.hidden_act,
140+
quant_config=quant_config,
141+
prefix=f"{prefix}.mlp",
142+
)
143+
self.input_layernorm = RMSNorm(config.hidden_size,
144+
eps=config.rms_norm_eps)
145+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
146+
eps=config.rms_norm_eps)
147+
63148
self.tp_rank = get_tensor_model_parallel_rank()
64149
self.tp_size = get_tensor_model_parallel_world_size()
65150
self.self_attn.o_proj.reduce_results = False
@@ -68,6 +153,8 @@ def __init__(
68153
def forward(
69154
self,
70155
positions: torch.Tensor,
156+
cos: torch.Tensor,
157+
sin: torch.Tensor,
71158
hidden_states: torch.Tensor,
72159
residual: Optional[torch.Tensor],
73160
flashcomm_v1_enabled: bool,
@@ -91,6 +178,8 @@ def forward(
91178
hidden_states = self.self_attn(
92179
positions=positions,
93180
hidden_states=hidden_states,
181+
cos=cos,
182+
sin=sin
94183
)
95184
if flashcomm_v1_enabled:
96185
hidden_states = maybe_pad_and_reduce_scatter(
@@ -133,6 +222,9 @@ def __init__(
133222
decoder_layer_type=decoder_layer_type)
134223
self.tp_size = get_tensor_model_parallel_world_size()
135224

225+
self.rotary_emb = self.layers[0].self_attn.rotary_emb
226+
self.cos_sin_cache = self.rotary_emb.cos_sin_cache
227+
136228
def forward(
137229
self,
138230
input_ids: torch.Tensor,
@@ -161,9 +253,19 @@ def forward(
161253
num_tokens = hidden_states.size(0)
162254
pad_size = (self.tp_size -
163255
(num_tokens % self.tp_size)) % self.tp_size
256+
257+
cos_sin = self.cos_sin_cache.index_select(0, positions)
258+
head_dim = cos_sin.size()[-1]
259+
cos, sin = cos_sin.reshape(-1, 2,
260+
head_dim // 2).repeat(1, 1, 2).chunk(2, dim=-2)
261+
cos = cos.view(1, -1, 1, head_dim).contiguous()
262+
sin = sin.view(1, -1, 1, head_dim).contiguous()
263+
164264
for layer in self.layers[self.start_layer:self.end_layer]:
165265
hidden_states, residual = layer(
166266
positions,
267+
cos,
268+
sin,
167269
hidden_states,
168270
residual,
169271
flashcomm_v1_enabled,

0 commit comments

Comments
 (0)