Skip to content

Commit

Permalink
[Feature] Add params config to turbomind backend and triton server ba…
Browse files Browse the repository at this point in the history
…ckend for the WebUI
  • Loading branch information
amulil committed Dec 18, 2023
1 parent 4b7af81 commit d959508
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
23 changes: 20 additions & 3 deletions lmdeploy/serve/gradio/triton_server_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ class InterFace:


def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int):
cancel_btn: gr.Button, reset_btn: gr.Button,
session_id: int, top_p: float, temperature: float,
request_output_len: int):
"""Chat with AI assistant.
Args:
Expand All @@ -30,7 +32,10 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
instruction = state_chatbot[-1][0]

bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
session_id, instruction, f'{session_id}-{len(state_chatbot)}',
request_output_len=request_output_len,
top_p=top_p,
temperature=temperature)

for status, tokens, _ in bot_response:
state_chatbot[-1] = (state_chatbot[-1][0], tokens)
Expand Down Expand Up @@ -108,12 +113,24 @@ def run_triton_server(triton_server_addr: str,
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
with gr.Row():
request_output_len = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')

send_event = instruction_txtbox.submit(
add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then(chat_stream, [
state_chatbot, llama_chatbot, cancel_btn, reset_btn,
state_session_id
state_session_id, top_p, temperature, request_output_len
], [state_chatbot, chatbot, cancel_btn, reset_btn])

cancel_btn.click(cancel_func,
Expand Down
22 changes: 20 additions & 2 deletions lmdeploy/serve/gradio/turbomind_coupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ async def chat_stream_local(
cancel_btn: gr.Button,
reset_btn: gr.Button,
session_id: int,
top_p: float,
temperature: float,
request_output_len: int
):
"""Chat with AI assistant.
Expand All @@ -39,7 +42,10 @@ async def chat_stream_local(
session_id,
stream_response=True,
sequence_start=(len(state_chatbot) == 1),
sequence_end=False):
sequence_end=False,
request_output_len=request_output_len,
top_p=top_p,
temperature=temperature):
response = outputs.response
if outputs.finish_reason == 'length':
gr.Warning('WARNING: exceed session max length.'
Expand Down Expand Up @@ -150,10 +156,22 @@ def run_local(model_path: str,
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
with gr.Row():
request_output_len = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')

send_event = instruction_txtbox.submit(chat_stream_local, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id
state_session_id, top_p, temperature, request_output_len
], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
Expand Down

0 comments on commit d959508

Please sign in to comment.