1
1
from collections .abc import Iterable
2
- from typing import Optional , Union
2
+ from typing import Any , Optional , Union
3
3
4
4
import torch
5
5
import torch .nn .functional as F
6
+ import vllm_ascend .envs as ascend_envs
6
7
from torch import nn
7
8
from transformers import Qwen2Config
9
+ from vllm .attention import AttentionType
8
10
from vllm .compilation .decorators import support_torch_compile
9
11
from vllm .config import CacheConfig , VllmConfig
10
12
from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
13
15
tensor_model_parallel_all_reduce ,
14
16
tensor_model_parallel_reduce_scatter )
15
17
from vllm .forward_context import get_forward_context
18
+ from vllm .model_executor .layers .layernorm import RMSNorm
16
19
from vllm .model_executor .layers .logits_processor import LogitsProcessor
17
20
from vllm .model_executor .layers .quantization import QuantizationConfig
18
21
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
19
22
from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
20
- from vllm .model_executor .models .qwen2 import Qwen2DecoderLayer , Qwen2Model
23
+ from vllm .model_executor .models .qwen2 import (Qwen2Attention , Qwen2MLP ,
24
+ Qwen2Model )
21
25
from vllm .model_executor .models .utils import (AutoWeightsLoader ,
22
26
PPMissingLayer , maybe_prefix )
23
27
from vllm .model_executor .sampling_metadata import SamplingMetadata
24
28
from vllm .sequence import IntermediateTensors
25
-
26
- import vllm_ascend .envs as ascend_envs
27
29
from vllm_ascend .attention .attention_v1 import AscendAttentionState
28
30
29
31
@@ -47,19 +49,102 @@ def maybe_pad_and_reduce_scatter(
47
49
return hidden_states
48
50
49
51
50
- class CustomQwen2DecoderLayer ( Qwen2DecoderLayer ):
52
+ class CustomQwen2Attention ( Qwen2Attention ):
51
53
52
54
def __init__ (
53
55
self ,
54
- config : Qwen2Config ,
56
+ hidden_size : int ,
57
+ num_heads : int ,
58
+ num_kv_heads : int ,
59
+ max_position : int = 4096 * 32 ,
60
+ rope_theta : float = 10000 ,
55
61
cache_config : Optional [CacheConfig ] = None ,
56
62
quant_config : Optional [QuantizationConfig ] = None ,
63
+ rope_scaling : Optional [tuple ] = None ,
57
64
prefix : str = "" ,
65
+ attn_type : str = AttentionType .DECODER ,
66
+ dual_chunk_attention_config : Optional [dict [str , Any ]] = None ,
58
67
) -> None :
59
- super ().__init__ (config = config ,
68
+ super ().__init__ (hidden_size = hidden_size ,
69
+ num_heads = num_heads ,
70
+ num_kv_heads = num_kv_heads ,
71
+ max_position = max_position ,
72
+ rope_theta = rope_theta ,
60
73
cache_config = cache_config ,
61
74
quant_config = quant_config ,
62
- prefix = prefix )
75
+ rope_scaling = rope_scaling ,
76
+ prefix = prefix ,
77
+ attn_type = attn_type ,
78
+ dual_chunk_attention_config = dual_chunk_attention_config )
79
+
80
+ def forward (
81
+ self ,
82
+ positions : torch .Tensor ,
83
+ hidden_states : torch .Tensor ,
84
+ cos : torch .Tensor ,
85
+ sin : torch .Tensor
86
+ ) -> torch .Tensor :
87
+ qkv , _ = self .qkv_proj (hidden_states )
88
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
89
+ q , k = self .rotary_emb (positions , q , k , cos = cos , sin = sin , skip_index_select = True )
90
+ attn_output = self .attn (q , k , v )
91
+ output , _ = self .o_proj (attn_output )
92
+ return output
93
+
94
+
95
+ class CustomQwen2DecoderLayer (nn .Module ):
96
+
97
+ def __init__ (
98
+ self ,
99
+ config : Qwen2Config ,
100
+ cache_config : Optional [CacheConfig ] = None ,
101
+ quant_config : Optional [QuantizationConfig ] = None ,
102
+ prefix : str = "" ,
103
+ ) -> None :
104
+ super ().__init__ ()
105
+
106
+ self .hidden_size = config .hidden_size
107
+ # Requires transformers > 4.32.0
108
+ rope_theta = getattr (config , "rope_theta" , 1000000 )
109
+ rope_scaling = getattr (config , "rope_scaling" , None )
110
+ dual_chunk_attention_config = getattr (config ,
111
+ "dual_chunk_attention_config" ,
112
+ None )
113
+
114
+ # By default, Qwen2 uses causal attention as it is a decoder-only model.
115
+ # You can override the HF config with `is_causal=False` to enable
116
+ # bidirectional attention, which is used in some embedding models
117
+ # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
118
+ if getattr (config , "is_causal" , True ):
119
+ attn_type = AttentionType .DECODER
120
+ else :
121
+ attn_type = AttentionType .ENCODER_ONLY
122
+
123
+ self .self_attn = CustomQwen2Attention (
124
+ hidden_size = self .hidden_size ,
125
+ num_heads = config .num_attention_heads ,
126
+ max_position = config .max_position_embeddings ,
127
+ num_kv_heads = config .num_key_value_heads ,
128
+ rope_theta = rope_theta ,
129
+ cache_config = cache_config ,
130
+ quant_config = quant_config ,
131
+ rope_scaling = rope_scaling ,
132
+ prefix = f"{ prefix } .self_attn" ,
133
+ attn_type = attn_type ,
134
+ dual_chunk_attention_config = dual_chunk_attention_config ,
135
+ )
136
+ self .mlp = Qwen2MLP (
137
+ hidden_size = self .hidden_size ,
138
+ intermediate_size = config .intermediate_size ,
139
+ hidden_act = config .hidden_act ,
140
+ quant_config = quant_config ,
141
+ prefix = f"{ prefix } .mlp" ,
142
+ )
143
+ self .input_layernorm = RMSNorm (config .hidden_size ,
144
+ eps = config .rms_norm_eps )
145
+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
146
+ eps = config .rms_norm_eps )
147
+
63
148
self .tp_rank = get_tensor_model_parallel_rank ()
64
149
self .tp_size = get_tensor_model_parallel_world_size ()
65
150
self .self_attn .o_proj .reduce_results = False
@@ -68,6 +153,8 @@ def __init__(
68
153
def forward (
69
154
self ,
70
155
positions : torch .Tensor ,
156
+ cos : torch .Tensor ,
157
+ sin : torch .Tensor ,
71
158
hidden_states : torch .Tensor ,
72
159
residual : Optional [torch .Tensor ],
73
160
flashcomm_v1_enabled : bool ,
@@ -91,6 +178,8 @@ def forward(
91
178
hidden_states = self .self_attn (
92
179
positions = positions ,
93
180
hidden_states = hidden_states ,
181
+ cos = cos ,
182
+ sin = sin
94
183
)
95
184
if flashcomm_v1_enabled :
96
185
hidden_states = maybe_pad_and_reduce_scatter (
@@ -133,6 +222,9 @@ def __init__(
133
222
decoder_layer_type = decoder_layer_type )
134
223
self .tp_size = get_tensor_model_parallel_world_size ()
135
224
225
+ self .rotary_emb = self .layers [0 ].self_attn .rotary_emb
226
+ self .cos_sin_cache = self .rotary_emb .cos_sin_cache
227
+
136
228
def forward (
137
229
self ,
138
230
input_ids : torch .Tensor ,
@@ -161,9 +253,19 @@ def forward(
161
253
num_tokens = hidden_states .size (0 )
162
254
pad_size = (self .tp_size -
163
255
(num_tokens % self .tp_size )) % self .tp_size
256
+
257
+ cos_sin = self .cos_sin_cache .index_select (0 , positions )
258
+ head_dim = cos_sin .size ()[- 1 ]
259
+ cos , sin = cos_sin .reshape (- 1 , 2 ,
260
+ head_dim // 2 ).repeat (1 , 1 , 2 ).chunk (2 , dim = - 2 )
261
+ cos = cos .view (1 , - 1 , 1 , head_dim ).contiguous ()
262
+ sin = sin .view (1 , - 1 , 1 , head_dim ).contiguous ()
263
+
164
264
for layer in self .layers [self .start_layer :self .end_layer ]:
165
265
hidden_states , residual = layer (
166
266
positions ,
267
+ cos ,
268
+ sin ,
167
269
hidden_states ,
168
270
residual ,
169
271
flashcomm_v1_enabled ,
0 commit comments