From 324237b2c9e223c2392088cecb57b3703d1f7d54 Mon Sep 17 00:00:00 2001 From: zhoushenglong <87467364+Reinerzhou@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:01:12 +0800 Subject: [PATCH] [Feature] support minicpm-v_2_6 for pytorch engine. (#2767) * support minicpmv_2_6. * update supported_models. * update supported_models. --- docs/en/supported_models/supported_models.md | 1 + .../supported_models/supported_models.md | 1 + lmdeploy/pytorch/models/minicpmv26.py | 430 ++++++++++++++++++ lmdeploy/pytorch/models/module_map.py | 6 + lmdeploy/pytorch/supported_models.py | 2 + 5 files changed, 440 insertions(+) create mode 100644 lmdeploy/pytorch/models/minicpmv26.py diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 283ce596f6..da52241253 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -72,6 +72,7 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | +| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | | Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | | Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | | StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 908f9a17f5..502e91b6d3 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -72,6 +72,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | +| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | | Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | | Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | | StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | diff --git a/lmdeploy/pytorch/models/minicpmv26.py b/lmdeploy/pytorch/models/minicpmv26.py new file mode 100644 index 0000000000..725e97d9d7 --- /dev/null +++ b/lmdeploy/pytorch/models/minicpmv26.py @@ -0,0 +1,430 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, + SiluAndMul, build_rotary_embedding, + build_rotary_params) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .utils.cudagraph import CudaGraphMixin + + +class MiniCPMV26Attention(nn.Module): + """Rewrite module of MiniCPMV26Attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=config.sliding_window, + ) + + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) + + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + k_scales_zeros=None + if len(past_key_value) == 2 else past_key_value[2], + v_scales_zeros=None + if len(past_key_value) == 2 else past_key_value[3], + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.o_proj(attn_output) + return attn_output + + +class MiniCPMV26MLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class MiniCPMV26DecoderLayer(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = MiniCPMV26Attention(config, + dtype=dtype, + device=device) + + # build MLP + self.mlp = MiniCPMV26MLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class MiniCPMV26Model(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + MiniCPMV26DecoderLayer(config, + layer_idx, + dtype=dtype, + device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding + rope_params = build_rotary_params(config) + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + **rope_params, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class MiniCPMVForCausalLM(nn.Module, CudaGraphMixin): + """rewrote model of MiniCPMVForCausalLM.""" + + packed_modules_mapping = { + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = MiniCPMV26Model(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """compute logits of the model output.""" + return self.lm_head(hidden_states) + + def update_weights(self): + """update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters(prefix='llm')) + for name, loaded_weight in weights: + if 'vpm' in name or 'resampler' in name: + continue + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index e6b5f6e29e..1059bfee4e 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -173,6 +173,12 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.minicpm3.MiniCPM3ForCausalLM', }) +# minicpmv2_6 +MODULE_MAP.update({ + 'MiniCPMV': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.minicpmv26.MiniCPMVForCausalLM', +}) + # mllama MODULE_MAP.update({ 'MllamaForConditionalGeneration': diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py index 21418188dd..7fa568651b 100644 --- a/lmdeploy/pytorch/supported_models.py +++ b/lmdeploy/pytorch/supported_models.py @@ -70,6 +70,8 @@ PhiMoEForCausalLM=True, # mllama MllamaForConditionalGeneration=True, + # MiniCPM-V-2_6 + MiniCPMVForCausalLM=True, )