diff --git a/autotest/config.yaml b/autotest/config.yaml index 6c92d2cf0b..e31a40f0d4 100644 --- a/autotest/config.yaml +++ b/autotest/config.yaml @@ -163,8 +163,6 @@ pytorch_base_model: turbomind_quatization: no_awq: - - Qwen/Qwen2-VL-2B-Instruct - - Qwen/Qwen2-VL-7B-Instruct - mistralai/Mistral-7B-Instruct-v0.3 - deepseek-ai/deepseek-coder-1.3b-instruct - codellama/CodeLlama-7b-Instruct-hf @@ -189,6 +187,8 @@ pytorch_quatization: - Qwen/Qwen2-7B-Instruct - Qwen/Qwen2-1.5B-Instruct - microsoft/Phi-3-mini-4k-instruct + - Qwen/Qwen2-VL-2B-Instruct + - Qwen/Qwen2-VL-7B-Instruct w8a8: - meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Llama-2-7b-chat-hf diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index cd5178793d..0780e93594 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -26,6 +26,7 @@ 'Phi3ForCausalLM': 'Phi3DecoderLayer', 'ChatGLMForConditionalGeneration': 'GLMBlock', 'MixtralForCausalLM': 'MixtralDecoderLayer', + 'Qwen2VLForConditionalGeneration': 'Qwen2VLDecoderLayer', } NORM_TYPE_MAP = { @@ -42,6 +43,7 @@ 'Phi3ForCausalLM': 'Phi3RMSNorm', 'ChatGLMForConditionalGeneration': 'RMSNorm', 'MixtralForCausalLM': 'MixtralRMSNorm', + 'Qwen2VLForConditionalGeneration': 'Qwen2RMSNorm', } HEAD_NAME_MAP = { @@ -58,6 +60,7 @@ 'Phi3ForCausalLM': 'lm_head', 'ChatGLMForConditionalGeneration': 'output_layer', 'MixtralForCausalLM': 'lm_head', + 'Qwen2VLForConditionalGeneration': 'lm_head', } diff --git a/lmdeploy/lite/quantization/awq.py b/lmdeploy/lite/quantization/awq.py index 068ad9357e..2efe41b6da 100644 --- a/lmdeploy/lite/quantization/awq.py +++ b/lmdeploy/lite/quantization/awq.py @@ -45,6 +45,11 @@ ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], 'post_attention_layernorm': ['block_sparse_moe.experts.{i}.w1', 'block_sparse_moe.experts.{i}.w3'] + }, + 'Qwen2VLDecoderLayer': { + 'input_layernorm': + ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], + 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] } } @@ -83,6 +88,10 @@ 'MixtralDecoderLayer': { 'self_attn.v_proj': ['self_attn.o_proj'], 'block_sparse_moe.experts.{i}.w3': ['block_sparse_moe.experts.{i}.w2'] + }, + 'Qwen2VLDecoderLayer': { + 'self_attn.v_proj': ['self_attn.o_proj'], + 'mlp.up_proj': ['mlp.down_proj'] } } diff --git a/lmdeploy/lite/utils/batch_split.py b/lmdeploy/lite/utils/batch_split.py index 3bd208f609..4e30f61d34 100644 --- a/lmdeploy/lite/utils/batch_split.py +++ b/lmdeploy/lite/utils/batch_split.py @@ -46,6 +46,14 @@ def split_decoder_layer_inputs( for name, val in kwargs.items(): if isinstance(val, torch.Tensor) and val.size(0) == bs: new_kwargs[name] = val[i:i + batch_size] + elif isinstance(val, torch.Tensor) and len( + val.shape) > 1 and val.size(1) == bs: # qwen2-vl + new_kwargs[name] = val[:, i:i + batch_size] + elif name == 'position_embeddings' and isinstance( + val, Tuple) and len( + val[0].shape) > 1 and val[0].size(1) == bs: # qwen2-vl + new_kwargs[name] = (val[0][:, i:i + batch_size], + val[1][:, i:i + batch_size]) else: new_kwargs[name] = val diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py index 2e53d8e0f0..3eb3c1541c 100644 --- a/lmdeploy/vl/model/qwen2.py +++ b/lmdeploy/vl/model/qwen2.py @@ -33,33 +33,35 @@ class Qwen2VLModel(VisonModel): def build_model(self): check_qwen_vl_deps_install() - - from accelerate import init_empty_weights - with init_empty_weights(): - config = self.hf_config - config.quantization_config = {} # disable vision part quantization - # disable accelerate check_tied_parameters_in_config - # for Qwen2-VL-2B-Instruct - config.tie_word_embeddings = False - - from transformers import Qwen2VLForConditionalGeneration - model = Qwen2VLForConditionalGeneration._from_config(config) - if not self.with_llm: + from transformers import Qwen2VLForConditionalGeneration + if self.with_llm: + model = Qwen2VLForConditionalGeneration.from_pretrained( + self.hf_config._name_or_path, trust_remote_code=True) + model.half() + self.vl_model = model + else: + from accelerate import init_empty_weights + with init_empty_weights(): + config = self.hf_config + config.quantization_config = { + } # disable vision part quantization + # disable accelerate check_tied_parameters_in_config + # for Qwen2-VL-2B-Instruct + config.tie_word_embeddings = False + + model = Qwen2VLForConditionalGeneration._from_config(config) del model.model del model.lm_head - else: - self.vl_model = model - model.half() - - from accelerate import load_checkpoint_and_dispatch - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - checkpoint=self.model_path, - device_map='auto' if not self.with_llm else {'': 'cpu'}, - max_memory=self.max_memory, - no_split_module_classes=['Qwen2VLVisionBlock'], - dtype=torch.half) + model.half() + from accelerate import load_checkpoint_and_dispatch + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + checkpoint=self.model_path, + device_map='auto' if not self.with_llm else {'': 'cpu'}, + max_memory=self.max_memory, + no_split_module_classes=['Qwen2VLVisionBlock'], + dtype=torch.half) self.model = model.eval()