Skip to content

Commit

Permalink
Feat: no_skip_speicial_token (#148)
Browse files Browse the repository at this point in the history
* Feat: no_skip_speicial_token

* fix: get_logger of lmdeploy

* update lmdeploy requirement
  • Loading branch information
liujiangning30 authored Feb 19, 2024
1 parent 3cf20f5 commit 6620117
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
43 changes: 33 additions & 10 deletions lagent/llms/lmdepoly_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def generate(self,
request_id: str = '',
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Expand All @@ -58,7 +59,8 @@ def generate(self,
request_id (str): the identical id of this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
Expand All @@ -73,7 +75,7 @@ def generate(self,
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens

logger = get_logger(log_level=self.chatbot.log_level)
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_new_tokens}')

Expand All @@ -91,8 +93,12 @@ def generate(self,

status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
if status < ModelStatusCode.END:
return ''
Expand All @@ -111,6 +117,7 @@ def stream_chat(self,
request_id: str = '',
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
Expand All @@ -121,7 +128,8 @@ def stream_chat(self,
request_id (str): the identical id of this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
Expand All @@ -133,7 +141,7 @@ def stream_chat(self,
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens

logger = get_logger(log_level=self.chatbot.log_level)
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_new_tokens}')

Expand All @@ -152,8 +160,12 @@ def stream_chat(self,
prompt = self.template_parser(inputs)
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
# The stop symbol also appears in the output of the last STREAM_ING state.
res = filter_suffix(res, self.gen_params.get('stop_words'))
Expand Down Expand Up @@ -223,14 +235,16 @@ def __init__(self,
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
Expand All @@ -242,7 +256,8 @@ def generate(self,
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
gen_config = GenerationConfig(**gen_params)
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)
response = self.model.batch_infer(
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
response = [resp.text for resp in response]
Expand Down Expand Up @@ -308,6 +323,7 @@ def generate(self,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs) -> List[str]:
"""Start a new round conversation of a session. Return the chat
Expand All @@ -319,6 +335,8 @@ def generate(self,
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
Expand All @@ -342,6 +360,7 @@ def generate(self,
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp = [
Expand All @@ -361,6 +380,7 @@ def stream_chat(self,
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
Expand All @@ -373,6 +393,8 @@ def stream_chat(self,
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
Expand All @@ -394,6 +416,7 @@ def stream_chat(self,
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp += text['choices'][0]['text']
Expand Down
2 changes: 1 addition & 1 deletion requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
google-search-results
lmdeploy>=0.2.2
lmdeploy>=0.2.3
pillow
python-pptx
timeout_decorator
Expand Down

0 comments on commit 6620117

Please sign in to comment.