Skip to content

Commit

Permalink
add vl gradio demo
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 20, 2023
1 parent 477f2db commit f096cdf
Show file tree
Hide file tree
Showing 6 changed files with 626 additions and 2 deletions.
34 changes: 34 additions & 0 deletions examples/vl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Vision-Language Web Demo

A chatbot demo with image input.

## Supported Models

- [InternLM/InternLM-XComposer](https://github.com/InternLM/InternLM-XComposer/tree/main)
- [Qwen/Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)

## Quick Start

### internlm/internlm-xcomposer-7b

- extract llm model from huggingface model
```python
python extract_xcomposer_llm.py
# the llm part will saved to internlm_model folder.
```
- lanuch the demo
```python
python app.py --model-name internlm-xcomposer-7b --llm-ckpt internlm_model
```

### Qwen-VL-Chat

- lanuch the dmeo
```python
python app.py --model-name qwen-vl-chat --hf-ckpt Qwen/Qwen-VL-Chat
```

## Limitations

- this demo the code in their repo to extract image features that might not very efficiency.
- this demo only contains the chat function. If you want to use localization ability in Qwen-VL-Chat or article generation function in InternLM-XComposer, you need implement these pre/post process. The difference compared to chat is how to build prompts and use the output of model.
238 changes: 238 additions & 0 deletions examples/vl/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import argparse
import os
import random
from dataclasses import dataclass, field
from itertools import count
from pathlib import Path
from threading import Lock
from typing import List, Tuple

import gradio as gr
from qwen_model import QwenVLChat
from xcomposer_model import InternLMXComposer

from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.turbomind import TurboMind
from lmdeploy.turbomind.chat import valid_str

BATCH_SIZE = 32
DEFAULT_MODEL_NAME = 'internlm-xcomposer-7b'
DEFAULT_HF_CKPT = 'internlm/internlm-xcomposer-7b'
# should use extract_xcomposer_llm.py to extract llm
# when use internlm-xcomposer-7b
DEFAULT_LLM_CKPT = None

SUPPORTED_MODELS = {
'internlm-xcomposer-7b': InternLMXComposer,
'qwen-vl-chat': QwenVLChat
}


@dataclass
class Session:
_lock = Lock()
_count = count()
_session_id: int = None
_message: List[Tuple[str, str]] = field(default_factory=list)
_step: int = 0

def __init__(self):
with Session._lock:
self._session_id = next(Session._count)
self._message = []
self._step = 0

@property
def session_id(self):
return self._session_id

@property
def message(self):
return self._message

@property
def step(self):
return self._step


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model-name',
type=str,
default=DEFAULT_MODEL_NAME,
help='Model name, default to %(default)s')
parser.add_argument(
'--hf-ckpt',
type=str,
default=DEFAULT_HF_CKPT,
help='hf checkpoint name or path, default to %(default)s')
parser.add_argument(
'--llm-ckpt',
type=str,
default=DEFAULT_LLM_CKPT,
help='LLM checkpoint name or path, default to %(default)s')
parser.add_argument('--server-port',
type=int,
default=9006,
help='Server port, default %(default)s')
parser.add_argument('--server-name',
type=str,
default='127.0.0.1',
help='Server name, default %(default)s')
args = parser.parse_args()
return args


def load_preprocessor_model(args):
assert args.model_name in SUPPORTED_MODELS
llm_ckpt = args.hf_ckpt if args.llm_ckpt is None else args.llm_ckpt
preprocessor = SUPPORTED_MODELS[args.model_name](args.hf_ckpt)
model = TurboMind.from_pretrained(llm_ckpt, model_name=args.model_name)
return preprocessor, model


def launch_demo(args, preprocessor, model):

def add_image(chatbot, session, file):
chatbot = chatbot + [((file.name, ), None)]
# print('add_image', chatbot)
history = session._message
# [([user, url, url], assistant), ...]
if len(history) == 0 or history[-1][-1] is not None:
history.append([[file.name], None])
else:
history[-1][0].append(file.name)
return chatbot, session

def add_text(chatbot, session, text):
chatbot = chatbot + [(text, None)]
history = session._message
if len(history) == 0 or history[-1][-1] is not None:
history.append([text, None])
else:
history[-1][0].insert(0, text)
return chatbot, session, disable_btn, enable_btn

def chat(
chatbot,
session,
):
yield chatbot, session, disable_btn, enable_btn, disable_btn

generator = model.create_instance()
history = session._message
sequence_start = len(history) == 1
seed = random.getrandbits(64) if sequence_start else None
input_ids, features, ranges = preprocessor.prepare_query(
history[-1][0], sequence_start)

if len(input_ids) + session.step > model.model.session_len:
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')

response_size = 0
step = session.step
for outputs in generator.stream_infer(session_id=session.session_id,
input_ids=input_ids,
input_embeddings=features,
input_embedding_ranges=ranges,
stream_output=True,
sequence_start=sequence_start,
random_seed=seed,
step=step):
res, tokens = outputs[0]
# decode res
response = model.tokenizer.decode(res.tolist(),
offset=response_size)
if response.endswith('�'):
continue
response = valid_str(response)
response_size = tokens
if chatbot[-1][1] is None:
chatbot[-1][1] = ''
history[-1][1] = ''
chatbot[-1][1] += response
history[-1][1] += response
session._step = step + len(input_ids) + tokens
yield chatbot, session, disable_btn, enable_btn, disable_btn

yield chatbot, session, enable_btn, disable_btn, enable_btn

def cancel(chatbot, session):
generator = model.create_instance()
for _ in generator.stream_infer(session_id=session.session_id,
input_ids=[0],
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
return chatbot, session, disable_btn, enable_btn

def reset(session):
generator = model.create_instance()
for _ in generator.stream_infer(session_id=session.session_id,
input_ids=[0],
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
return [], Session()

with gr.Blocks(css=CSS, theme=THEME) as demo:
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy VL Playground')

chatbot = gr.Chatbot(elem_id='chatbot', label=model.model_name)
query = gr.Textbox(placeholder='Please input the instruction',
label='Instruction')
session = gr.State()

with gr.Row():
addimg_btn = gr.UploadButton('Upload Image',
file_types=['image'])
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')

addimg_btn.upload(add_image, [chatbot, session, addimg_btn],
[chatbot, session],
show_progress=True)

send_event = query.submit(
add_text, [chatbot, session, query], [chatbot, session]).then(
chat, [chatbot, session],
[chatbot, session, query, cancel_btn, reset_btn])
query.submit(lambda: gr.update(value=''), None, [query])

cancel_btn.click(cancel, [chatbot, session],
[chatbot, session, cancel_btn, reset_btn],
cancels=[send_event])

reset_btn.click(reset, [session], [chatbot, session],
cancels=[send_event])

demo.load(lambda: Session(), inputs=None, outputs=[session])

demo.queue(api_open=True, concurrency_count=BATCH_SIZE, max_size=100)
demo.launch(
share=True,
server_port=args.server_port,
server_name=args.server_name,
)


def main():
args = parse_args()

cur_folder = Path(__file__).parent.as_posix()
if cur_folder != os.getcwd():
os.chdir(cur_folder)
print(f'change working dir to {cur_folder}')

preprocessor, model = load_preprocessor_model(args)
launch_demo(args, preprocessor, model)


if __name__ == '__main__':
main()
41 changes: 41 additions & 0 deletions examples/vl/extract_xcomposer_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from pathlib import Path

import torch
from transformers import AutoModel, AutoTokenizer
from xcomposer_model import InternLMXComposerTemplate # noqa

model = AutoModel.from_pretrained('internlm/internlm-xcomposer-7b',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('internlm/internlm-xcomposer-7b',
trust_remote_code=True)

internlm_model = model.internlm_model

lora_layers = [
'self_attn.q_proj', 'self_attn.v_proj', 'mlp.down_proj', 'mlp.up_proj'
]


def get_attr(m, key):
keys = key.split('.')
for key in keys:
m = getattr(m, key)
return m


# merge lora
for i in range(len(internlm_model.model.layers)):
layer = internlm_model.model.layers[i]
for key in lora_layers:
lora_linear = get_attr(layer, key)
lora_b = lora_linear.lora_B
lora_a = lora_linear.lora_A
w_ba = torch.matmul(lora_b.weight, lora_a.weight)
lora_linear.weight.data += w_ba.data

# save model
cur_folder = Path(__file__).parent
dst_path = os.path.join(cur_folder, 'internlm_model')
internlm_model.save_pretrained(dst_path)
tokenizer.save_pretrained(dst_path)
Loading

0 comments on commit f096cdf

Please sign in to comment.