From fd8906c1c4bc37a359b9677d0cbef694a23ab00e Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Thu, 14 Nov 2024 13:07:01 +0800 Subject: [PATCH] Support molmo in turbomind (#2716) * initial moe support * dynamic grouped gemm * benchmark * moe benchmark * moe sampling * split-k * refactor tuning * simplify * n-major weight * add `num` for `MatrixLayout` * packed rows * packed cols * dispatch for packed rows * w4a16 moe * refactor model loading * fix pytorch loader * refactor * dispatch w4a16 moe * fix loader * add comment * fix msvc build * fix msvc build * fix msvc build * fix ut * fix ut * fix p-lora * add all support arches * minor * fix lint * fix lint * fix lint * fix ut * bf16 support * minor * checkin molmo conversion * add chat template * refactor * fix lint * fix ut * Just for test: hardcode vocab_size * minor * minor * minor * fix inter_size config * load with non-standard filenames * fix loader * fix missing default param * defer the loading of misc weights for safetensors * add embedding_size * update * update * tmp * tmp * update molmo template * vision embedding * fix * update * fix * fix messages2prompt in templates * fix order of out_messages * fix * add user guide * update is_supported --------- Co-authored-by: Li Zhang --- docs/en/multi_modal/index.rst | 2 + docs/en/multi_modal/molmo.md | 92 +++++++++ docs/zh_cn/multi_modal/index.rst | 2 + docs/zh_cn/multi_modal/molmo.md | 92 +++++++++ lmdeploy/archs.py | 3 +- lmdeploy/model.py | 31 +++ lmdeploy/serve/vl_async_engine.py | 5 + lmdeploy/turbomind/deploy/config.py | 7 + .../turbomind/deploy/source_model/__init__.py | 1 + .../turbomind/deploy/source_model/molmo.py | 122 ++++++++++++ .../turbomind/deploy/target_model/base.py | 3 + lmdeploy/turbomind/supported_models.py | 8 +- lmdeploy/vl/model/builder.py | 10 +- lmdeploy/vl/model/molmo.py | 177 ++++++++++++++++++ lmdeploy/vl/templates.py | 80 ++++++++ src/turbomind/models/llama/LlamaWeight.cc | 15 +- src/turbomind/models/llama/LlamaWeight.h | 2 + src/turbomind/models/llama/llama_params.h | 1 + .../triton_backend/llama/LlamaTritonModel.cc | 8 + 19 files changed, 653 insertions(+), 8 deletions(-) create mode 100644 docs/en/multi_modal/molmo.md create mode 100644 docs/zh_cn/multi_modal/molmo.md create mode 100644 lmdeploy/turbomind/deploy/source_model/molmo.py create mode 100644 lmdeploy/vl/model/molmo.py diff --git a/docs/en/multi_modal/index.rst b/docs/en/multi_modal/index.rst index 62f724070f..a68fe3da4f 100644 --- a/docs/en/multi_modal/index.rst +++ b/docs/en/multi_modal/index.rst @@ -12,3 +12,5 @@ Vision-Language Models minicpmv.md phi3.md mllama.md + qwen2_vl.md + molmo.md diff --git a/docs/en/multi_modal/molmo.md b/docs/en/multi_modal/molmo.md new file mode 100644 index 0000000000..dfff43dc64 --- /dev/null +++ b/docs/en/multi_modal/molmo.md @@ -0,0 +1,92 @@ +# Molmo + +LMDeploy supports the following molmo series of models, which are detailed in the table below: + +| Model | Size | Supported Inference Engine | +| :-------------: | :--: | :------------------------: | +| Molmo-7B-D-0924 | 7B | TurboMind | +| Molmo-72-0924 | 72B | TurboMind | + +The next chapter demonstrates how to deploy a molmo model using LMDeploy, with [Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) as an example. + +## Installation + +Please install LMDeploy by following the [installation guide](../get_started/installation.md) + +## Offline inference + +The following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md) + +```python +from lmdeploy import pipeline +from lmdeploy.vl import load_image + +pipe = pipeline('allenai/Molmo-7B-D-0924') + +image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') +response = pipe((f'describe this image', image)) +print(response) +``` + +More examples are listed below: + +
+ + multi-image multi-round conversation, combined images + + +```python +from lmdeploy import pipeline, GenerationConfig + +pipe = pipeline('allenai/Molmo-7B-D-0924', log_level='INFO') +messages = [ + dict(role='user', content=[ + dict(type='text', text='Describe the two images in detail.'), + dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), + dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) + ]) +] +out = pipe(messages, gen_config=GenerationConfig(do_sample=False)) + +messages.append(dict(role='assistant', content=out.text)) +messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) +out = pipe(messages, gen_config=GenerationConfig(do_sample=False)) +``` + +
+ +## Online serving + +You can launch the server by the `lmdeploy serve api_server` CLI: + +```shell +lmdeploy serve api_server allenai/Molmo-7B-D-0924 +``` + +You can also start the service using the docker image: + +```shell +docker run --runtime nvidia --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HUGGING_FACE_HUB_TOKEN=" \ + -p 23333:23333 \ + --ipc=host \ + openmmlab/lmdeploy:latest \ + lmdeploy serve api_server allenai/Molmo-7B-D-0924 +``` + +If you find the following logs, it means the service launches successfully. + +```text +HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! +HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! +HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! +INFO: Started server process [2439] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) +``` + +The arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`. + +More information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md) diff --git a/docs/zh_cn/multi_modal/index.rst b/docs/zh_cn/multi_modal/index.rst index 0942d8d31c..bd141ea90f 100644 --- a/docs/zh_cn/multi_modal/index.rst +++ b/docs/zh_cn/multi_modal/index.rst @@ -12,3 +12,5 @@ minicpmv.md phi3.md mllama.md + qwen2_vl.md + molmo.md diff --git a/docs/zh_cn/multi_modal/molmo.md b/docs/zh_cn/multi_modal/molmo.md new file mode 100644 index 0000000000..1dc8f8f79b --- /dev/null +++ b/docs/zh_cn/multi_modal/molmo.md @@ -0,0 +1,92 @@ +# Qwen2-VL + +LMDeploy 支持 Molmo 系列模型,具体如下: + +| Model | Size | Supported Inference Engine | +| :-------------: | :--: | :------------------------: | +| Molmo-7B-D-0924 | 7B | TurboMind | +| Molmo-72-0924 | 72B | TurboMind | + +本文将以[Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) 为例,演示使用 LMDeploy 部署 Molmo 系列模型的方法 + +## 安装 + +请参考[安装文档](../get_started/installation.md)安装 LMDeploy。 + +## 离线推理 + +以下是使用 pipeline 进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md) + +```python +from lmdeploy import pipeline +from lmdeploy.vl import load_image + +pipe = pipeline('allenai/Molmo-7B-D-0924') + +image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') +response = pipe((f'describe this image', image)) +print(response) +``` + +更多例子如下: + +
+ + 多图多轮对话 + + +```python +from lmdeploy import pipeline, GenerationConfig + +pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO') +messages = [ + dict(role='user', content=[ + dict(type='text', text='Describe the two images in detail.'), + dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')), + dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg')) + ]) +] +out = pipe(messages, gen_config=GenerationConfig(top_k=1)) + +messages.append(dict(role='assistant', content=out.text)) +messages.append(dict(role='user', content='What are the similarities and differences between these two images.')) +out = pipe(messages, gen_config=GenerationConfig(top_k=1)) +``` + +
+ +## 在线服务 + +你可以通过 `lmdeploy serve api_server` CLI 工具启动服务: + +```shell +lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct +``` + +也可以基于 docker image 启动服务: + +```shell +docker run --runtime nvidia --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HUGGING_FACE_HUB_TOKEN=" \ + -p 23333:23333 \ + --ipc=host \ + openmmlab/lmdeploy:qwen2vl \ + lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct +``` + +如果日志中有如下信息,就表明服务启动成功了。 + +```text +HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! +HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! +HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!! +INFO: Started server process [2439] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit) +``` + +有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。 + +关于 `api_server` 更多的介绍,以及访问 `api_server` 的方法,请阅读[此处](api_server_vl.md) diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index 8284c99741..ce5cbd98ff 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -121,7 +121,8 @@ def check_vl_llm(config: dict) -> bool: 'InternVLChatModel', 'MiniGeminiLlamaForCausalLM', 'MGMLlamaForCausalLM', 'MiniCPMV', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM', - 'Qwen2VLForConditionalGeneration', 'MllamaForConditionalGeneration' + 'Qwen2VLForConditionalGeneration', 'MllamaForConditionalGeneration', + 'MolmoForCausalLM' ]) if arch == 'QWenLMHeadModel' and 'visual' in config: return True diff --git a/lmdeploy/model.py b/lmdeploy/model.py index db864a8344..c9eb71c2c3 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -1747,6 +1747,37 @@ def match(cls, model_path: str) -> Optional[str]: return 'internvl-phi3' +@MODELS.register_module(name='molmo') +class Molmo(BaseChatTemplate): + + def __init__(self, + user=' User: ', + eoh='', + assistant=' Assistant:', + eoa='', + separator=' ', + stop_words=['<|endoftext|>'], + **kwargs): + super().__init__(user=user, + eoh=eoh, + assistant=assistant, + eoa=eoa, + separator=separator, + stop_words=stop_words, + **kwargs) + + @classmethod + def match(cls, model_path: str) -> Optional[str]: + """Return the model_name that was registered to MODELS. + + Args: + model_path (str): the model path used for matching. + """ + path = model_path.lower() + if 'molmo' in path: + return 'molmo' + + def best_match_model(query: str) -> Optional[str]: """Get the model that matches the query. diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index fd0b0bb5e4..c293cd71c8 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -64,6 +64,7 @@ async def _get_prompt_input(self, results = {} input_ids = [] from lmdeploy.vl.templates import (MllamaTempateWrapper, + MolmoChatTemplateWrapper, Qwen2VLChatTemplateWrapper) ranges = None grid_thws = None @@ -99,6 +100,10 @@ async def _get_prompt_input(self, results['cross_attention_states'] = features[0] return results + if isinstance(self.vl_prompt_template, + MolmoChatTemplateWrapper): + return features[0] + features = [x.cpu().numpy() for x in features] input_ids = [] begins = [] diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index a535b0d4c1..c724b085a0 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -35,6 +35,13 @@ class ModelConfig: kv_head_num: int = None hidden_units: int = None vocab_size: int = None + # Turbomind used to assume token_embedding and lm_head has the same size + # at vocab dim, i.e. `vocab_size` + # But in molmo, embedding.shape is [vocab_size + 128, hidden_units] + # while lm_head shape is [hidden_units, vocab_size]. + # Therefore, we add a new attr "embedding_size" to represent the vocab dim + # of token_embedding + embedding_size: int = 0 num_layer: int = None inter_size: int = None norm_eps: float = None diff --git a/lmdeploy/turbomind/deploy/source_model/__init__.py b/lmdeploy/turbomind/deploy/source_model/__init__.py index b1da698e2e..de16bdc0a0 100644 --- a/lmdeploy/turbomind/deploy/source_model/__init__.py +++ b/lmdeploy/turbomind/deploy/source_model/__init__.py @@ -9,5 +9,6 @@ from .meta_llama import MetaLlamaModel # noqa: F401 from .minicpmv import MiniCPMVModel # noqa: F401 from .mixtral import MixtralModel # noqa: F401 +from .molmo import MolmoModel # noqa: F401 from .qwen import QwenModel # noqa: F401 from .xcomposer2 import Xcomposer2Model # noqa: F401 diff --git a/lmdeploy/turbomind/deploy/source_model/molmo.py b/lmdeploy/turbomind/deploy/source_model/molmo.py new file mode 100644 index 0000000000..541e201046 --- /dev/null +++ b/lmdeploy/turbomind/deploy/source_model/molmo.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp + +import torch + +from .base import INPUT_MODELS +from .llama import LlamaModel, LlamaReader + + +class MolmoReader(LlamaReader): + attn_layer_prefix = 'model.transformer.blocks' + attn_layer_patten = r'model.transformer.blocks.([0-9]+).' + norm_weight_key = 'model.transformer.ln_f.weight' + output_weight_key = 'model.transformer.ff_out.weight' + + # In molmo, names of attention parameters are "att_proj.bias", + # "att_proj.weight", "attn_norm.weight", "attn_out.weight", and names + # of ffn parameters are "ff_norm", "ff_out", "ff_proj", so we + # make the patterns are r'att' and r'ffn_', respectively. + attn_pattern = r'att' + ffn_pattern = r'ff_' + + def tok_embeddings(self): + embed1 = self.params.get('model.transformer.wte.embedding', None) + embed2 = self.params.get('model.transformer.wte.new_embedding', None) + if embed1 is not None and embed2 is not None: + return torch.cat((embed1, embed2), dim=0) + else: + assert embed1 is None and embed2 is None + return None + + def attn_norm(self, i: int): + """Get attn norm for layer i.""" + return self.params[f'{self.attn_layer_prefix}.{i}.attn_norm.weight'] + + def _attn(self, i: int, kind: str): + """Get q, k, v, o kind(weight, bias, qweight) for layer i. + + Args: + i (int): layer id + kind (str): can be one of ["weight", "bias", "qweight"] + """ + q, k, v = (None, ) * 3 + hidden_size = self.model_cfg['hidden_size'] + head_num = self.model_cfg['num_attention_heads'] + kv_head_num = self.model_cfg['num_key_value_heads'] + head_dim = hidden_size // head_num + assert head_dim == 128 + fused_dims = (hidden_size, kv_head_num * head_dim, + kv_head_num * head_dim) + qkv = self.params.get(f'{self.attn_layer_prefix}.{i}.att_proj.{kind}') + qkv = self.transform(qkv, kind) + if qkv is not None: + q, k, v = qkv.split(fused_dims, dim=0) + o = self.params.get(f'{self.attn_layer_prefix}.{i}.attn_out.{kind}') + o = self.transform(o, kind) + if o is None: # handle the case when qkv has bias but o doesn't + o = torch.zeros_like(q) + return (q, k, v, o) + + def _ffn(self, i: int, kind: str): + """Get ffn kind(weight, qweight) for layer i.""" + up_and_gate = self.params[ + f'{self.attn_layer_prefix}.{i}.ff_proj.{kind}'] + up_and_gate = self.transform(up_and_gate, kind) + gate, up = up_and_gate.chunk(2, dim=0) + down = self.params[f'{self.attn_layer_prefix}.{i}.ff_out.{kind}'] + down = self.transform(down, kind) + return (up, down, gate) + + def ffn_norm(self, i: int): + """Get ffn norm for layer i.""" + return self.params[f'{self.attn_layer_prefix}.{i}.ff_norm.weight'] + + +@INPUT_MODELS.register_module(name='molmo') +class MolmoModel(LlamaModel): + + Reader = MolmoReader + + def __init__(self, model_path: str, tokenizer_path: str, **kwargs): + super().__init__(model_path, tokenizer_path, **kwargs) + config_path = osp.join(self.model_path, 'config.json') + with open(config_path) as f: + self.config = json.load(f) + + def tokenizer_info(self): + + n_words = 152064 + bos_id = 151643 + eos_id = 151643 + return n_words, bos_id, eos_id + + def model_info(self): + config = self.config + num_layer = config['num_hidden_layers'] + norm_eps = config['layer_norm_eps'] + attn_head_num = config['num_attention_heads'] + kv_head_num = config['num_key_value_heads'] + hidden_units = config['hidden_size'] + rope_theta = config['rope_theta'] + max_position_embeddings = config['max_position_embeddings'] + vocab_size = config['vocab_size'] + # https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/modeling_molmo.py#L2041 + additional_vocab_size = 128 + inter_size = config['intermediate_size'] // 2 + attn_bias = config['qkv_bias'] + return dict( + num_layer=num_layer, + norm_eps=norm_eps, + head_num=attn_head_num, + kv_head_num=kv_head_num, + hidden_units=hidden_units, + attn_bias=int(attn_bias), + inter_size=inter_size, + vocab_size=vocab_size, + # https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/modeling_molmo.py#L564 + embedding_size=vocab_size + additional_vocab_size, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + ) diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index abd570cd00..09699ade09 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -92,6 +92,9 @@ def update_model_config(self): final_cfg = config_to_dict(self.model_config) final_cfg.update(dict(start_id=bos_id, end_id=eos_id)) final_cfg.update(self.input_model_info) + if 'embedding_size' not in self.input_model_info.keys(): + final_cfg.update( + embedding_size=self.input_model_info['vocab_size']) self.model_config = config_from_dict(ModelConfig, final_cfg) diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index bb3533254b..e66da22df0 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -42,7 +42,9 @@ ChatGLMModel='glm4', ChatGLMForConditionalGeneration='glm4', # mixtral - MixtralForCausalLM='mixtral') + MixtralForCausalLM='mixtral', + MolmoForCausalLM='molmo', +) def is_supported(model_path: str): @@ -104,5 +106,9 @@ def _is_head_dim_supported(cfg): if llm_arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']: support_by_turbomind = _is_head_dim_supported( cfg.text_config) + elif arch == 'MolmoForCausalLM': + kv_heads = cfg.num_key_value_heads + # TM hasn't supported allenai/Molmo-7B-O-0924 yet + support_by_turbomind = kv_heads is not None return support_by_turbomind diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py index 9e71f7d1c0..2401b42259 100644 --- a/lmdeploy/vl/model/builder.py +++ b/lmdeploy/vl/model/builder.py @@ -18,6 +18,7 @@ from .mini_gemeni import MiniGeminiVisionModel # noqa F401 from .minicpmv import MiniCPMVModel # noqa F401 from .mllama import MllamaVLModel # noqa F401 +from .molmo import MolmoVisionModel # noqa F401 from .phi3_vision import Phi3VisionModel # noqa F401 from .qwen import QwenVisionModel # noqa F401 from .qwen2 import Qwen2VLModel # noqa F401 @@ -31,7 +32,14 @@ def load_vl_model(model_path: str, with_llm: bool = False, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None): - """load visual model.""" + """load visual model. + + Args: + model_path(str): the path or repo_id from model hub of the model + with_llm(bool): whether to remove the LLM part from the model. + When it is False, it means removing LLM part + backend_config: the config of the inference engine + """ if not os.path.exists(model_path): revision = getattr(backend_config, 'revision', None) download_dir = getattr(backend_config, 'download_dir', None) diff --git a/lmdeploy/vl/model/molmo.py b/lmdeploy/vl/model/molmo.py new file mode 100644 index 0000000000..9abae7a309 --- /dev/null +++ b/lmdeploy/vl/model/molmo.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Dict, List + +import torch +from PIL.Image import Image +from transformers import AutoModelForCausalLM, AutoProcessor + +from lmdeploy.utils import get_logger +from lmdeploy.vl.constants import IMAGE_TOKEN +from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.vl.model.utils import disable_logging + +logger = get_logger('lmdeploy') + + +@VISION_MODELS.register_module() +class MolmoVisionModel(VisonModel): + """molmo's vision model.""" + + _arch = 'MolmoForCausalLM' + + def build_model(self): + """Load model.""" + from accelerate import init_empty_weights, load_checkpoint_and_dispatch + with init_empty_weights(): + config = self.hf_config + model = AutoModelForCausalLM.from_config(config, + trust_remote_code=True) + if not self.with_llm: + # Remove nn modules other than embedding from the LLM model + for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']: + del model.model.transformer[key] + self.token_embedding = model.model.transformer.wte + else: + self.vl_model = model + + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + checkpoint=self.model_path, + device_map='auto' if not self.with_llm else {'': 'cpu'}, + max_memory=self.max_memory, + no_split_module_classes=[ + 'ResidualAttentionBlock', 'Embedding' + ]) + + # We need eval mode to freeze the weights in model, thus, + # avoid randomness in inference. + self.model = model.eval() + self.config = config + + self.processor = AutoProcessor.from_pretrained(self.model_path, + trust_remote_code=True, + torch_dtype='auto', + device_map='auto') + + @torch.no_grad() + def forward(self, + images: List[Image], + params: List[Dict] = None) -> List[Dict]: + """forward the model with given input. + + Args: + images (List): [None] it is not used + params (List): the inputs after precessing GPT4V messages in + `MolmoChatTemplateWrapper`. Its format is like the following: + [[ + {'role': 'user', 'content': 'user prompt'}, + {'role': 'asssistant', 'content': 'assistant prompt'}, + {'role': 'user', 'content': 'user prompt', 'images': [PIL image list]}, + ... + ]] + """ # noqa + + messages = params[0] + assert isinstance(messages, List) + # append an assistant message to `messages` + messages.append(dict(role='assistant', content='')) + # results is a list of tuple(input_ids, embeddings) + results = [] + # the concat prompt. It is not used during inference but to adhere the + # interface definition of `_get_prompt_input` in `class VLAsyncEngine` + prompts = '' + # Prepend BOS + # qwen2 and olmo do not have a BOS, and instead use EOS as a generic + # separator token. + bos = (self.processor.tokenizer.bos_token_id + or self.processor.tokenizer.eos_token_id) + results.append(([bos], None)) + for i, message in enumerate(messages): + if 'images' in message.keys(): + prompts += ' User: ' + (IMAGE_TOKEN + '\n') * len( + message['images']) + message['content'] + prompt = f' User: {message["content"]}' + tokens = self.processor.tokenizer.encode( + prompt, add_special_tokens=False) + # preprocess images. The output is a dict + inputs = self.processor.process(images=message['images'], + tokens=tokens) + inputs = { + k: v.to(self.model.device).unsqueeze(0) + for k, v in inputs.items() + } + input_ids = inputs['input_ids'] + # remove the bos from input_ids which is prepended by molmo's + # processor + input_ids = input_ids[:, 1:] + images = inputs[ + 'images'] # (batch_size, num_image, num_patch, d_model) + image_input_idx = inputs[ + 'image_input_idx'] # (batch_size, num_image, num_patch) + image_masks = inputs['image_masks'] + batch_size, seq_len = input_ids.size() + assert batch_size == 1 + + # Get embeddings of input. + if input_ids is not None: + input_ids = input_ids * (input_ids != -1).to( + input_ids.dtype) + embeddings = self.model.model.transformer.wte(input_ids) + image_features, _ = self.model.model.vision_backbone( + images, image_masks) + num_image, num_patch = image_features.shape[1:3] + assert image_input_idx.shape == (batch_size, num_image, + num_patch) + + # insert the image feature into the embedding. + image_features = image_features.view(batch_size, + num_image * num_patch, -1) + image_input_idx = image_input_idx.view(batch_size, + num_image * num_patch) + + valid = image_input_idx >= 0 + batch_idx = torch.arange(batch_size, device=embeddings.device) + batch_idx = torch.tile(batch_idx[:, None], + [1, image_features.shape[1]]) + image_features = image_features.to(embeddings.device) + embeddings[batch_idx[valid], + image_input_idx[valid]] += image_features[valid] + assert embeddings.shape[:2] == (batch_size, seq_len) + results.append((input_ids.flatten().tolist(), embeddings)) + else: + role = message['role'] + content = message['content'] + assert isinstance(content, str) + prompt = '' + if role == 'user': + prompt = f' User: {content}' + elif role == 'assistant': + prompt = f' Assistant:{content}' + else: + assert 0, f'molmo does not support role {role}, message is {message}' # noqa + input_ids = self.processor.tokenizer.encode( + prompt, add_special_tokens=False) + results.append((input_ids, None)) + prompts += prompt + + # concat input_ids from results, calculate the range in the input_ids + # where embeddings will be copied to + input_ids = [] + input_embeddings = [] + input_embedding_ranges = [] + start = 0 + for _input_ids, _embeddings in results: + if _embeddings is not None: + input_embeddings.append(_embeddings.cpu()) + end = start + len(_input_ids) + input_embedding_ranges.append((start, end)) + input_ids += _input_ids + start += len(_input_ids) + return [ + dict(prompt=prompts, + input_ids=input_ids, + input_embeddings=input_embeddings, + input_embedding_ranges=input_embedding_ranges) + ] diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py index 45e457ad2c..cdf398868a 100644 --- a/lmdeploy/vl/templates.py +++ b/lmdeploy/vl/templates.py @@ -428,6 +428,84 @@ class GLM4VChatTemplateWrapper(VLChatTemplateWrapper): pass +class MolmoChatTemplateWrapper(VLChatTemplateWrapper): + + async def async_collect_pil_images( + self, messages: List[Dict]) -> List[Tuple[PIL.Image.Image, Dict]]: + """collect images from messages. + + Args: + messages (List[Dict]): a user request of GPT4V message format + """ + if isinstance(messages, Dict): + messages = [messages] + assert isinstance(messages, List) + + out_messages = [None] * len(messages) + + def _inner_call(i, in_messages, out_messages): + role = in_messages[i]['role'] + content = in_messages[i]['content'] + if role != 'user' or isinstance(content, str): + # means message is user's prompt input or assistant's prompt, + # returning it directory + out_messages[i] = in_messages[i] + return + # the role is a user and the content is a list + assert isinstance(content, List) + message = dict(role=role, content='', images=[]) + for item in content: + # 'image_url': means url or local path to image. + # 'image_data': means PIL.Image.Image object. + if item['type'] == 'image_url': + try: + image = load_image(item['image_url']['url']) + message['images'].append(image) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'image_data': + try: + image = load_image(item['image_data']['data']) + message['images'].append(image) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'text': + message['content'] = item['text'] + else: + logger.error(f'unexpected content type {message}') + out_messages[i] = message + + await asyncio.gather(*[ + asyncio.get_event_loop().run_in_executor(None, _inner_call, i, + messages, out_messages) + for i in range(len(messages)) + ]) + return [(None, out_messages)] + + def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str: + """Return a placeholder "IMAGE_TOKEN" so that + `vl_asyn_engine._get_prompt_input` can know that it.""" + if isinstance(messages, str): + return self.chat_template.messages2prompt(messages, sequence_start) + else: + _messages = [] + for message in messages: + role, content = message['role'], message['content'] + if role != 'user' or isinstance(content, str): + _messages.append(message) + continue + for item in content: + item_type = item['type'] + if item_type in ['image_url', 'image_data']: + # Return the image placeholder so that + # `vl_asyn_engine._get_prompt_input` can know that the + # request contains images + return IMAGE_TOKEN + _messages.append(dict(role=role, content=item[item_type])) + return self.chat_template.messages2prompt(_messages, + sequence_start) + + def get_vl_prompt_template(model_path: str, chat_template: BaseModel, model_name: str) -> VLChatTemplateWrapper: """get vision language prompt template.""" @@ -467,4 +545,6 @@ def get_vl_prompt_template(model_path: str, chat_template: BaseModel, return GLM4VChatTemplateWrapper(chat_template) elif arch == 'Qwen2VLForConditionalGeneration': return Qwen2VLChatTemplateWrapper(chat_template) + elif arch == 'MolmoForCausalLM': + return MolmoChatTemplateWrapper(chat_template) raise ValueError(f'unsupported vl_prompt_template with arch {arch}') diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 1ac2d82dd9..9d62042d62 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -32,6 +32,7 @@ LlamaWeight::LlamaWeight(size_t head_num, size_t hidden_units, size_t inter_size, size_t vocab_size, + size_t embedding_size, size_t num_layer, bool attn_bias, WeightType weight_type, @@ -44,16 +45,20 @@ LlamaWeight::LlamaWeight(size_t head_num, inter_size_(inter_size), vocab_size_(vocab_size), vocab_size_padded_(vocab_size), + embedding_size_(embedding_size), num_layer_(num_layer), weight_type_(weight_type), tensor_para_size_(tensor_para_size), tensor_para_rank_(tensor_para_rank) { if (vocab_size_padded_ % tensor_para_size_ != 0) { - vocab_size_padded_ = (vocab_size_padded_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_; + vocab_size_padded_ = (vocab_size_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_; TM_LOG_WARNING("pad vocab size from %d to %d", vocab_size_, vocab_size_padded_); } - + if (embedding_size_ % tensor_para_size_ != 0) { + embedding_size_ = (embedding_size_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_; + TM_LOG_WARNING("pad embed size from %d to %d", embedding_size_, embedding_size_); + } FT_CHECK(hidden_units_ % tensor_para_size_ == 0); decoder_layer_weights.reserve(num_layer_); @@ -96,7 +101,7 @@ template void LlamaWeight::mallocWeights() { FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0); - deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_ / tensor_para_size_); + deviceMalloc((T**)&pre_decoder_embedding_table, embedding_size_ * hidden_units_ / tensor_para_size_); deviceMalloc((T**)&output_norm_weight, hidden_units_); deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_); } @@ -111,7 +116,7 @@ void LlamaWeight::loadModel(std::string dir_path) dir_path += '/'; loadWeightFromBin((T*)pre_decoder_embedding_table, - {vocab_size_padded_ * hidden_units_ / tensor_para_size_}, + {embedding_size_ * hidden_units_ / tensor_para_size_}, dir_path + "tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight", model_file_type); @@ -135,7 +140,7 @@ TensorMap LlamaWeight::getParams() output.insert("tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight", Tensor{MEMORY_GPU, getTensorType(), - {vocab_size_padded_ * hidden_units_ / tensor_para_size_ * sizeof(T)}, + {embedding_size_ * hidden_units_ / tensor_para_size_ * sizeof(T)}, pre_decoder_embedding_table}); output.insert("norm.weight", diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index c04bf6c5a6..c30e753565 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -35,6 +35,7 @@ struct LlamaWeight { size_t hidden_units, size_t inter_size, size_t vocab_size, + size_t embedding_size, size_t num_layer, bool attn_bias, WeightType weight_type, @@ -67,6 +68,7 @@ struct LlamaWeight { size_t inter_size_; size_t vocab_size_; size_t vocab_size_padded_; + size_t embedding_size_; size_t num_layer_; WeightType weight_type_; size_t tensor_para_size_; diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 2ea63f0410..e6b9d690ae 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -18,6 +18,7 @@ struct ModelParam { size_t layer_num; size_t inter_size; size_t vocab_size; + size_t embedding_size; float norm_eps; int quant_policy; // diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 38552be0cf..2deca46380 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -133,6 +133,12 @@ void LlamaTritonModel::handleMissingParams() (int)model_param_.kv_head_num); } + if (model_param_.embedding_size == 0) { + model_param_.embedding_size = model_param_.vocab_size; + TM_LOG_WARNING("[LlamaTritonModel] `embedding_size` is not set, default to `vocab_size` (%d).", + (int)model_param_.vocab_size); + } + if (!attn_param_.max_position_embeddings) { attn_param_.max_position_embeddings = 2048; TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.", @@ -252,6 +258,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, model_param_.layer_num = model_reader["num_layer"].as(); model_param_.inter_size = model_reader["inter_size"].as(); model_param_.vocab_size = model_reader["vocab_size"].as(); + model_param_.embedding_size = model_reader["embedding_size"].as(); model_param_.norm_eps = model_reader["norm_eps"].as(); model_param_.start_id = model_reader["start_id"].as(); model_param_.end_id = model_reader["end_id"].as(); @@ -417,6 +424,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) model_param_.hidden_units, model_param_.inter_size, model_param_.vocab_size, + model_param_.embedding_size, model_param_.layer_num, attn_bias_, weight_type_,