diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index ecddd44892..c874286df4 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -4,15 +4,14 @@ import fire import torch -from accelerate import (infer_auto_device_map, init_empty_weights, - load_checkpoint_in_model) from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer from lmdeploy.lite.quantization import CalibrationContext from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP, smooth_layers) -from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders +from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders, + load_hf_from_pretrained) from lmdeploy.pytorch.models import QLinear, QRMSNorm LAYER_TYPE_MAP = { @@ -114,36 +113,32 @@ def smooth_quant(model: str, tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) - hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True) - checkpoint = hf_config._name_or_path - with init_empty_weights(): - # Load model - model = AutoModelForCausalLM.from_pretrained(model, - torch_dtype=torch.float16, - trust_remote_code=True) - model.config.use_cache = False + model = load_hf_from_pretrained(model, + torch_dtype=torch.float16, + trust_remote_code=True) + + model_type = type(model).__name__ + if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP: + raise RuntimeError( + f'Currently, quantification and calibration of {model_type} are ' + f'not supported. The supported model types are ' + f"{', '.join(LAYER_TYPE_MAP.keys())}.") + + if model_type == 'QWenLMHeadModel': + try: + import flash_attn # noqa: F401 + except ImportError: + raise RuntimeError( + 'When using Qwen, you need to `pip install flash-attn` first, ' + 'otherwise calibration and quantification will not work ' + 'properly.') layer_type = LAYER_TYPE_MAP[type(model).__name__] norm_type = NORM_TYPE_MAP[type(model).__name__] fc2fcs = FC_FCS_MAP[layer_type] norm2fcs = NORM_FCS_MAP[layer_type] - decoder_layers = collect_target_modules(model, layer_type) - - # Infer device map - device_map = infer_auto_device_map(model, - no_split_module_classes=[layer_type]) - for name in device_map.keys(): - if name in decoder_layers or 'lm_head' in name: - device_map[name] = 'cpu' - else: - device_map[name] = 0 - load_checkpoint_in_model(model, - checkpoint, - device_map, - dtype=torch.float16) - inp_stats = calibrate(model, tokenizer, calib_dataset, calib_samples, calib_seqlen, device) act_scales = inp_stats['absmax'] diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py index 55aefd2389..3ab81aaf31 100644 --- a/lmdeploy/lite/utils/load.py +++ b/lmdeploy/lite/utils/load.py @@ -4,8 +4,8 @@ from accelerate import infer_auto_device_map, init_empty_weights from transformers import AutoConfig, AutoModelForCausalLM +from lmdeploy.legacy.pytorch.model import LoadWoInit from lmdeploy.lite.utils import collect_target_modules -from lmdeploy.pytorch.model import LoadWoInit LAYER_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMDecoderLayer',