Skip to content

Commit

Permalink
[Fix] Fix conflicts in lite (#878)
Browse files Browse the repository at this point in the history
* cherry-pick Fix meta tensor error commits

* fix smooth quant

---------

Co-authored-by: pppppM <[email protected]>
  • Loading branch information
HIT-cwh and pppppM authored Dec 21, 2023
1 parent 10a8912 commit 32bc114
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 101 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def calibrate(model: str,
work_dir.mkdir(parents=True, exist_ok=True)
calib_ctx.export(work_dir)

return model, tokenizer, work_dir


if __name__ == '__main__':
import fire
Expand Down
77 changes: 9 additions & 68 deletions lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import fire
import torch
from torch import nn
from transformers import AutoTokenizer

from lmdeploy.lite.quantization import CalibrationContext
from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
smooth_layers)
from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders,
load_hf_from_pretrained)
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.models import QLinear, QRMSNorm

from .calibrate import calibrate

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
Expand Down Expand Up @@ -50,73 +49,19 @@
}


def calibrate(model,
tokenizer,
calib_dataset: str = 'c4',
calib_samples: int = 128,
calib_seqlen: int = 2048,
device: str = 'cuda') -> None:
"""The main function for loading the model and performing calibration on a
given dataset.
Args:
model (nn.Module): The transformers model.
tokenizer: The corresponding tokenizer.
calib_dataset (str, optional): The calibration dataset name.
Defaults to 'c4'.
calib_samples (int, optional): The number of samples for calibration.
Defaults to 128.
calib_seqlen (int, optional): The sequence length for calibration.
Defaults to 2048.
device (str, optional): The device to be used for calculation.
Defaults to 'cuda'.
"""

assert calib_dataset in ['c4', 'ptb', 'wikitext2', 'pileval'], \
'Support only `c4`, `ptb`, `wikitext2` or `pileval`.'

layer_type = LAYER_TYPE_MAP[type(model).__name__]
norm_type = NORM_TYPE_MAP[type(model).__name__]

print('Loading calibrate dataset ...')
calib_loader, _ = get_calib_loaders(calib_dataset,
tokenizer,
nsamples=calib_samples,
seqlen=calib_seqlen)

# Initialize calibration context
calib_ctx = CalibrationContext(model,
tokenizer,
layer_type=layer_type,
norm_type=norm_type,
device=device)

with calib_ctx:
all_data = torch.cat([
data if isinstance(data, torch.Tensor) else data[0]
for data in calib_loader
]).to(device)
calib_ctx.calibrate(all_data)

inp_stats = calib_ctx.collect_inputs_stats()
return inp_stats


def smooth_quant(model: str,
work_dir: str = './work_dir',
calib_dataset: str = 'c4',
calib_samples: int = 128,
calib_seqlen: int = 2048,
device: str = 'cuda'):

# Load tokenizer and configuration
tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False,
trust_remote_code=True)

model = load_hf_from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
model, tokenizer, work_dir = calibrate(model, calib_dataset, calib_samples,
calib_seqlen, work_dir, device)
# calibrate function exports the calibration statistics
# (inputs, outputs, keys and values) to `work_dir`.
inp_stats = torch.load(work_dir / 'inputs_stats.pth')
act_scales = inp_stats['absmax']

model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
Expand All @@ -139,10 +84,6 @@ def smooth_quant(model: str,
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]

inp_stats = calibrate(model, tokenizer, calib_dataset, calib_samples,
calib_seqlen, device)
act_scales = inp_stats['absmax']

layers = collect_target_modules(model, layer_type)
fcs = {}
for l_name, layer in layers.items():
Expand Down
46 changes: 13 additions & 33 deletions lmdeploy/lite/utils/load.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.accel import LoadNoInit

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer',
}

def load_hf_from_pretrained(pretrained_model_name_or_path,
dtype=torch.float16,
**kwargs):

def load_hf_from_pretrained(pretrained_model_name_or_path, **kwargs):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
raise RuntimeError('Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)')

kwargs.pop('config', None)

hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
torch_dtype=torch.float16,
torch_dtype=dtype,
trust_remote_code=True)

# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True
# HACK hard code for qwen, other configs do not have the `fp16` attribute.
if dtype == torch.float16:
hf_config.fp16 = True
elif dtype == torch.bfloat16:
hf_config.bf16 = True

with init_empty_weights():
with LoadNoInit():
# Load model
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, config=hf_config, **kwargs)
model.config.use_cache = False
layer_type = LAYER_TYPE_MAP[type(model).__name__]
decoder_layers = collect_target_modules(model, layer_type)
# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
if 'device_map' in kwargs:
kwargs.pop('device_map')
with LoadNoInit():
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device_map=device_map,
config=hf_config,
**kwargs)
model.config.use_cache = False

return model

0 comments on commit 32bc114

Please sign in to comment.