diff --git a/examples/vl/README.md b/examples/vl/README.md index cd9abdb25a..51fcc22391 100644 --- a/examples/vl/README.md +++ b/examples/vl/README.md @@ -30,5 +30,5 @@ A chatbot demo with image input. ## Limitations -- this demo the code in their repo to extract image features that might not very efficiency. -- this demo only contains the chat function. If you want to use localization ability in Qwen-VL-Chat or article generation function in InternLM-XComposer, you need implement these pre/post process. The difference compared to chat is how to build prompts and use the output of model. +- this demo uses the code in their repo to extract image features that might not very efficiency. +- this demo only contains the chat function. If you want to use localization ability in Qwen-VL-Chat or article generation function in InternLM-XComposer, you need implement these pre/post processes. The difference compared to chat is how to build prompts and use the output of model. diff --git a/examples/vl/app.py b/examples/vl/app.py index 39af89cf49..bb1b109594 100644 --- a/examples/vl/app.py +++ b/examples/vl/app.py @@ -8,6 +8,7 @@ from typing import List, Tuple import gradio as gr +from packaging.version import Version, parse from qwen_model import QwenVLChat from xcomposer_model import InternLMXComposer @@ -27,6 +28,11 @@ 'qwen-vl-chat': QwenVLChat } +if parse(gr.__version__) >= Version('4.0.0'): + que_kwargs = {'default_concurrency_limit': BATCH_SIZE} +else: + que_kwargs = {'concurrency_count': BATCH_SIZE} + @dataclass class Session: @@ -117,8 +123,6 @@ def chat( chatbot, session, ): - yield chatbot, session, disable_btn, enable_btn, disable_btn - generator = model.create_instance() history = session._message sequence_start = len(history) == 1 @@ -129,36 +133,38 @@ def chat( if len(input_ids) + session.step > model.model.session_len: gr.Warning('WARNING: exceed session max length.' ' Please restart the session by reset button.') - - response_size = 0 - step = session.step - for outputs in generator.stream_infer(session_id=session.session_id, - input_ids=input_ids, - input_embeddings=features, - input_embedding_ranges=ranges, - stream_output=True, - sequence_start=sequence_start, - random_seed=seed, - step=step): - res, tokens = outputs[0] - # decode res - response = model.tokenizer.decode(res.tolist(), - offset=response_size) - if response.endswith('�'): - continue - response = valid_str(response) - response_size = tokens - if chatbot[-1][1] is None: - chatbot[-1][1] = '' - history[-1][1] = '' - chatbot[-1][1] += response - history[-1][1] += response - session._step = step + len(input_ids) + tokens - yield chatbot, session, disable_btn, enable_btn, disable_btn - - yield chatbot, session, enable_btn, disable_btn, enable_btn - - def cancel(chatbot, session): + yield chatbot, session, enable_btn, disable_btn, enable_btn + else: + response_size = 0 + step = session.step + for outputs in generator.stream_infer( + session_id=session.session_id, + input_ids=input_ids, + input_embeddings=features, + input_embedding_ranges=ranges, + stream_output=True, + sequence_start=sequence_start, + random_seed=seed, + step=step): + res, tokens = outputs[0] + # decode res + response = model.tokenizer.decode(res.tolist(), + offset=response_size) + if response.endswith('�'): + continue + response = valid_str(response) + response_size = tokens + if chatbot[-1][1] is None: + chatbot[-1][1] = '' + history[-1][1] = '' + chatbot[-1][1] += response + history[-1][1] += response + session._step = step + len(input_ids) + tokens + yield chatbot, session, disable_btn, enable_btn, disable_btn + + yield chatbot, session, enable_btn, disable_btn, enable_btn + + def stop(session): generator = model.create_instance() for _ in generator.stream_infer(session_id=session.session_id, input_ids=[0], @@ -167,18 +173,14 @@ def cancel(chatbot, session): sequence_end=False, stop=True): pass - return chatbot, session, disable_btn, enable_btn + + def cancel(chatbot, session): + stop(session) + return chatbot, session, disable_btn, enable_btn, enable_btn def reset(session): - generator = model.create_instance() - for _ in generator.stream_infer(session_id=session.session_id, - input_ids=[0], - request_output_len=0, - sequence_start=False, - sequence_end=False, - stop=True): - pass - return [], Session() + stop(session) + return [], Session(), enable_btn with gr.Blocks(css=CSS, theme=THEME) as demo: with gr.Column(elem_id='container'): @@ -197,7 +199,8 @@ def reset(session): addimg_btn.upload(add_image, [chatbot, session, addimg_btn], [chatbot, session], - show_progress=True) + show_progress=True, + queue=True) send_event = query.submit( add_text, [chatbot, session, query], [chatbot, session]).then( @@ -206,15 +209,15 @@ def reset(session): query.submit(lambda: gr.update(value=''), None, [query]) cancel_btn.click(cancel, [chatbot, session], - [chatbot, session, cancel_btn, reset_btn], + [chatbot, session, cancel_btn, reset_btn, query], cancels=[send_event]) - reset_btn.click(reset, [session], [chatbot, session], + reset_btn.click(reset, [session], [chatbot, session, query], cancels=[send_event]) demo.load(lambda: Session(), inputs=None, outputs=[session]) - demo.queue(api_open=True, concurrency_count=BATCH_SIZE, max_size=100) + demo.queue(api_open=True, **que_kwargs, max_size=100) demo.launch( share=True, server_port=args.server_port,