Skip to content

Commit

Permalink
Merge branch 'main' into gradio-hf
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Dec 18, 2023
2 parents 02704fb + e3ac7fd commit a67c249
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 52 deletions.
1 change: 1 addition & 0 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def safe_run(self, session_id: Optional[int] = None):
yield
except (Exception, asyncio.CancelledError) as e: # noqa
self.stop_session(session_id)
raise e
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)])
Expand Down
245 changes: 193 additions & 52 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,58 +456,24 @@ def _func(device_id, enque_output):
t.start()
self.threads[device_id] = t

async def async_stream_infer(self, *args, **kwargs):
"""Async wrapper of self.stream_infer."""
for output in self.stream_infer(*args, **kwargs):
# Allow the pipeline add new requests into the queue.
await asyncio.sleep(0)
yield output

def stream_infer(self,
session_id,
input_ids,
input_embeddings=None,
input_embedding_ranges=None,
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
step=0,
stop=False,
top_p=0.8,
top_k=40,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
random_seed=None,
stream_output=False):
"""Perform model inference.
Args:
session_id (int): the id of a session
input_ids (numpy.ndarray): the token ids of a prompt
input_embeddings (List[numpy.ndarray]): embeddings features
input_embedding_ranges (List[Tuple[int,int]]): the begin/end
offsets of input_embeddings to input_ids
request_output_len (int): the max number of to-be-generated tokens
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): indicator for cancelling the session
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
random_seed (int): seed used by sampling
stream_output (bool): indicator for stream output
"""
if stream_output and not stop:
self.model_insts[0].register_callback(self._forward_callback)

def prepare_inputs(self,
session_id,
input_ids,
input_embeddings=None,
input_embedding_ranges=None,
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
step=0,
stop=False,
top_p=0.8,
top_k=40,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
random_seed=None,
stream_output=False):
"""Convert inputs format."""
if len(input_ids) == 0:
input_ids = [[]]
if isinstance(input_ids[0], int):
Expand Down Expand Up @@ -608,8 +574,183 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):

if random_seed is not None:
inputs['random_seed'] = _broadcast_np(random_seed, np.uint64)
return inputs, input_lengths

async def async_stream_infer(self,
session_id,
input_ids,
input_embeddings=None,
input_embedding_ranges=None,
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
step=0,
stop=False,
top_p=0.8,
top_k=40,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
random_seed=None,
stream_output=False):
"""Perform model inference.
Args:
session_id (int): the id of a session
input_ids (numpy.ndarray): the token ids of a prompt
input_embeddings (List[numpy.ndarray]): embeddings features
input_embedding_ranges (List[Tuple[int,int]]): the begin/end
offsets of input_embeddings to input_ids
request_output_len (int): the max number of to-be-generated tokens
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): indicator for cancelling the session
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
random_seed (int): seed used by sampling
stream_output (bool): indicator for stream output
"""
if stream_output and not stop:
self.model_insts[0].register_callback(self._forward_callback)
inputs, input_lengths = self.prepare_inputs(
session_id=session_id,
input_ids=input_ids,
input_embeddings=input_embeddings,
input_embedding_ranges=input_embedding_ranges,
request_output_len=request_output_len,
sequence_start=sequence_start,
sequence_end=sequence_end,
step=step,
stop=stop,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=random_seed,
stream_output=stream_output)

tm_inputs = _np_dict_to_tm_dict(inputs)
# start forward thread
self.que = Queue()
self._forward_thread(tm_inputs)

seq_start = input_lengths + input_lengths.new_tensor(step)

# generator
while True:
# Thanks for https://github.com/frankxyy and his issue
# https://github.com/InternLM/lmdeploy/issues/832
while self.que.qsize() == 0:
await asyncio.sleep(0)
while self.que.qsize() > 1:
self.que.get()

finish, tm_outputs = self.que.get()

outputs = _tm_dict_to_torch_dict(tm_outputs)

output_ids = outputs['output_ids'][:, 0, :]
sequence_length = outputs['sequence_length'].long()[:, 0]
output_ids = [
output_id[s:l] for output_id, s, l in zip(
output_ids, seq_start, sequence_length)
]
sequence_length -= seq_start.to(sequence_length.device)

outputs = []
for output, len_ in zip(output_ids, sequence_length):
output, len_ = output, len_.item()
if len(output) > 0 and output[-1].item(
) == self.eos_id and not ignore_eos:
outputs.append((output[:-1], len_ - 1))
elif len(output) > 0 and output[-1].item() in self.stop_tokens:
outputs.append((output[:-1], len_))
else:
outputs.append((output, len_))
yield outputs

if finish:
for t in self.threads:
t.join()
while self.que.qsize() > 0:
self.que.get()
break

if stream_output and not stop:
self.model_insts[0].unregister_callback()

def stream_infer(self,
session_id,
input_ids,
input_embeddings=None,
input_embedding_ranges=None,
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
step=0,
stop=False,
top_p=0.8,
top_k=40,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
random_seed=None,
stream_output=False):
"""Perform model inference.
Args:
session_id (int): the id of a session
input_ids (numpy.ndarray): the token ids of a prompt
input_embeddings (List[numpy.ndarray]): embeddings features
input_embedding_ranges (List[Tuple[int,int]]): the begin/end
offsets of input_embeddings to input_ids
request_output_len (int): the max number of to-be-generated tokens
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): indicator for cancelling the session
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
random_seed (int): seed used by sampling
stream_output (bool): indicator for stream output
"""
if stream_output and not stop:
self.model_insts[0].register_callback(self._forward_callback)
inputs, input_lengths = self.prepare_inputs(
session_id=session_id,
input_ids=input_ids,
input_embeddings=input_embeddings,
input_embedding_ranges=input_embedding_ranges,
request_output_len=request_output_len,
sequence_start=sequence_start,
sequence_end=sequence_end,
step=step,
stop=stop,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=random_seed,
stream_output=stream_output)

tm_inputs = _np_dict_to_tm_dict(inputs)
# start forward thread
self.que = Queue()
self._forward_thread(tm_inputs)
Expand Down

0 comments on commit a67c249

Please sign in to comment.