Skip to content

Commit

Permalink
fix out of boundary index
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Dec 18, 2023
1 parent fc0386a commit 2a5f0eb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
14 changes: 7 additions & 7 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def end_session(self, session_id: int):
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=True,
stop=True):
sequence_end=True):
pass
self.id2step[str(session_id)] = 0
if str(session_id) in self.id2generator and self.id2generator[str(
Expand Down Expand Up @@ -265,17 +264,17 @@ async def generate(
prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = None
if self.id2step[str(session_id)] + len(
if stop is True:
self.stop_session(session_id)
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason)
elif self.id2step[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len:
finish_reason = 'length'
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason)
if sequence_end is True and sequence_start is False:
self.end_session(session_id)
elif stop is True:
self.stop_session(session_id)
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason)
else:
generator = await self.get_generator(stop, session_id)
with self.safe_run(session_id):
Expand All @@ -296,6 +295,7 @@ async def generate(
ignore_eos=ignore_eos,
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
print(res.tolist()[response_size:], response_size)
# decode res
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
Expand Down
13 changes: 10 additions & 3 deletions lmdeploy/serve/gradio/api_server_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import time
from threading import Lock
from typing import Sequence

Expand Down Expand Up @@ -89,16 +88,24 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, disable_btn)
# end the session
# stop the session
for out in get_streaming_response(
'',
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=0,
stop=True,
interactive_mode=True):
pass
# end the session
for out in get_streaming_response(
'',
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=0,
interactive_mode=False):
pass
time.sleep(0.5)
# resume the session
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/serve/gradio/turbomind_coupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, disable_btn)
InterFace.async_engine.stop_session(session_id)
InterFace.async_engine.end_session(session_id)
messages = []
for qa in state_chatbot:
Expand Down

0 comments on commit 2a5f0eb

Please sign in to comment.