-
Notifications
You must be signed in to change notification settings - Fork 455
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
626 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Vision-Language Web Demo | ||
|
||
A chatbot demo with image input. | ||
|
||
## Supported Models | ||
|
||
- [InternLM/InternLM-XComposer](https://github.com/InternLM/InternLM-XComposer/tree/main) | ||
- [Qwen/Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat) | ||
|
||
## Quick Start | ||
|
||
### internlm/internlm-xcomposer-7b | ||
|
||
- extract llm model from huggingface model | ||
```python | ||
python extract_xcomposer_llm.py | ||
# the llm part will saved to internlm_model folder. | ||
``` | ||
- lanuch the demo | ||
```python | ||
python app.py --model-name internlm-xcomposer-7b --llm-ckpt internlm_model | ||
``` | ||
|
||
### Qwen-VL-Chat | ||
|
||
- lanuch the dmeo | ||
```python | ||
python app.py --model-name qwen-vl-chat --hf-ckpt Qwen/Qwen-VL-Chat | ||
``` | ||
|
||
## Limitations | ||
|
||
- this demo the code in their repo to extract image features that might not very efficiency. | ||
- this demo only contains the chat function. If you want to use localization ability in Qwen-VL-Chat or article generation function in InternLM-XComposer, you need implement these pre/post process. The difference compared to chat is how to build prompts and use the output of model. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
import argparse | ||
import os | ||
import random | ||
from dataclasses import dataclass, field | ||
from itertools import count | ||
from pathlib import Path | ||
from threading import Lock | ||
from typing import List, Tuple | ||
|
||
import gradio as gr | ||
from qwen_model import QwenVLChat | ||
from xcomposer_model import InternLMXComposer | ||
|
||
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn | ||
from lmdeploy.turbomind import TurboMind | ||
from lmdeploy.turbomind.chat import valid_str | ||
|
||
BATCH_SIZE = 32 | ||
DEFAULT_MODEL_NAME = 'internlm-xcomposer-7b' | ||
DEFAULT_HF_CKPT = 'internlm/internlm-xcomposer-7b' | ||
# should use extract_xcomposer_llm.py to extract llm | ||
# when use internlm-xcomposer-7b | ||
DEFAULT_LLM_CKPT = None | ||
|
||
SUPPORTED_MODELS = { | ||
'internlm-xcomposer-7b': InternLMXComposer, | ||
'qwen-vl-chat': QwenVLChat | ||
} | ||
|
||
|
||
@dataclass | ||
class Session: | ||
_lock = Lock() | ||
_count = count() | ||
_session_id: int = None | ||
_message: List[Tuple[str, str]] = field(default_factory=list) | ||
_step: int = 0 | ||
|
||
def __init__(self): | ||
with Session._lock: | ||
self._session_id = next(Session._count) | ||
self._message = [] | ||
self._step = 0 | ||
|
||
@property | ||
def session_id(self): | ||
return self._session_id | ||
|
||
@property | ||
def message(self): | ||
return self._message | ||
|
||
@property | ||
def step(self): | ||
return self._step | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model-name', | ||
type=str, | ||
default=DEFAULT_MODEL_NAME, | ||
help='Model name, default to %(default)s') | ||
parser.add_argument( | ||
'--hf-ckpt', | ||
type=str, | ||
default=DEFAULT_HF_CKPT, | ||
help='hf checkpoint name or path, default to %(default)s') | ||
parser.add_argument( | ||
'--llm-ckpt', | ||
type=str, | ||
default=DEFAULT_LLM_CKPT, | ||
help='LLM checkpoint name or path, default to %(default)s') | ||
parser.add_argument('--server-port', | ||
type=int, | ||
default=9006, | ||
help='Server port, default %(default)s') | ||
parser.add_argument('--server-name', | ||
type=str, | ||
default='127.0.0.1', | ||
help='Server name, default %(default)s') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def load_preprocessor_model(args): | ||
assert args.model_name in SUPPORTED_MODELS | ||
llm_ckpt = args.hf_ckpt if args.llm_ckpt is None else args.llm_ckpt | ||
preprocessor = SUPPORTED_MODELS[args.model_name](args.hf_ckpt) | ||
model = TurboMind.from_pretrained(llm_ckpt, model_name=args.model_name) | ||
return preprocessor, model | ||
|
||
|
||
def launch_demo(args, preprocessor, model): | ||
|
||
def add_image(chatbot, session, file): | ||
chatbot = chatbot + [((file.name, ), None)] | ||
# print('add_image', chatbot) | ||
history = session._message | ||
# [([user, url, url], assistant), ...] | ||
if len(history) == 0 or history[-1][-1] is not None: | ||
history.append([[file.name], None]) | ||
else: | ||
history[-1][0].append(file.name) | ||
return chatbot, session | ||
|
||
def add_text(chatbot, session, text): | ||
chatbot = chatbot + [(text, None)] | ||
history = session._message | ||
if len(history) == 0 or history[-1][-1] is not None: | ||
history.append([text, None]) | ||
else: | ||
history[-1][0].insert(0, text) | ||
return chatbot, session, disable_btn, enable_btn | ||
|
||
def chat( | ||
chatbot, | ||
session, | ||
): | ||
yield chatbot, session, disable_btn, enable_btn, disable_btn | ||
|
||
generator = model.create_instance() | ||
history = session._message | ||
sequence_start = len(history) == 1 | ||
seed = random.getrandbits(64) if sequence_start else None | ||
input_ids, features, ranges = preprocessor.prepare_query( | ||
history[-1][0], sequence_start) | ||
|
||
if len(input_ids) + session.step > model.model.session_len: | ||
gr.Warning('WARNING: exceed session max length.' | ||
' Please restart the session by reset button.') | ||
|
||
response_size = 0 | ||
step = session.step | ||
for outputs in generator.stream_infer(session_id=session.session_id, | ||
input_ids=input_ids, | ||
input_embeddings=features, | ||
input_embedding_ranges=ranges, | ||
stream_output=True, | ||
sequence_start=sequence_start, | ||
random_seed=seed, | ||
step=step): | ||
res, tokens = outputs[0] | ||
# decode res | ||
response = model.tokenizer.decode(res.tolist(), | ||
offset=response_size) | ||
if response.endswith('�'): | ||
continue | ||
response = valid_str(response) | ||
response_size = tokens | ||
if chatbot[-1][1] is None: | ||
chatbot[-1][1] = '' | ||
history[-1][1] = '' | ||
chatbot[-1][1] += response | ||
history[-1][1] += response | ||
session._step = step + len(input_ids) + tokens | ||
yield chatbot, session, disable_btn, enable_btn, disable_btn | ||
|
||
yield chatbot, session, enable_btn, disable_btn, enable_btn | ||
|
||
def cancel(chatbot, session): | ||
generator = model.create_instance() | ||
for _ in generator.stream_infer(session_id=session.session_id, | ||
input_ids=[0], | ||
request_output_len=0, | ||
sequence_start=False, | ||
sequence_end=False, | ||
stop=True): | ||
pass | ||
return chatbot, session, disable_btn, enable_btn | ||
|
||
def reset(session): | ||
generator = model.create_instance() | ||
for _ in generator.stream_infer(session_id=session.session_id, | ||
input_ids=[0], | ||
request_output_len=0, | ||
sequence_start=False, | ||
sequence_end=False, | ||
stop=True): | ||
pass | ||
return [], Session() | ||
|
||
with gr.Blocks(css=CSS, theme=THEME) as demo: | ||
with gr.Column(elem_id='container'): | ||
gr.Markdown('## LMDeploy VL Playground') | ||
|
||
chatbot = gr.Chatbot(elem_id='chatbot', label=model.model_name) | ||
query = gr.Textbox(placeholder='Please input the instruction', | ||
label='Instruction') | ||
session = gr.State() | ||
|
||
with gr.Row(): | ||
addimg_btn = gr.UploadButton('Upload Image', | ||
file_types=['image']) | ||
cancel_btn = gr.Button(value='Cancel', interactive=False) | ||
reset_btn = gr.Button(value='Reset') | ||
|
||
addimg_btn.upload(add_image, [chatbot, session, addimg_btn], | ||
[chatbot, session], | ||
show_progress=True) | ||
|
||
send_event = query.submit( | ||
add_text, [chatbot, session, query], [chatbot, session]).then( | ||
chat, [chatbot, session], | ||
[chatbot, session, query, cancel_btn, reset_btn]) | ||
query.submit(lambda: gr.update(value=''), None, [query]) | ||
|
||
cancel_btn.click(cancel, [chatbot, session], | ||
[chatbot, session, cancel_btn, reset_btn], | ||
cancels=[send_event]) | ||
|
||
reset_btn.click(reset, [session], [chatbot, session], | ||
cancels=[send_event]) | ||
|
||
demo.load(lambda: Session(), inputs=None, outputs=[session]) | ||
|
||
demo.queue(api_open=True, concurrency_count=BATCH_SIZE, max_size=100) | ||
demo.launch( | ||
share=True, | ||
server_port=args.server_port, | ||
server_name=args.server_name, | ||
) | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
cur_folder = Path(__file__).parent.as_posix() | ||
if cur_folder != os.getcwd(): | ||
os.chdir(cur_folder) | ||
print(f'change working dir to {cur_folder}') | ||
|
||
preprocessor, model = load_preprocessor_model(args) | ||
launch_demo(args, preprocessor, model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
from transformers import AutoModel, AutoTokenizer | ||
from xcomposer_model import InternLMXComposerTemplate # noqa | ||
|
||
model = AutoModel.from_pretrained('internlm/internlm-xcomposer-7b', | ||
trust_remote_code=True) | ||
tokenizer = AutoTokenizer.from_pretrained('internlm/internlm-xcomposer-7b', | ||
trust_remote_code=True) | ||
|
||
internlm_model = model.internlm_model | ||
|
||
lora_layers = [ | ||
'self_attn.q_proj', 'self_attn.v_proj', 'mlp.down_proj', 'mlp.up_proj' | ||
] | ||
|
||
|
||
def get_attr(m, key): | ||
keys = key.split('.') | ||
for key in keys: | ||
m = getattr(m, key) | ||
return m | ||
|
||
|
||
# merge lora | ||
for i in range(len(internlm_model.model.layers)): | ||
layer = internlm_model.model.layers[i] | ||
for key in lora_layers: | ||
lora_linear = get_attr(layer, key) | ||
lora_b = lora_linear.lora_B | ||
lora_a = lora_linear.lora_A | ||
w_ba = torch.matmul(lora_b.weight, lora_a.weight) | ||
lora_linear.weight.data += w_ba.data | ||
|
||
# save model | ||
cur_folder = Path(__file__).parent | ||
dst_path = os.path.join(cur_folder, 'internlm_model') | ||
internlm_model.save_pretrained(dst_path) | ||
tokenizer.save_pretrained(dst_path) |
Oops, something went wrong.