diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index b164c638f8..9d3b543986 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -179,6 +179,8 @@ def calibrate(model: str, work_dir.mkdir(parents=True, exist_ok=True) calib_ctx.export(work_dir) + return model, tokenizer, work_dir + if __name__ == '__main__': import fire diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index c874286df4..b209825395 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -5,15 +5,14 @@ import fire import torch from torch import nn -from transformers import AutoTokenizer -from lmdeploy.lite.quantization import CalibrationContext from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP, smooth_layers) -from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders, - load_hf_from_pretrained) +from lmdeploy.lite.utils import collect_target_modules from lmdeploy.pytorch.models import QLinear, QRMSNorm +from .calibrate import calibrate + LAYER_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMDecoderLayer', 'QWenLMHeadModel': 'QWenBlock', @@ -50,58 +49,6 @@ } -def calibrate(model, - tokenizer, - calib_dataset: str = 'c4', - calib_samples: int = 128, - calib_seqlen: int = 2048, - device: str = 'cuda') -> None: - """The main function for loading the model and performing calibration on a - given dataset. - - Args: - model (nn.Module): The transformers model. - tokenizer: The corresponding tokenizer. - calib_dataset (str, optional): The calibration dataset name. - Defaults to 'c4'. - calib_samples (int, optional): The number of samples for calibration. - Defaults to 128. - calib_seqlen (int, optional): The sequence length for calibration. - Defaults to 2048. - device (str, optional): The device to be used for calculation. - Defaults to 'cuda'. - """ - - assert calib_dataset in ['c4', 'ptb', 'wikitext2', 'pileval'], \ - 'Support only `c4`, `ptb`, `wikitext2` or `pileval`.' - - layer_type = LAYER_TYPE_MAP[type(model).__name__] - norm_type = NORM_TYPE_MAP[type(model).__name__] - - print('Loading calibrate dataset ...') - calib_loader, _ = get_calib_loaders(calib_dataset, - tokenizer, - nsamples=calib_samples, - seqlen=calib_seqlen) - - # Initialize calibration context - calib_ctx = CalibrationContext(model, - tokenizer, - layer_type=layer_type, - norm_type=norm_type, - device=device) - - with calib_ctx: - all_data = torch.cat([ - data if isinstance(data, torch.Tensor) else data[0] - for data in calib_loader - ]).to(device) - calib_ctx.calibrate(all_data) - - inp_stats = calib_ctx.collect_inputs_stats() - return inp_stats - - def smooth_quant(model: str, work_dir: str = './work_dir', calib_dataset: str = 'c4', @@ -109,14 +56,12 @@ def smooth_quant(model: str, calib_seqlen: int = 2048, device: str = 'cuda'): - # Load tokenizer and configuration - tokenizer = AutoTokenizer.from_pretrained(model, - use_fast=False, - trust_remote_code=True) - - model = load_hf_from_pretrained(model, - torch_dtype=torch.float16, - trust_remote_code=True) + model, tokenizer, work_dir = calibrate(model, calib_dataset, calib_samples, + calib_seqlen, work_dir, device) + # calibrate function exports the calibration statistics + # (inputs, outputs, keys and values) to `work_dir`. + inp_stats = torch.load(work_dir / 'inputs_stats.pth') + act_scales = inp_stats['absmax'] model_type = type(model).__name__ if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP: @@ -139,10 +84,6 @@ def smooth_quant(model: str, fc2fcs = FC_FCS_MAP[layer_type] norm2fcs = NORM_FCS_MAP[layer_type] - inp_stats = calibrate(model, tokenizer, calib_dataset, calib_samples, - calib_seqlen, device) - act_scales = inp_stats['absmax'] - layers = collect_target_modules(model, layer_type) fcs = {} for l_name, layer in layers.items(): diff --git a/lmdeploy/lite/utils/load.py b/lmdeploy/lite/utils/load.py index f24b9216b4..bfd306a743 100644 --- a/lmdeploy/lite/utils/load.py +++ b/lmdeploy/lite/utils/load.py @@ -1,55 +1,35 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from accelerate import infer_auto_device_map, init_empty_weights from transformers import AutoConfig, AutoModelForCausalLM -from lmdeploy.lite.utils import collect_target_modules from lmdeploy.pytorch.accel import LoadNoInit -LAYER_TYPE_MAP = { - 'InternLMForCausalLM': 'InternLMDecoderLayer', - 'QWenLMHeadModel': 'QWenBlock', - 'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B - 'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B - 'LlamaForCausalLM': 'LlamaDecoderLayer', -} +def load_hf_from_pretrained(pretrained_model_name_or_path, + dtype=torch.float16, + **kwargs): -def load_hf_from_pretrained(pretrained_model_name_or_path, **kwargs): + if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + raise RuntimeError('Your device does not supports bf16(bfloat16), ' + 'please change to fp16(float16)') kwargs.pop('config', None) hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, - torch_dtype=torch.float16, + torch_dtype=dtype, trust_remote_code=True) - # hard code for qwen, other configs do not have the `fp16` attribute. - hf_config.fp16 = True + # HACK hard code for qwen, other configs do not have the `fp16` attribute. + if dtype == torch.float16: + hf_config.fp16 = True + elif dtype == torch.bfloat16: + hf_config.bf16 = True - with init_empty_weights(): + with LoadNoInit(): # Load model model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, config=hf_config, **kwargs) model.config.use_cache = False - layer_type = LAYER_TYPE_MAP[type(model).__name__] - decoder_layers = collect_target_modules(model, layer_type) - # Infer device map - device_map = infer_auto_device_map(model, - no_split_module_classes=[layer_type]) - for name in device_map.keys(): - if name in decoder_layers or 'lm_head' in name: - device_map[name] = 'cpu' - else: - device_map[name] = 0 - if 'device_map' in kwargs: - kwargs.pop('device_map') - with LoadNoInit(): - model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, - device_map=device_map, - config=hf_config, - **kwargs) - model.config.use_cache = False return model