From 90d5efe215fc0af4a73aad2ab77984c756b2a323 Mon Sep 17 00:00:00 2001 From: LS Date: Fri, 5 Apr 2024 22:56:02 +0800 Subject: [PATCH 1/2] feat: support loading eetq quantized model --- .../custom_modeling/flash_llama_modeling.py | 2 +- server/lorax_server/utils/layers.py | 71 ++++++++++++------- server/lorax_server/utils/weights.py | 21 +++++- 3 files changed, 65 insertions(+), 29 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 01b6a6f15..dcc54aa7d 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -195,7 +195,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if isinstance(weight, torch.Tensor): weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index e4c803302..736f43499 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -7,6 +7,7 @@ from accelerate import init_empty_weights from torch import nn from torch.nn import functional as F +from loguru import logger from lorax_server.adapters.types import LORA, MEDUSA from lorax_server.utils.gptq.quant_linear import QuantLinear @@ -166,32 +167,40 @@ def __init__( self, weight, bias, + scales=None, + quantized=False, ) -> None: super().__init__() - # Get the device where the weight tensor is currently stored. - device = weight.device - - # Transpose the weight tensor and make a contiguous copy of it on the CPU. - # The contiguous() function is used to ensure that the tensor is stored in a contiguous block of memory, - # which can improve performance in some cases. - weight_transposed = torch.t(weight) - weight_contiguous = weight_transposed.contiguous() - weight_cpu = weight_contiguous.cpu() - - # Quantize the weights. The quant_weights function is assumed to perform the quantization. - # The weights are quantized to int8 format, and the quantization is not performed in place (False). - weight_quantized, scale = quant_weights(weight_cpu, torch.int8, False) - - # Move the quantized weights and the scale back to the original device (GPU if available). - # The cuda() function is used to move the tensors to the GPU. - self.weight = weight_quantized.cuda(device) - self.scale = scale.cuda(device) - - # If a bias is present, move it to the GPU as well. If not, set the bias to None. - if bias is not None: - self.bias = bias.cuda(device) + + if quantized: + self.weight = weight + self.scale = scales + self.bias = bias if bias is not None else None else: - self.bias = None + # Get the device where the weight tensor is currently stored. + device = weight.device + + # Transpose the weight tensor and make a contiguous copy of it on the CPU. + # The contiguous() function is used to ensure that the tensor is stored in a contiguous block of memory, + # which can improve performance in some cases. + weight_transposed = torch.t(weight) + weight_contiguous = weight_transposed.contiguous() + weight_cpu = weight_contiguous.cpu() + + # Quantize the weights. The quant_weights function is assumed to perform the quantization. + # The weights are quantized to int8 format, and the quantization is not performed in place (False). + weight_quantized, scale = quant_weights(weight_cpu, torch.int8, False) + + # Move the quantized weights and the scale back to the original device (GPU if available). + # The cuda() function is used to move the tensors to the GPU. + self.weight = weight_quantized.cuda(device) + self.scale = scale.cuda(device) + + # If a bias is present, move it to the GPU as well. If not, set the bias to None. + if bias is not None: + self.bias = bias.cuda(device) + else: + self.bias = None def forward(self, input: torch.Tensor) -> torch.Tensor: """ @@ -344,10 +353,20 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): quant_type="fp4", ) elif quantize == "eetq": - if HAS_EETQ: - linear = EETQLinear(weight, bias) - else: + if not HAS_EETQ: raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") + + try: + qweight, scales = weight + linear = EETQLinear( + qweight, + bias, + scales, + True, + ) + except Exception: + logger.info("It seems that weight not quantized, make JIT now") + linear = EETQLinear(weight, bias) elif quantize == "gptq": try: qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 1bc777d92..03778bb5c 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -139,7 +139,7 @@ def get_tensor(self, tensor_name: str): tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32 - if tensor.dtype not in [torch.int32, torch.int64]: + if tensor.dtype not in [torch.int8, torch.int32, torch.int64]: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -178,7 +178,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[ raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32 - if tensor.dtype != torch.int32: + if tensor.dtype not in [torch.int8, torch.int32]: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -226,6 +226,15 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + elif quantize == "eetq": + try: + qweight = torch.cat(self.get_sharded_list("qweight", prefixes, dim=1), dim=1) + scales = torch.cat(self.get_sharded_list("weight_scales", prefixes, dim=0), dim=0) + weight = (qweight, scales) + except RuntimeError: + logger.info("It seems that weight is not quantized, so load it normally then make JIT later") + w = self.get_sharded_list("weight", prefixes, dim=0) + weight = torch.cat(w, dim=dim) else: w = self.get_sharded_list("weight", prefixes, dim=0) weight = torch.cat(w, dim=dim) @@ -310,6 +319,14 @@ def get_multi_weights_row(self, prefix: str, quantize: str): g_idx = None use_exllama = False weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + elif quantize == "eetq": + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + scales = self.get_sharded(f"{prefix}.weight_scales", dim=0) + weight = (qweight, scales) + except RuntimeError: + logger.info("It seems that weight is not quantized, so load it normally then make JIT later") + weight = self.get_sharded(f"{prefix}.weight", dim=1) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight From 9d91b378c3599aefcf891d98600e6a873d4b503e Mon Sep 17 00:00:00 2001 From: LS Date: Sat, 6 Apr 2024 20:59:43 +0800 Subject: [PATCH 2/2] fix: reduce the changes in EETQLinear --- server/lorax_server/utils/layers.py | 49 +++++++++++++++-------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 736f43499..61a336e51 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -176,31 +176,32 @@ def __init__( self.weight = weight self.scale = scales self.bias = bias if bias is not None else None + return + + # Get the device where the weight tensor is currently stored. + device = weight.device + + # Transpose the weight tensor and make a contiguous copy of it on the CPU. + # The contiguous() function is used to ensure that the tensor is stored in a contiguous block of memory, + # which can improve performance in some cases. + weight_transposed = torch.t(weight) + weight_contiguous = weight_transposed.contiguous() + weight_cpu = weight_contiguous.cpu() + + # Quantize the weights. The quant_weights function is assumed to perform the quantization. + # The weights are quantized to int8 format, and the quantization is not performed in place (False). + weight_quantized, scale = quant_weights(weight_cpu, torch.int8, False) + + # Move the quantized weights and the scale back to the original device (GPU if available). + # The cuda() function is used to move the tensors to the GPU. + self.weight = weight_quantized.cuda(device) + self.scale = scale.cuda(device) + + # If a bias is present, move it to the GPU as well. If not, set the bias to None. + if bias is not None: + self.bias = bias.cuda(device) else: - # Get the device where the weight tensor is currently stored. - device = weight.device - - # Transpose the weight tensor and make a contiguous copy of it on the CPU. - # The contiguous() function is used to ensure that the tensor is stored in a contiguous block of memory, - # which can improve performance in some cases. - weight_transposed = torch.t(weight) - weight_contiguous = weight_transposed.contiguous() - weight_cpu = weight_contiguous.cpu() - - # Quantize the weights. The quant_weights function is assumed to perform the quantization. - # The weights are quantized to int8 format, and the quantization is not performed in place (False). - weight_quantized, scale = quant_weights(weight_cpu, torch.int8, False) - - # Move the quantized weights and the scale back to the original device (GPU if available). - # The cuda() function is used to move the tensors to the GPU. - self.weight = weight_quantized.cuda(device) - self.scale = scale.cuda(device) - - # If a bias is present, move it to the GPU as well. If not, set the bias to None. - if bias is not None: - self.bias = bias.cuda(device) - else: - self.bias = None + self.bias = None def forward(self, input: torch.Tensor) -> torch.Tensor: """