diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index fc18825a2e..5f03304203 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -36,7 +36,6 @@ def __init__(self, model_path, instance_num=32, tp=1, **kwargs) -> None: self.id2step = {} self.id2generator = {} self.loop = asyncio.get_event_loop() - self.special_gen = self.tm_model.create_instance() self.gens_set = set() for i in range(instance_num): self.gens_set.add(self.tm_model.create_instance()) @@ -44,8 +43,7 @@ def __init__(self, model_path, instance_num=32, tp=1, **kwargs) -> None: def stop_session(self, session_id: int): """Stop a session by a session_id.""" input_ids = [self.tm_model.eos_id] - stop_generator = self.id2generator.get(str(session_id), - self.special_gen) + stop_generator = self.tm_model.create_instance() for outputs in stop_generator.stream_infer(session_id, input_ids, request_output_len=0, @@ -60,8 +58,7 @@ def stop_session(self, session_id: int): def end_session(self, session_id: int): """Clear a session by a session_id.""" input_ids = [self.tm_model.eos_id] - end_generator = self.id2generator.get(str(session_id), - self.special_gen) + end_generator = self.tm_model.create_instance() for outputs in end_generator.stream_infer(session_id, input_ids, request_output_len=0, @@ -94,10 +91,12 @@ async def get_embeddings(self, prompt, do_prerpocess=False): async def get_generator(self, stop: bool, session_id: int): """Only return the model instance if it is available.""" if stop: - return self.id2generator.get(str(session_id), self.special_gen) + return self.tm_model.create_instance() while self.gens_set == set(): await asyncio.sleep(0) - return self.gens_set.pop() + generator = self.gens_set.pop() + self.id2generator[str(session_id)] = generator + return generator def batch_infer(self, prompts: List[str], @@ -214,7 +213,6 @@ async def generate( self.end_session(session_id) else: generator = await self.get_generator(stop, session_id) - self.id2generator[str(session_id)] = generator with self.safe_run(session_id): response_size = 0 async for outputs in generator.async_stream_infer(