diff --git a/docs/en/multi_modal/index.rst b/docs/en/multi_modal/index.rst
index 62f724070f..a68fe3da4f 100644
--- a/docs/en/multi_modal/index.rst
+++ b/docs/en/multi_modal/index.rst
@@ -12,3 +12,5 @@ Vision-Language Models
minicpmv.md
phi3.md
mllama.md
+ qwen2_vl.md
+ molmo.md
diff --git a/docs/en/multi_modal/molmo.md b/docs/en/multi_modal/molmo.md
new file mode 100644
index 0000000000..dfff43dc64
--- /dev/null
+++ b/docs/en/multi_modal/molmo.md
@@ -0,0 +1,92 @@
+# Molmo
+
+LMDeploy supports the following molmo series of models, which are detailed in the table below:
+
+| Model | Size | Supported Inference Engine |
+| :-------------: | :--: | :------------------------: |
+| Molmo-7B-D-0924 | 7B | TurboMind |
+| Molmo-72-0924 | 72B | TurboMind |
+
+The next chapter demonstrates how to deploy a molmo model using LMDeploy, with [Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) as an example.
+
+## Installation
+
+Please install LMDeploy by following the [installation guide](../get_started/installation.md)
+
+## Offline inference
+
+The following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)
+
+```python
+from lmdeploy import pipeline
+from lmdeploy.vl import load_image
+
+pipe = pipeline('allenai/Molmo-7B-D-0924')
+
+image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
+response = pipe((f'describe this image', image))
+print(response)
+```
+
+More examples are listed below:
+
+
+
+ multi-image multi-round conversation, combined images
+
+
+```python
+from lmdeploy import pipeline, GenerationConfig
+
+pipe = pipeline('allenai/Molmo-7B-D-0924', log_level='INFO')
+messages = [
+ dict(role='user', content=[
+ dict(type='text', text='Describe the two images in detail.'),
+ dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),
+ dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))
+ ])
+]
+out = pipe(messages, gen_config=GenerationConfig(do_sample=False))
+
+messages.append(dict(role='assistant', content=out.text))
+messages.append(dict(role='user', content='What are the similarities and differences between these two images.'))
+out = pipe(messages, gen_config=GenerationConfig(do_sample=False))
+```
+
+
+
+## Online serving
+
+You can launch the server by the `lmdeploy serve api_server` CLI:
+
+```shell
+lmdeploy serve api_server allenai/Molmo-7B-D-0924
+```
+
+You can also start the service using the docker image:
+
+```shell
+docker run --runtime nvidia --gpus all \
+ -v ~/.cache/huggingface:/root/.cache/huggingface \
+ --env "HUGGING_FACE_HUB_TOKEN=" \
+ -p 23333:23333 \
+ --ipc=host \
+ openmmlab/lmdeploy:latest \
+ lmdeploy serve api_server allenai/Molmo-7B-D-0924
+```
+
+If you find the following logs, it means the service launches successfully.
+
+```text
+HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
+HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
+HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
+INFO: Started server process [2439]
+INFO: Waiting for application startup.
+INFO: Application startup complete.
+INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit)
+```
+
+The arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`.
+
+More information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md)
diff --git a/docs/zh_cn/multi_modal/index.rst b/docs/zh_cn/multi_modal/index.rst
index 0942d8d31c..bd141ea90f 100644
--- a/docs/zh_cn/multi_modal/index.rst
+++ b/docs/zh_cn/multi_modal/index.rst
@@ -12,3 +12,5 @@
minicpmv.md
phi3.md
mllama.md
+ qwen2_vl.md
+ molmo.md
diff --git a/docs/zh_cn/multi_modal/molmo.md b/docs/zh_cn/multi_modal/molmo.md
new file mode 100644
index 0000000000..1dc8f8f79b
--- /dev/null
+++ b/docs/zh_cn/multi_modal/molmo.md
@@ -0,0 +1,92 @@
+# Qwen2-VL
+
+LMDeploy 支持 Molmo 系列模型,具体如下:
+
+| Model | Size | Supported Inference Engine |
+| :-------------: | :--: | :------------------------: |
+| Molmo-7B-D-0924 | 7B | TurboMind |
+| Molmo-72-0924 | 72B | TurboMind |
+
+本文将以[Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) 为例,演示使用 LMDeploy 部署 Molmo 系列模型的方法
+
+## 安装
+
+请参考[安装文档](../get_started/installation.md)安装 LMDeploy。
+
+## 离线推理
+
+以下是使用 pipeline 进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)
+
+```python
+from lmdeploy import pipeline
+from lmdeploy.vl import load_image
+
+pipe = pipeline('allenai/Molmo-7B-D-0924')
+
+image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
+response = pipe((f'describe this image', image))
+print(response)
+```
+
+更多例子如下:
+
+
+
+ 多图多轮对话
+
+
+```python
+from lmdeploy import pipeline, GenerationConfig
+
+pipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO')
+messages = [
+ dict(role='user', content=[
+ dict(type='text', text='Describe the two images in detail.'),
+ dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),
+ dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))
+ ])
+]
+out = pipe(messages, gen_config=GenerationConfig(top_k=1))
+
+messages.append(dict(role='assistant', content=out.text))
+messages.append(dict(role='user', content='What are the similarities and differences between these two images.'))
+out = pipe(messages, gen_config=GenerationConfig(top_k=1))
+```
+
+
+
+## 在线服务
+
+你可以通过 `lmdeploy serve api_server` CLI 工具启动服务:
+
+```shell
+lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct
+```
+
+也可以基于 docker image 启动服务:
+
+```shell
+docker run --runtime nvidia --gpus all \
+ -v ~/.cache/huggingface:/root/.cache/huggingface \
+ --env "HUGGING_FACE_HUB_TOKEN=" \
+ -p 23333:23333 \
+ --ipc=host \
+ openmmlab/lmdeploy:qwen2vl \
+ lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct
+```
+
+如果日志中有如下信息,就表明服务启动成功了。
+
+```text
+HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
+HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
+HINT: Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
+INFO: Started server process [2439]
+INFO: Waiting for application startup.
+INFO: Application startup complete.
+INFO: Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit)
+```
+
+有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。
+
+关于 `api_server` 更多的介绍,以及访问 `api_server` 的方法,请阅读[此处](api_server_vl.md)
diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py
index 8284c99741..ce5cbd98ff 100644
--- a/lmdeploy/archs.py
+++ b/lmdeploy/archs.py
@@ -121,7 +121,8 @@ def check_vl_llm(config: dict) -> bool:
'InternVLChatModel', 'MiniGeminiLlamaForCausalLM',
'MGMLlamaForCausalLM', 'MiniCPMV', 'LlavaForConditionalGeneration',
'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM',
- 'Qwen2VLForConditionalGeneration', 'MllamaForConditionalGeneration'
+ 'Qwen2VLForConditionalGeneration', 'MllamaForConditionalGeneration',
+ 'MolmoForCausalLM'
])
if arch == 'QWenLMHeadModel' and 'visual' in config:
return True
diff --git a/lmdeploy/model.py b/lmdeploy/model.py
index db864a8344..c9eb71c2c3 100644
--- a/lmdeploy/model.py
+++ b/lmdeploy/model.py
@@ -1747,6 +1747,37 @@ def match(cls, model_path: str) -> Optional[str]:
return 'internvl-phi3'
+@MODELS.register_module(name='molmo')
+class Molmo(BaseChatTemplate):
+
+ def __init__(self,
+ user=' User: ',
+ eoh='',
+ assistant=' Assistant:',
+ eoa='',
+ separator=' ',
+ stop_words=['<|endoftext|>'],
+ **kwargs):
+ super().__init__(user=user,
+ eoh=eoh,
+ assistant=assistant,
+ eoa=eoa,
+ separator=separator,
+ stop_words=stop_words,
+ **kwargs)
+
+ @classmethod
+ def match(cls, model_path: str) -> Optional[str]:
+ """Return the model_name that was registered to MODELS.
+
+ Args:
+ model_path (str): the model path used for matching.
+ """
+ path = model_path.lower()
+ if 'molmo' in path:
+ return 'molmo'
+
+
def best_match_model(query: str) -> Optional[str]:
"""Get the model that matches the query.
diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py
index fd0b0bb5e4..c293cd71c8 100644
--- a/lmdeploy/serve/vl_async_engine.py
+++ b/lmdeploy/serve/vl_async_engine.py
@@ -64,6 +64,7 @@ async def _get_prompt_input(self,
results = {}
input_ids = []
from lmdeploy.vl.templates import (MllamaTempateWrapper,
+ MolmoChatTemplateWrapper,
Qwen2VLChatTemplateWrapper)
ranges = None
grid_thws = None
@@ -99,6 +100,10 @@ async def _get_prompt_input(self,
results['cross_attention_states'] = features[0]
return results
+ if isinstance(self.vl_prompt_template,
+ MolmoChatTemplateWrapper):
+ return features[0]
+
features = [x.cpu().numpy() for x in features]
input_ids = []
begins = []
diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py
index a535b0d4c1..c724b085a0 100644
--- a/lmdeploy/turbomind/deploy/config.py
+++ b/lmdeploy/turbomind/deploy/config.py
@@ -35,6 +35,13 @@ class ModelConfig:
kv_head_num: int = None
hidden_units: int = None
vocab_size: int = None
+ # Turbomind used to assume token_embedding and lm_head has the same size
+ # at vocab dim, i.e. `vocab_size`
+ # But in molmo, embedding.shape is [vocab_size + 128, hidden_units]
+ # while lm_head shape is [hidden_units, vocab_size].
+ # Therefore, we add a new attr "embedding_size" to represent the vocab dim
+ # of token_embedding
+ embedding_size: int = 0
num_layer: int = None
inter_size: int = None
norm_eps: float = None
diff --git a/lmdeploy/turbomind/deploy/source_model/__init__.py b/lmdeploy/turbomind/deploy/source_model/__init__.py
index b1da698e2e..de16bdc0a0 100644
--- a/lmdeploy/turbomind/deploy/source_model/__init__.py
+++ b/lmdeploy/turbomind/deploy/source_model/__init__.py
@@ -9,5 +9,6 @@
from .meta_llama import MetaLlamaModel # noqa: F401
from .minicpmv import MiniCPMVModel # noqa: F401
from .mixtral import MixtralModel # noqa: F401
+from .molmo import MolmoModel # noqa: F401
from .qwen import QwenModel # noqa: F401
from .xcomposer2 import Xcomposer2Model # noqa: F401
diff --git a/lmdeploy/turbomind/deploy/source_model/molmo.py b/lmdeploy/turbomind/deploy/source_model/molmo.py
new file mode 100644
index 0000000000..541e201046
--- /dev/null
+++ b/lmdeploy/turbomind/deploy/source_model/molmo.py
@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os.path as osp
+
+import torch
+
+from .base import INPUT_MODELS
+from .llama import LlamaModel, LlamaReader
+
+
+class MolmoReader(LlamaReader):
+ attn_layer_prefix = 'model.transformer.blocks'
+ attn_layer_patten = r'model.transformer.blocks.([0-9]+).'
+ norm_weight_key = 'model.transformer.ln_f.weight'
+ output_weight_key = 'model.transformer.ff_out.weight'
+
+ # In molmo, names of attention parameters are "att_proj.bias",
+ # "att_proj.weight", "attn_norm.weight", "attn_out.weight", and names
+ # of ffn parameters are "ff_norm", "ff_out", "ff_proj", so we
+ # make the patterns are r'att' and r'ffn_', respectively.
+ attn_pattern = r'att'
+ ffn_pattern = r'ff_'
+
+ def tok_embeddings(self):
+ embed1 = self.params.get('model.transformer.wte.embedding', None)
+ embed2 = self.params.get('model.transformer.wte.new_embedding', None)
+ if embed1 is not None and embed2 is not None:
+ return torch.cat((embed1, embed2), dim=0)
+ else:
+ assert embed1 is None and embed2 is None
+ return None
+
+ def attn_norm(self, i: int):
+ """Get attn norm for layer i."""
+ return self.params[f'{self.attn_layer_prefix}.{i}.attn_norm.weight']
+
+ def _attn(self, i: int, kind: str):
+ """Get q, k, v, o kind(weight, bias, qweight) for layer i.
+
+ Args:
+ i (int): layer id
+ kind (str): can be one of ["weight", "bias", "qweight"]
+ """
+ q, k, v = (None, ) * 3
+ hidden_size = self.model_cfg['hidden_size']
+ head_num = self.model_cfg['num_attention_heads']
+ kv_head_num = self.model_cfg['num_key_value_heads']
+ head_dim = hidden_size // head_num
+ assert head_dim == 128
+ fused_dims = (hidden_size, kv_head_num * head_dim,
+ kv_head_num * head_dim)
+ qkv = self.params.get(f'{self.attn_layer_prefix}.{i}.att_proj.{kind}')
+ qkv = self.transform(qkv, kind)
+ if qkv is not None:
+ q, k, v = qkv.split(fused_dims, dim=0)
+ o = self.params.get(f'{self.attn_layer_prefix}.{i}.attn_out.{kind}')
+ o = self.transform(o, kind)
+ if o is None: # handle the case when qkv has bias but o doesn't
+ o = torch.zeros_like(q)
+ return (q, k, v, o)
+
+ def _ffn(self, i: int, kind: str):
+ """Get ffn kind(weight, qweight) for layer i."""
+ up_and_gate = self.params[
+ f'{self.attn_layer_prefix}.{i}.ff_proj.{kind}']
+ up_and_gate = self.transform(up_and_gate, kind)
+ gate, up = up_and_gate.chunk(2, dim=0)
+ down = self.params[f'{self.attn_layer_prefix}.{i}.ff_out.{kind}']
+ down = self.transform(down, kind)
+ return (up, down, gate)
+
+ def ffn_norm(self, i: int):
+ """Get ffn norm for layer i."""
+ return self.params[f'{self.attn_layer_prefix}.{i}.ff_norm.weight']
+
+
+@INPUT_MODELS.register_module(name='molmo')
+class MolmoModel(LlamaModel):
+
+ Reader = MolmoReader
+
+ def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
+ super().__init__(model_path, tokenizer_path, **kwargs)
+ config_path = osp.join(self.model_path, 'config.json')
+ with open(config_path) as f:
+ self.config = json.load(f)
+
+ def tokenizer_info(self):
+
+ n_words = 152064
+ bos_id = 151643
+ eos_id = 151643
+ return n_words, bos_id, eos_id
+
+ def model_info(self):
+ config = self.config
+ num_layer = config['num_hidden_layers']
+ norm_eps = config['layer_norm_eps']
+ attn_head_num = config['num_attention_heads']
+ kv_head_num = config['num_key_value_heads']
+ hidden_units = config['hidden_size']
+ rope_theta = config['rope_theta']
+ max_position_embeddings = config['max_position_embeddings']
+ vocab_size = config['vocab_size']
+ # https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/modeling_molmo.py#L2041
+ additional_vocab_size = 128
+ inter_size = config['intermediate_size'] // 2
+ attn_bias = config['qkv_bias']
+ return dict(
+ num_layer=num_layer,
+ norm_eps=norm_eps,
+ head_num=attn_head_num,
+ kv_head_num=kv_head_num,
+ hidden_units=hidden_units,
+ attn_bias=int(attn_bias),
+ inter_size=inter_size,
+ vocab_size=vocab_size,
+ # https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/modeling_molmo.py#L564
+ embedding_size=vocab_size + additional_vocab_size,
+ rope_theta=rope_theta,
+ max_position_embeddings=max_position_embeddings,
+ )
diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py
index abd570cd00..09699ade09 100644
--- a/lmdeploy/turbomind/deploy/target_model/base.py
+++ b/lmdeploy/turbomind/deploy/target_model/base.py
@@ -92,6 +92,9 @@ def update_model_config(self):
final_cfg = config_to_dict(self.model_config)
final_cfg.update(dict(start_id=bos_id, end_id=eos_id))
final_cfg.update(self.input_model_info)
+ if 'embedding_size' not in self.input_model_info.keys():
+ final_cfg.update(
+ embedding_size=self.input_model_info['vocab_size'])
self.model_config = config_from_dict(ModelConfig, final_cfg)
diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py
index bb3533254b..e66da22df0 100644
--- a/lmdeploy/turbomind/supported_models.py
+++ b/lmdeploy/turbomind/supported_models.py
@@ -42,7 +42,9 @@
ChatGLMModel='glm4',
ChatGLMForConditionalGeneration='glm4',
# mixtral
- MixtralForCausalLM='mixtral')
+ MixtralForCausalLM='mixtral',
+ MolmoForCausalLM='molmo',
+)
def is_supported(model_path: str):
@@ -104,5 +106,9 @@ def _is_head_dim_supported(cfg):
if llm_arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
support_by_turbomind = _is_head_dim_supported(
cfg.text_config)
+ elif arch == 'MolmoForCausalLM':
+ kv_heads = cfg.num_key_value_heads
+ # TM hasn't supported allenai/Molmo-7B-O-0924 yet
+ support_by_turbomind = kv_heads is not None
return support_by_turbomind
diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py
index 9e71f7d1c0..2401b42259 100644
--- a/lmdeploy/vl/model/builder.py
+++ b/lmdeploy/vl/model/builder.py
@@ -18,6 +18,7 @@
from .mini_gemeni import MiniGeminiVisionModel # noqa F401
from .minicpmv import MiniCPMVModel # noqa F401
from .mllama import MllamaVLModel # noqa F401
+from .molmo import MolmoVisionModel # noqa F401
from .phi3_vision import Phi3VisionModel # noqa F401
from .qwen import QwenVisionModel # noqa F401
from .qwen2 import Qwen2VLModel # noqa F401
@@ -31,7 +32,14 @@ def load_vl_model(model_path: str,
with_llm: bool = False,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None):
- """load visual model."""
+ """load visual model.
+
+ Args:
+ model_path(str): the path or repo_id from model hub of the model
+ with_llm(bool): whether to remove the LLM part from the model.
+ When it is False, it means removing LLM part
+ backend_config: the config of the inference engine
+ """
if not os.path.exists(model_path):
revision = getattr(backend_config, 'revision', None)
download_dir = getattr(backend_config, 'download_dir', None)
diff --git a/lmdeploy/vl/model/molmo.py b/lmdeploy/vl/model/molmo.py
new file mode 100644
index 0000000000..9abae7a309
--- /dev/null
+++ b/lmdeploy/vl/model/molmo.py
@@ -0,0 +1,177 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Dict, List
+
+import torch
+from PIL.Image import Image
+from transformers import AutoModelForCausalLM, AutoProcessor
+
+from lmdeploy.utils import get_logger
+from lmdeploy.vl.constants import IMAGE_TOKEN
+from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
+from lmdeploy.vl.model.utils import disable_logging
+
+logger = get_logger('lmdeploy')
+
+
+@VISION_MODELS.register_module()
+class MolmoVisionModel(VisonModel):
+ """molmo's vision model."""
+
+ _arch = 'MolmoForCausalLM'
+
+ def build_model(self):
+ """Load model."""
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
+ with init_empty_weights():
+ config = self.hf_config
+ model = AutoModelForCausalLM.from_config(config,
+ trust_remote_code=True)
+ if not self.with_llm:
+ # Remove nn modules other than embedding from the LLM model
+ for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']:
+ del model.model.transformer[key]
+ self.token_embedding = model.model.transformer.wte
+ else:
+ self.vl_model = model
+
+ with disable_logging():
+ load_checkpoint_and_dispatch(
+ model=model,
+ checkpoint=self.model_path,
+ device_map='auto' if not self.with_llm else {'': 'cpu'},
+ max_memory=self.max_memory,
+ no_split_module_classes=[
+ 'ResidualAttentionBlock', 'Embedding'
+ ])
+
+ # We need eval mode to freeze the weights in model, thus,
+ # avoid randomness in inference.
+ self.model = model.eval()
+ self.config = config
+
+ self.processor = AutoProcessor.from_pretrained(self.model_path,
+ trust_remote_code=True,
+ torch_dtype='auto',
+ device_map='auto')
+
+ @torch.no_grad()
+ def forward(self,
+ images: List[Image],
+ params: List[Dict] = None) -> List[Dict]:
+ """forward the model with given input.
+
+ Args:
+ images (List): [None] it is not used
+ params (List): the inputs after precessing GPT4V messages in
+ `MolmoChatTemplateWrapper`. Its format is like the following:
+ [[
+ {'role': 'user', 'content': 'user prompt'},
+ {'role': 'asssistant', 'content': 'assistant prompt'},
+ {'role': 'user', 'content': 'user prompt', 'images': [PIL image list]},
+ ...
+ ]]
+ """ # noqa
+
+ messages = params[0]
+ assert isinstance(messages, List)
+ # append an assistant message to `messages`
+ messages.append(dict(role='assistant', content=''))
+ # results is a list of tuple(input_ids, embeddings)
+ results = []
+ # the concat prompt. It is not used during inference but to adhere the
+ # interface definition of `_get_prompt_input` in `class VLAsyncEngine`
+ prompts = ''
+ # Prepend BOS
+ # qwen2 and olmo do not have a BOS, and instead use EOS as a generic
+ # separator token.
+ bos = (self.processor.tokenizer.bos_token_id
+ or self.processor.tokenizer.eos_token_id)
+ results.append(([bos], None))
+ for i, message in enumerate(messages):
+ if 'images' in message.keys():
+ prompts += ' User: ' + (IMAGE_TOKEN + '\n') * len(
+ message['images']) + message['content']
+ prompt = f' User: {message["content"]}'
+ tokens = self.processor.tokenizer.encode(
+ prompt, add_special_tokens=False)
+ # preprocess images. The output is a dict
+ inputs = self.processor.process(images=message['images'],
+ tokens=tokens)
+ inputs = {
+ k: v.to(self.model.device).unsqueeze(0)
+ for k, v in inputs.items()
+ }
+ input_ids = inputs['input_ids']
+ # remove the bos from input_ids which is prepended by molmo's
+ # processor
+ input_ids = input_ids[:, 1:]
+ images = inputs[
+ 'images'] # (batch_size, num_image, num_patch, d_model)
+ image_input_idx = inputs[
+ 'image_input_idx'] # (batch_size, num_image, num_patch)
+ image_masks = inputs['image_masks']
+ batch_size, seq_len = input_ids.size()
+ assert batch_size == 1
+
+ # Get embeddings of input.
+ if input_ids is not None:
+ input_ids = input_ids * (input_ids != -1).to(
+ input_ids.dtype)
+ embeddings = self.model.model.transformer.wte(input_ids)
+ image_features, _ = self.model.model.vision_backbone(
+ images, image_masks)
+ num_image, num_patch = image_features.shape[1:3]
+ assert image_input_idx.shape == (batch_size, num_image,
+ num_patch)
+
+ # insert the image feature into the embedding.
+ image_features = image_features.view(batch_size,
+ num_image * num_patch, -1)
+ image_input_idx = image_input_idx.view(batch_size,
+ num_image * num_patch)
+
+ valid = image_input_idx >= 0
+ batch_idx = torch.arange(batch_size, device=embeddings.device)
+ batch_idx = torch.tile(batch_idx[:, None],
+ [1, image_features.shape[1]])
+ image_features = image_features.to(embeddings.device)
+ embeddings[batch_idx[valid],
+ image_input_idx[valid]] += image_features[valid]
+ assert embeddings.shape[:2] == (batch_size, seq_len)
+ results.append((input_ids.flatten().tolist(), embeddings))
+ else:
+ role = message['role']
+ content = message['content']
+ assert isinstance(content, str)
+ prompt = ''
+ if role == 'user':
+ prompt = f' User: {content}'
+ elif role == 'assistant':
+ prompt = f' Assistant:{content}'
+ else:
+ assert 0, f'molmo does not support role {role}, message is {message}' # noqa
+ input_ids = self.processor.tokenizer.encode(
+ prompt, add_special_tokens=False)
+ results.append((input_ids, None))
+ prompts += prompt
+
+ # concat input_ids from results, calculate the range in the input_ids
+ # where embeddings will be copied to
+ input_ids = []
+ input_embeddings = []
+ input_embedding_ranges = []
+ start = 0
+ for _input_ids, _embeddings in results:
+ if _embeddings is not None:
+ input_embeddings.append(_embeddings.cpu())
+ end = start + len(_input_ids)
+ input_embedding_ranges.append((start, end))
+ input_ids += _input_ids
+ start += len(_input_ids)
+ return [
+ dict(prompt=prompts,
+ input_ids=input_ids,
+ input_embeddings=input_embeddings,
+ input_embedding_ranges=input_embedding_ranges)
+ ]
diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py
index 45e457ad2c..cdf398868a 100644
--- a/lmdeploy/vl/templates.py
+++ b/lmdeploy/vl/templates.py
@@ -428,6 +428,84 @@ class GLM4VChatTemplateWrapper(VLChatTemplateWrapper):
pass
+class MolmoChatTemplateWrapper(VLChatTemplateWrapper):
+
+ async def async_collect_pil_images(
+ self, messages: List[Dict]) -> List[Tuple[PIL.Image.Image, Dict]]:
+ """collect images from messages.
+
+ Args:
+ messages (List[Dict]): a user request of GPT4V message format
+ """
+ if isinstance(messages, Dict):
+ messages = [messages]
+ assert isinstance(messages, List)
+
+ out_messages = [None] * len(messages)
+
+ def _inner_call(i, in_messages, out_messages):
+ role = in_messages[i]['role']
+ content = in_messages[i]['content']
+ if role != 'user' or isinstance(content, str):
+ # means message is user's prompt input or assistant's prompt,
+ # returning it directory
+ out_messages[i] = in_messages[i]
+ return
+ # the role is a user and the content is a list
+ assert isinstance(content, List)
+ message = dict(role=role, content='', images=[])
+ for item in content:
+ # 'image_url': means url or local path to image.
+ # 'image_data': means PIL.Image.Image object.
+ if item['type'] == 'image_url':
+ try:
+ image = load_image(item['image_url']['url'])
+ message['images'].append(image)
+ except KeyError:
+ logger.error(f'invalid format {message}')
+ elif item['type'] == 'image_data':
+ try:
+ image = load_image(item['image_data']['data'])
+ message['images'].append(image)
+ except KeyError:
+ logger.error(f'invalid format {message}')
+ elif item['type'] == 'text':
+ message['content'] = item['text']
+ else:
+ logger.error(f'unexpected content type {message}')
+ out_messages[i] = message
+
+ await asyncio.gather(*[
+ asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
+ messages, out_messages)
+ for i in range(len(messages))
+ ])
+ return [(None, out_messages)]
+
+ def messages2prompt(self, messages, sequence_start=True, **kwargs) -> str:
+ """Return a placeholder "IMAGE_TOKEN" so that
+ `vl_asyn_engine._get_prompt_input` can know that it."""
+ if isinstance(messages, str):
+ return self.chat_template.messages2prompt(messages, sequence_start)
+ else:
+ _messages = []
+ for message in messages:
+ role, content = message['role'], message['content']
+ if role != 'user' or isinstance(content, str):
+ _messages.append(message)
+ continue
+ for item in content:
+ item_type = item['type']
+ if item_type in ['image_url', 'image_data']:
+ # Return the image placeholder so that
+ # `vl_asyn_engine._get_prompt_input` can know that the
+ # request contains images
+ return IMAGE_TOKEN
+ _messages.append(dict(role=role, content=item[item_type]))
+ return self.chat_template.messages2prompt(_messages,
+ sequence_start)
+
+
def get_vl_prompt_template(model_path: str, chat_template: BaseModel,
model_name: str) -> VLChatTemplateWrapper:
"""get vision language prompt template."""
@@ -467,4 +545,6 @@ def get_vl_prompt_template(model_path: str, chat_template: BaseModel,
return GLM4VChatTemplateWrapper(chat_template)
elif arch == 'Qwen2VLForConditionalGeneration':
return Qwen2VLChatTemplateWrapper(chat_template)
+ elif arch == 'MolmoForCausalLM':
+ return MolmoChatTemplateWrapper(chat_template)
raise ValueError(f'unsupported vl_prompt_template with arch {arch}')
diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc
index 1ac2d82dd9..9d62042d62 100644
--- a/src/turbomind/models/llama/LlamaWeight.cc
+++ b/src/turbomind/models/llama/LlamaWeight.cc
@@ -32,6 +32,7 @@ LlamaWeight::LlamaWeight(size_t head_num,
size_t hidden_units,
size_t inter_size,
size_t vocab_size,
+ size_t embedding_size,
size_t num_layer,
bool attn_bias,
WeightType weight_type,
@@ -44,16 +45,20 @@ LlamaWeight::LlamaWeight(size_t head_num,
inter_size_(inter_size),
vocab_size_(vocab_size),
vocab_size_padded_(vocab_size),
+ embedding_size_(embedding_size),
num_layer_(num_layer),
weight_type_(weight_type),
tensor_para_size_(tensor_para_size),
tensor_para_rank_(tensor_para_rank)
{
if (vocab_size_padded_ % tensor_para_size_ != 0) {
- vocab_size_padded_ = (vocab_size_padded_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_;
+ vocab_size_padded_ = (vocab_size_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_;
TM_LOG_WARNING("pad vocab size from %d to %d", vocab_size_, vocab_size_padded_);
}
-
+ if (embedding_size_ % tensor_para_size_ != 0) {
+ embedding_size_ = (embedding_size_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_;
+ TM_LOG_WARNING("pad embed size from %d to %d", embedding_size_, embedding_size_);
+ }
FT_CHECK(hidden_units_ % tensor_para_size_ == 0);
decoder_layer_weights.reserve(num_layer_);
@@ -96,7 +101,7 @@ template
void LlamaWeight::mallocWeights()
{
FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0);
- deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_ / tensor_para_size_);
+ deviceMalloc((T**)&pre_decoder_embedding_table, embedding_size_ * hidden_units_ / tensor_para_size_);
deviceMalloc((T**)&output_norm_weight, hidden_units_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_);
}
@@ -111,7 +116,7 @@ void LlamaWeight::loadModel(std::string dir_path)
dir_path += '/';
loadWeightFromBin((T*)pre_decoder_embedding_table,
- {vocab_size_padded_ * hidden_units_ / tensor_para_size_},
+ {embedding_size_ * hidden_units_ / tensor_para_size_},
dir_path + "tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight",
model_file_type);
@@ -135,7 +140,7 @@ TensorMap LlamaWeight::getParams()
output.insert("tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight",
Tensor{MEMORY_GPU,
getTensorType(),
- {vocab_size_padded_ * hidden_units_ / tensor_para_size_ * sizeof(T)},
+ {embedding_size_ * hidden_units_ / tensor_para_size_ * sizeof(T)},
pre_decoder_embedding_table});
output.insert("norm.weight",
diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h
index c04bf6c5a6..c30e753565 100644
--- a/src/turbomind/models/llama/LlamaWeight.h
+++ b/src/turbomind/models/llama/LlamaWeight.h
@@ -35,6 +35,7 @@ struct LlamaWeight {
size_t hidden_units,
size_t inter_size,
size_t vocab_size,
+ size_t embedding_size,
size_t num_layer,
bool attn_bias,
WeightType weight_type,
@@ -67,6 +68,7 @@ struct LlamaWeight {
size_t inter_size_;
size_t vocab_size_;
size_t vocab_size_padded_;
+ size_t embedding_size_;
size_t num_layer_;
WeightType weight_type_;
size_t tensor_para_size_;
diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h
index 2ea63f0410..e6b9d690ae 100644
--- a/src/turbomind/models/llama/llama_params.h
+++ b/src/turbomind/models/llama/llama_params.h
@@ -18,6 +18,7 @@ struct ModelParam {
size_t layer_num;
size_t inter_size;
size_t vocab_size;
+ size_t embedding_size;
float norm_eps;
int quant_policy;
//
diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
index 38552be0cf..2deca46380 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
@@ -133,6 +133,12 @@ void LlamaTritonModel::handleMissingParams()
(int)model_param_.kv_head_num);
}
+ if (model_param_.embedding_size == 0) {
+ model_param_.embedding_size = model_param_.vocab_size;
+ TM_LOG_WARNING("[LlamaTritonModel] `embedding_size` is not set, default to `vocab_size` (%d).",
+ (int)model_param_.vocab_size);
+ }
+
if (!attn_param_.max_position_embeddings) {
attn_param_.max_position_embeddings = 2048;
TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.",
@@ -252,6 +258,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size,
model_param_.layer_num = model_reader["num_layer"].as();
model_param_.inter_size = model_reader["inter_size"].as();
model_param_.vocab_size = model_reader["vocab_size"].as();
+ model_param_.embedding_size = model_reader["embedding_size"].as();
model_param_.norm_eps = model_reader["norm_eps"].as();
model_param_.start_id = model_reader["start_id"].as();
model_param_.end_id = model_reader["end_id"].as();
@@ -417,6 +424,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank)
model_param_.hidden_units,
model_param_.inter_size,
model_param_.vocab_size,
+ model_param_.embedding_size,
model_param_.layer_num,
attn_bias_,
weight_type_,