Skip to content

Commit

Permalink
Adapt to the pyTorch poc branch
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Dec 18, 2023
1 parent e48bd07 commit c174db9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 28 deletions.
49 changes: 22 additions & 27 deletions lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

import fire
import torch
from accelerate import (infer_auto_device_map, init_empty_weights,
load_checkpoint_in_model)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer

from lmdeploy.lite.quantization import CalibrationContext
from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
smooth_layers)
from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders
from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders,
load_hf_from_pretrained)
from lmdeploy.pytorch.models import QLinear, QRMSNorm

LAYER_TYPE_MAP = {
Expand Down Expand Up @@ -114,36 +113,32 @@ def smooth_quant(model: str,
tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False,
trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
checkpoint = hf_config._name_or_path

with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
model.config.use_cache = False
model = load_hf_from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)

model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
raise RuntimeError(
f'Currently, quantification and calibration of {model_type} are '
f'not supported. The supported model types are '
f"{', '.join(LAYER_TYPE_MAP.keys())}.")

if model_type == 'QWenLMHeadModel':
try:
import flash_attn # noqa: F401
except ImportError:
raise RuntimeError(
'When using Qwen, you need to `pip install flash-attn` first, '
'otherwise calibration and quantification will not work '
'properly.')

layer_type = LAYER_TYPE_MAP[type(model).__name__]
norm_type = NORM_TYPE_MAP[type(model).__name__]
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]

decoder_layers = collect_target_modules(model, layer_type)

# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)

inp_stats = calibrate(model, tokenizer, calib_dataset, calib_samples,
calib_seqlen, device)
act_scales = inp_stats['absmax']
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/lite/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.legacy.pytorch.model import LoadWoInit
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.model import LoadWoInit

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
Expand Down

0 comments on commit c174db9

Please sign in to comment.