Skip to content

Commit

Permalink
Support minicpm3-4b
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Sep 13, 2024
1 parent e8a1a33 commit 5cedcb9
Show file tree
Hide file tree
Showing 4 changed files with 673 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ def match(cls, model_path: str) -> Optional[str]:


@MODELS.register_module(name='minicpmv-2d6')
@MODELS.register_module(name='minicpm3')
@MODELS.register_module(name='qwen')
class Qwen7BChat(BaseChatTemplate):
"""Chat template for Qwen-7B-Chat."""
Expand Down Expand Up @@ -924,6 +925,8 @@ def match(cls, model_path: str) -> Optional[str]:
return 'qwen'
if 'minicpm-v-2_6' in model_path.lower():
return 'minicpmv-2d6'
if 'minicpm3-' in model_path.lower():
return 'minicpm3'


@MODELS.register_module(name='codellama')
Expand Down
32 changes: 32 additions & 0 deletions lmdeploy/pytorch/configurations/minicpm3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder


class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):

@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.architectures[0] in ['MiniCPM3ForCausalLM']

@classmethod
def build(cls, hf_config, model_path: str = None):
"""build."""
head_dim = (hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim)
k_head_dim = head_dim
v_head_dim = head_dim
num_attention_heads = hf_config.num_attention_heads
num_key_value_heads = hf_config.num_key_value_heads
return ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
k_head_dim=k_head_dim,
v_head_dim=v_head_dim,
vocab_size=hf_config.vocab_size,
multi_query_attention=True)
Loading

0 comments on commit 5cedcb9

Please sign in to comment.