16
16
# limitations under the License.
17
17
#
18
18
19
- import copy
20
- from typing import Any
19
+ from typing import Any , List
21
20
22
21
import torch
23
22
36
35
class TrainableEquivalentTransformation :
37
36
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""
38
37
39
- _PREPARE_ATTRS : list [str ] = ["weight_config" , "trained_alphas" ]
38
+ _PREPARE_ATTRS : List [str ] = ["weight_config" , "trained_alphas" ]
40
39
_PREPARE_ATTRS_PREFIX = "_prepare_"
41
40
42
- def __init__ (self , model , weight_config = {}, absorb_to_layer = {} , folding = True , example_inputs = None ):
41
+ def __init__ (self , model , weight_config = {}, absorb_to_layer = None , folding = True , example_inputs = None ):
43
42
"""
44
43
:param model: the model for quantization
45
44
:param weight_config (dict, optional): contains all info required by RTN. Defaults to {}.
@@ -54,6 +53,24 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
54
53
self .absorb_to_layer = absorb_to_layer
55
54
self ._post_initialized = False
56
55
56
+ def _detect_absorb_to_layer (self , model , folding , example_inputs ):
57
+ # If user not provide the layers to absorb the quantization, detect layers automatically
58
+ supported_layers = ["Linear" ]
59
+ detected_absorb_layers = {}
60
+ # Detect the layers that can be absorbed automatically
61
+ if folding :
62
+ from neural_compressor .torch .algorithms .weight_only .utility import GraphTrace
63
+
64
+ tg = GraphTrace ()
65
+ detected_absorb_layers , _ = tg .get_absorb_to_layer (model , example_inputs , supported_layers )
66
+ else :
67
+ for name , module in model .named_modules ():
68
+ if module .__class__ .__name__ in supported_layers :
69
+ detected_absorb_layers [name ] = [name ]
70
+ logger .info ("Detected **absorb layer**: **absorbed layers**" )
71
+ logger .info (detected_absorb_layers )
72
+ return detected_absorb_layers
73
+
57
74
def _post_init (self ):
58
75
self .dtype = self ._get_dtype ()
59
76
self .model .to (self .device )
@@ -75,6 +92,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
75
92
to the paper for more details
76
93
:param sqrt_w_init: use sqrt weight to init."""
77
94
95
+ if not self .absorb_to_layer :
96
+ self .absorb_to_layer = self ._detect_absorb_to_layer (self .model , self .folding , self .example_inputs )
78
97
if not self ._post_initialized :
79
98
self ._post_init ()
80
99
# freeze model.
@@ -104,7 +123,7 @@ def add_tuning_scale(self, sqrt_w_init=False):
104
123
105
124
self .trained_alphas [layer_norm ] = alpha
106
125
for layer_name in self .absorb_to_layer [layer_norm ]:
107
- if self .weight_config .get (layer_name ) is None : # pragma: no cover
126
+ if not self .weight_config .get (layer_name ): # pragma: no cover
108
127
logger .info (f"layer { layer_name } not in weight config, skip." )
109
128
continue
110
129
num_bits = self .weight_config [layer_name ]["bits" ]
@@ -117,10 +136,10 @@ def add_tuning_scale(self, sqrt_w_init=False):
117
136
)
118
137
set_module (self .model , layer_name , wrapper_module )
119
138
120
- for n , m in self .model .named_modules ():
139
+ for layer_name , m in self .model .named_modules ():
121
140
if isinstance (m , torch .nn .Linear ) and "orig_layer" not in n :
122
- if self .weight_config .get (n ) is None : # pragma: no cover
123
- logger .info (f"out of absorbed layer { n } not in weight config, skip." )
141
+ if not self .weight_config .get (layer_name ) : # pragma: no cover
142
+ logger .info (f"out of absorbed layer { layer_name } not in weight config, skip." )
124
143
continue
125
144
num_bits = self .weight_config [layer_name ]["bits" ]
126
145
group_size = self .weight_config [layer_name ]["group_size" ]
@@ -131,7 +150,7 @@ def add_tuning_scale(self, sqrt_w_init=False):
131
150
wrapper_module = TEQLinearFakeQuant (
132
151
orig_layer = m , alpha = alpha , num_bits = num_bits , group_size = group_size , scheme = scheme
133
152
)
134
- set_module (self .model , n , wrapper_module )
153
+ set_module (self .model , layer_name , wrapper_module )
135
154
# Attach the weight config captured at prepare stage to the model
136
155
self .model ._weight_config = self .weight_config
137
156
self .model ._trained_alphas = self .trained_alphas
@@ -272,7 +291,7 @@ def save(self, save_scale_file="", save_state_dict_file=""):
272
291
273
292
class TEQuantizer (Quantizer ):
274
293
275
- def __init__ (self , quant_config , folding , absorb_to_layer , example_inputs ):
294
+ def __init__ (self , quant_config , folding , example_inputs , absorb_to_layer = None ):
276
295
super ().__init__ (quant_config = quant_config )
277
296
self .folding = folding
278
297
self .absorb_to_layer = absorb_to_layer
0 commit comments