Skip to content

Commit

Permalink
support gradio 4.x
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 20, 2023
1 parent f096cdf commit 2f5fa2e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 48 deletions.
4 changes: 2 additions & 2 deletions examples/vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ A chatbot demo with image input.

## Limitations

- this demo the code in their repo to extract image features that might not very efficiency.
- this demo only contains the chat function. If you want to use localization ability in Qwen-VL-Chat or article generation function in InternLM-XComposer, you need implement these pre/post process. The difference compared to chat is how to build prompts and use the output of model.
- this demo uses the code in their repo to extract image features that might not very efficiency.
- this demo only contains the chat function. If you want to use localization ability in Qwen-VL-Chat or article generation function in InternLM-XComposer, you need implement these pre/post processes. The difference compared to chat is how to build prompts and use the output of model.
95 changes: 49 additions & 46 deletions examples/vl/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List, Tuple

import gradio as gr
from packaging.version import Version, parse
from qwen_model import QwenVLChat
from xcomposer_model import InternLMXComposer

Expand All @@ -27,6 +28,11 @@
'qwen-vl-chat': QwenVLChat
}

if parse(gr.__version__) >= Version('4.0.0'):
que_kwargs = {'default_concurrency_limit': BATCH_SIZE}
else:
que_kwargs = {'concurrency_count': BATCH_SIZE}


@dataclass
class Session:
Expand Down Expand Up @@ -117,8 +123,6 @@ def chat(
chatbot,
session,
):
yield chatbot, session, disable_btn, enable_btn, disable_btn

generator = model.create_instance()
history = session._message
sequence_start = len(history) == 1
Expand All @@ -129,36 +133,38 @@ def chat(
if len(input_ids) + session.step > model.model.session_len:
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')

response_size = 0
step = session.step
for outputs in generator.stream_infer(session_id=session.session_id,
input_ids=input_ids,
input_embeddings=features,
input_embedding_ranges=ranges,
stream_output=True,
sequence_start=sequence_start,
random_seed=seed,
step=step):
res, tokens = outputs[0]
# decode res
response = model.tokenizer.decode(res.tolist(),
offset=response_size)
if response.endswith('�'):
continue
response = valid_str(response)
response_size = tokens
if chatbot[-1][1] is None:
chatbot[-1][1] = ''
history[-1][1] = ''
chatbot[-1][1] += response
history[-1][1] += response
session._step = step + len(input_ids) + tokens
yield chatbot, session, disable_btn, enable_btn, disable_btn

yield chatbot, session, enable_btn, disable_btn, enable_btn

def cancel(chatbot, session):
yield chatbot, session, enable_btn, disable_btn, enable_btn
else:
response_size = 0
step = session.step
for outputs in generator.stream_infer(
session_id=session.session_id,
input_ids=input_ids,
input_embeddings=features,
input_embedding_ranges=ranges,
stream_output=True,
sequence_start=sequence_start,
random_seed=seed,
step=step):
res, tokens = outputs[0]
# decode res
response = model.tokenizer.decode(res.tolist(),
offset=response_size)
if response.endswith('�'):
continue
response = valid_str(response)
response_size = tokens
if chatbot[-1][1] is None:
chatbot[-1][1] = ''
history[-1][1] = ''
chatbot[-1][1] += response
history[-1][1] += response
session._step = step + len(input_ids) + tokens
yield chatbot, session, disable_btn, enable_btn, disable_btn

yield chatbot, session, enable_btn, disable_btn, enable_btn

def stop(session):
generator = model.create_instance()
for _ in generator.stream_infer(session_id=session.session_id,
input_ids=[0],
Expand All @@ -167,18 +173,14 @@ def cancel(chatbot, session):
sequence_end=False,
stop=True):
pass
return chatbot, session, disable_btn, enable_btn

def cancel(chatbot, session):
stop(session)
return chatbot, session, disable_btn, enable_btn, enable_btn

def reset(session):
generator = model.create_instance()
for _ in generator.stream_infer(session_id=session.session_id,
input_ids=[0],
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
return [], Session()
stop(session)
return [], Session(), enable_btn

with gr.Blocks(css=CSS, theme=THEME) as demo:
with gr.Column(elem_id='container'):
Expand All @@ -197,7 +199,8 @@ def reset(session):

addimg_btn.upload(add_image, [chatbot, session, addimg_btn],
[chatbot, session],
show_progress=True)
show_progress=True,
queue=True)

send_event = query.submit(
add_text, [chatbot, session, query], [chatbot, session]).then(
Expand All @@ -206,15 +209,15 @@ def reset(session):
query.submit(lambda: gr.update(value=''), None, [query])

cancel_btn.click(cancel, [chatbot, session],
[chatbot, session, cancel_btn, reset_btn],
[chatbot, session, cancel_btn, reset_btn, query],
cancels=[send_event])

reset_btn.click(reset, [session], [chatbot, session],
reset_btn.click(reset, [session], [chatbot, session, query],
cancels=[send_event])

demo.load(lambda: Session(), inputs=None, outputs=[session])

demo.queue(api_open=True, concurrency_count=BATCH_SIZE, max_size=100)
demo.queue(api_open=True, **que_kwargs, max_size=100)
demo.launch(
share=True,
server_port=args.server_port,
Expand Down

0 comments on commit 2f5fa2e

Please sign in to comment.