Skip to content

Commit 3996d8d

Browse files
committed
add ut
Signed-off-by: yiliu30 <[email protected]>
1 parent 4372a76 commit 3996d8d

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

neural_compressor/torch/algorithms/weight_only/teq.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
# limitations under the License.
1717
#
1818

19-
import copy
20-
from typing import Any
19+
from typing import Any, List
2120

2221
import torch
2322

@@ -36,10 +35,10 @@
3635
class TrainableEquivalentTransformation:
3736
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""
3837

39-
_PREPARE_ATTRS: list[str] = ["weight_config", "trained_alphas"]
38+
_PREPARE_ATTRS: List[str] = ["weight_config", "trained_alphas"]
4039
_PREPARE_ATTRS_PREFIX = "_prepare_"
4140

42-
def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None):
41+
def __init__(self, model, weight_config={}, absorb_to_layer=None, folding=True, example_inputs=None):
4342
"""
4443
:param model: the model for quantization
4544
:param weight_config (dict, optional): contains all info required by RTN. Defaults to {}.
@@ -54,6 +53,24 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
5453
self.absorb_to_layer = absorb_to_layer
5554
self._post_initialized = False
5655

56+
def _detect_absorb_to_layer(self, model, folding, example_inputs):
57+
# If user not provide the layers to absorb the quantization, detect layers automatically
58+
supported_layers = ["Linear"]
59+
detected_absorb_layers = {}
60+
# Detect the layers that can be absorbed automatically
61+
if folding:
62+
from neural_compressor.torch.algorithms.weight_only.utility import GraphTrace
63+
64+
tg = GraphTrace()
65+
detected_absorb_layers, _ = tg.get_absorb_to_layer(model, example_inputs, supported_layers)
66+
else:
67+
for name, module in model.named_modules():
68+
if module.__class__.__name__ in supported_layers:
69+
detected_absorb_layers[name] = [name]
70+
logger.info("Detected **absorb layer**: **absorbed layers**")
71+
logger.info(detected_absorb_layers)
72+
return detected_absorb_layers
73+
5774
def _post_init(self):
5875
self.dtype = self._get_dtype()
5976
self.model.to(self.device)
@@ -75,6 +92,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
7592
to the paper for more details
7693
:param sqrt_w_init: use sqrt weight to init."""
7794

95+
if not self.absorb_to_layer:
96+
self.absorb_to_layer = self._detect_absorb_to_layer(self.model, self.folding, self.example_inputs)
7897
if not self._post_initialized:
7998
self._post_init()
8099
# freeze model.
@@ -104,7 +123,7 @@ def add_tuning_scale(self, sqrt_w_init=False):
104123

105124
self.trained_alphas[layer_norm] = alpha
106125
for layer_name in self.absorb_to_layer[layer_norm]:
107-
if self.weight_config.get(layer_name) is None: # pragma: no cover
126+
if not self.weight_config.get(layer_name): # pragma: no cover
108127
logger.info(f"layer {layer_name} not in weight config, skip.")
109128
continue
110129
num_bits = self.weight_config[layer_name]["bits"]
@@ -117,10 +136,10 @@ def add_tuning_scale(self, sqrt_w_init=False):
117136
)
118137
set_module(self.model, layer_name, wrapper_module)
119138

120-
for n, m in self.model.named_modules():
139+
for layer_name, m in self.model.named_modules():
121140
if isinstance(m, torch.nn.Linear) and "orig_layer" not in n:
122-
if self.weight_config.get(n) is None: # pragma: no cover
123-
logger.info(f"out of absorbed layer {n} not in weight config, skip.")
141+
if not self.weight_config.get(layer_name): # pragma: no cover
142+
logger.info(f"out of absorbed layer {layer_name} not in weight config, skip.")
124143
continue
125144
num_bits = self.weight_config[layer_name]["bits"]
126145
group_size = self.weight_config[layer_name]["group_size"]
@@ -131,7 +150,7 @@ def add_tuning_scale(self, sqrt_w_init=False):
131150
wrapper_module = TEQLinearFakeQuant(
132151
orig_layer=m, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme
133152
)
134-
set_module(self.model, n, wrapper_module)
153+
set_module(self.model, layer_name, wrapper_module)
135154
# Attach the weight config captured at prepare stage to the model
136155
self.model._weight_config = self.weight_config
137156
self.model._trained_alphas = self.trained_alphas
@@ -272,7 +291,7 @@ def save(self, save_scale_file="", save_state_dict_file=""):
272291

273292
class TEQuantizer(Quantizer):
274293

275-
def __init__(self, quant_config, folding, absorb_to_layer, example_inputs):
294+
def __init__(self, quant_config, folding, example_inputs, absorb_to_layer=None):
276295
super().__init__(quant_config=quant_config)
277296
self.folding = folding
278297
self.absorb_to_layer = absorb_to_layer

test/3x/torch/algorithms/weight_only/test_teq_quantizer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,21 @@ def setUpClass(self):
8282
)
8383
self.gptj.seqlen = 512
8484

85-
def train_func(self):
86-
pass
85+
def test_teq_detect_absorb_layers(self):
86+
example_inputs = torch.ones([1, 512], dtype=torch.long)
87+
test_input = torch.ones([1, 512], dtype=torch.long)
88+
model = copy.deepcopy(self.gptj)
89+
out0 = model(test_input)
90+
91+
weight_config = {
92+
# 'op_name': (bit, group_size, scheme)
93+
"transformer.h.0.mlp.fc_in": {"bits": 8, "group_size": -1, "scheme": "sym"},
94+
"transformer.h.0.mlp.fc_out": {"bits": 4, "group_size": 32, "scheme": "asym"},
95+
}
96+
quantizer = TEQuantizer(quant_config=weight_config, folding=True, example_inputs=example_inputs)
97+
model = quantizer.quantize(copy.deepcopy(self.gptj), run_fn=train)
98+
out1 = model(test_input)
99+
self.assertTrue(torch.allclose(out1[0], out0[0], atol=0.03))
87100

88101
def test_teq(self):
89102
example_inputs = torch.ones([1, 512], dtype=torch.long)

0 commit comments

Comments
 (0)