diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 7867c6ca7c..5b1ec0980a 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -13,19 +13,20 @@ from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.protocol import ( # noqa: E501 - ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionRequest, ChatCompletionRequestQos, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, CompletionRequest, + ChatCompletionStreamResponse, ChatMessage, CompletionRequest, CompletionRequestQos, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse, - GenerateRequest, GenerateResponse, ModelCard, ModelList, ModelPermission, - UsageInfo) - + GenerateRequest, GenerateRequestQos, GenerateResponse, ModelCard, ModelList, + ModelPermission, UsageInfo) +from lmdeploy.serve.qos_engine.qos_engine import QosEngine class VariableInterface: """A IO interface maintaining variables.""" async_engine: AsyncEngine = None + qos_engine: QosEngine = None request_hosts = [] @@ -83,6 +84,140 @@ def ip2id(host_ip: str): print('Warning, could not get session id from ip, set it 0') return 0 +@app.post('/v1/chat/completions_qos') +async def chat_completions_v1_qos(request: ChatCompletionRequestQos, + raw_request: Request = None): + """Completion API similar to OpenAI's API. + + Refer to `https://platform.openai.com/docs/api-reference/chat/create` + for the API specification. + + The request should be a JSON object with the following fields: + - model: model name. Available from /v1/models. + - messages: string prompt or chat history in OpenAI format. + - temperature (float): to modulate the next token probability + - top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + - n (int): How many chat completion choices to generate for each input + message. Only support one here. + - stream: whether to stream the results or not. Default to false. + - max_tokens (int): output token nums + - repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + + Additional arguments supported by LMDeploy: + - ignore_eos (bool): indicator for ignoring eos + - session_id (int): if not specified, will set random value + + Currently we do not support the following features: + - function_call (Users should implement this by themselves) + - logit_bias (not supported yet) + - presence_penalty (replaced with repetition_penalty) + - frequency_penalty (replaced with repetition_penalty) + - user_id (str): qos_tag, user who call this function + """ + if request.session_id == -1: + request.session_id = random.randint(1, 10086) + error_check_ret = await check_request(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = str(request.session_id) + created_time = int(time.time()) + + result_generator = await VariableInterface.qos_engine.generate_with_qos(request) + + if result_generator is None: + return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Failed to generate completions') + + def create_stream_response_json( + index: int, + text: str, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(role='assistant', content=text), + finish_reason=finish_reason, + ) + response = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.model_dump_json() + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # First chunk with role + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role='assistant'), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse(id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f'data: {data}\n\n' + + async for res in result_generator: + response_json = create_stream_response_json( + index=0, + text=res.response, + ) + yield f'data: {response_json}\n\n' + yield 'data: [DONE]\n\n' + + # Streaming response + if request.stream: + return StreamingResponse(completion_stream_generator(), + media_type='text/event-stream') + + # Non-streaming response + final_res = None + text = '' + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + VariableInterface.async_engine.stop_session(request.session_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Client disconnected') + final_res = res + text += res.response + assert final_res is not None + choices = [] + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role='assistant', content=text), + finish_reason=final_res.finish_reason, + ) + choices.append(choice_data) + + total_tokens = sum([ + final_res.history_token_len, final_res.input_token_len, + final_res.generate_token_len + ]) + usage = UsageInfo( + prompt_tokens=final_res.input_token_len, + completion_tokens=final_res.generate_token_len, + total_tokens=total_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + @app.post('/v1/chat/completions') async def chat_completions_v1(request: ChatCompletionRequest, @@ -230,6 +365,167 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return response +@app.post('/v1/completions_qos') +async def completions_v1_qos(request: CompletionRequestQos, + raw_request: Request = None): + """Completion API similar to OpenAI's API. + + Go to `https://platform.openai.com/docs/api-reference/completions/create` + for the API specification. + + The request should be a JSON object with the following fields: + - model (str): model name. Available from /v1/models. + - prompt (str): the input prompt. + - suffix (str): The suffix that comes after a completion of inserted text. + - max_tokens (int): output token nums + - temperature (float): to modulate the next token probability + - top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + - n (int): How many chat completion choices to generate for each input + message. Only support one here. + - stream: whether to stream the results or not. Default to false. + - repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + - user (str): A unique identifier representing your end-user. + + Additional arguments supported by LMDeploy: + - ignore_eos (bool): indicator for ignoring eos + - session_id (int): if not specified, will set random value + + Currently we do not support the following features: + - logprobs (not supported yet) + - presence_penalty (replaced with repetition_penalty) + - frequency_penalty (replaced with repetition_penalty) + - user_id (str): qos_tag, user who call this function + """ + if request.session_id == -1: + request.session_id = random.randint(1, 10086) + error_check_ret = await check_request(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = str(request.session_id) + created_time = int(time.time()) + if isinstance(request.prompt, str): + request.prompt = [request.prompt] + + generators = await VariableInterface.qos_engine.generate_with_qos(request) + + # generators = [] + # for i in range(len(request.prompt)): + # result_generator = VariableInterface.async_engine.generate( + # request.prompt[i], + # request.session_id + i, + # True, # always use stream to enable batching + # sequence_start=True, + # sequence_end=True, + # request_output_len=request.max_tokens + # if request.max_tokens else 512, + # stop=False, + # top_p=request.top_p, + # temperature=request.temperature, + # repetition_penalty=request.repetition_penalty, + # ignore_eos=request.ignore_eos, + # do_preprocess=False) + # generators.append(result_generator) + + def create_stream_response_json( + index: int, + text: str, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.model_dump_json() + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # First chunk with role + for generator in generators: + for i in range(request.n): + choice_data = CompletionResponseStreamChoice( + index=i, + text='', + finish_reason=None, + ) + chunk = CompletionStreamResponse(id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f'data: {data}\n\n' + + async for res in generator: + response_json = create_stream_response_json( + index=0, + text=res.response, + ) + yield f'data: {response_json}\n\n' + yield 'data: [DONE]\n\n' + + # Streaming response + if request.stream: + return StreamingResponse(completion_stream_generator(), + media_type='text/event-stream') + + # Non-streaming response + usage = UsageInfo() + choices = [] + + async def _inner_call(i, generator): + final_res = None + text = '' + async for res in generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + VariableInterface.async_engine.stop_session(request.session_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Client disconnected') + final_res = res + text += res.response + assert final_res is not None + choice_data = CompletionResponseChoice( + index=0, + text=text, + finish_reason=final_res.finish_reason, + ) + choices.append(choice_data) + + total_tokens = sum([ + final_res.history_token_len, final_res.input_token_len, + final_res.generate_token_len + ]) + usage.prompt_tokens += final_res.input_token_len + usage.completion_tokens += final_res.generate_token_len + usage.total_tokens += total_tokens + + await asyncio.gather( + *[_inner_call(i, generators[i]) for i in range(len(generators))]) + + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + + + @app.post('/v1/completions') async def completions_v1(request: CompletionRequest, raw_request: Request = None): @@ -427,6 +723,71 @@ def encode(prompt: str, do_preprocess: bool, add_bos: bool): length.append(len(ids)) return EncodeResponse(input_ids=encoded, length=length) +@app.post('/v1/chat/interactive_qos') +async def chat_interactive_v1_qos(request: GenerateRequestQos, + raw_request: Request = None): + """Generate completion for the request. + + - On interactive mode, the chat history is kept on the server. Please set + `interactive_mode = True`. + - On normal mode, no chat history is kept on the server. Set + `interactive_mode = False`. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - session_id: determine which instance will be called. If not specified + with a value other than -1, using random value directly. + - interactive_mode (bool): turn on interactive mode or not. On interactive + mode, session history is kept on the server (and vice versa). + - stream: whether to stream the results or not. + - stop: whether to stop the session response or not. + - request_output_len (int): output token nums + - top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + - top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering + - temperature (float): to modulate the next token probability + - repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + - ignore_eos (bool): indicator for ignoring eos + - user_id (str): qos_tag, user who call this function + """ + if request.session_id == -1: + request.session_id = random.randint(10087, 23333) + + generation = await VariableInterface.qos_engine.generate_with_qos(request) + + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for out in generation: + chunk = GenerateResponse(text=out.response, + tokens=out.generate_token_len, + finish_reason=out.finish_reason) + data = chunk.model_dump_json() + yield f'{data}\n' + + if request.stream: + return StreamingResponse(stream_results(), + media_type='text/event-stream') + else: + ret = {} + text = '' + tokens = 0 + finish_reason = None + async for out in generation: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + async_engine.stop_session(request.session_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Client disconnected') + text += out.response + tokens = out.generate_token_len + finish_reason = out.finish_reason + ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason} + return JSONResponse(ret) + @app.post('/generate', tags=['deprecated'], @@ -551,6 +912,7 @@ def serve(model_path: str, allow_methods (List[str]): a list of allowed HTTP methods for CORS allow_headers (List[str]): a list of allowed HTTP headers for CORS log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] + qos_config_path (str): qos policy config path """ # noqa E501 os.environ['TM_LOG_LEVEL'] = log_level @@ -562,12 +924,25 @@ def serve(model_path: str, allow_methods=allow_methods, allow_headers=allow_headers, ) + qos_config_str = "" + if qos_config_path: + try: + with open(qos_config_path, 'r') as file: + qos_config_str = file.read() + except FileNotFoundError: + qos_config_str = "" VariableInterface.async_engine = AsyncEngine(model_path=model_path, model_name=model_name, instance_num=instance_num, tp=tp, **kwargs) + VariableInterface.qos_engine = QosEngine(instance_num=instance_num, + qos_tag=qos_config_str, + engine=VariableInterface.async_engine, + **kwargs) + VariableInterface.qos_engine.start() + for i in range(3): print(f'HINT: Please open \033[93m\033[1mhttp://{server_name}:' f'{server_port}\033[0m in a browser for detailed api usage!!!') diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 942f554034..aecde4ed7e 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -55,6 +55,26 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 +class ChatCompletionRequestQos(BaseModel): + """Chat completion request.""" + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = 512 + stop: Optional[bool] = False + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + user_id: Optional[str] = None + # additional argument of lmdeploy + repetition_penalty: Optional[float] = 1.0 + session_id: Optional[int] = -1 + ignore_eos: Optional[bool] = False + + class ChatCompletionRequest(BaseModel): """Chat completion request.""" model: str @@ -141,6 +161,28 @@ class CompletionRequest(BaseModel): ignore_eos: Optional[bool] = False top_k: Optional[int] = 40 # for opencompass +class CompletionRequestQos(BaseModel): + """Completion request.""" + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + # additional argument of lmdeploy + repetition_penalty: Optional[float] = 1.0 + session_id: Optional[int] = -1 + ignore_eos: Optional[bool] = False + user_id: Optional[str] = None + class CompletionResponseChoice(BaseModel): """Completion response choices.""" @@ -220,6 +262,22 @@ class GenerateRequest(BaseModel): ignore_eos: bool = False +class GenerateRequestQos(BaseModel): + """Generate request.""" + prompt: Union[str, List[Dict[str, str]]] + session_id: int = -1 + interactive_mode: bool = False + stream: bool = False + stop: bool = False + request_output_len: int = 512 + top_p: float = 0.8 + top_k: int = 40 + temperature: float = 0.8 + repetition_penalty: float = 1.0 + ignore_eos: bool = False + user_id: Optional[str] = None + + class GenerateResponse(BaseModel): """Generate response.""" text: str diff --git a/lmdeploy/serve/qos_engine/__init__.py b/lmdeploy/serve/qos_engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lmdeploy/serve/qos_engine/inner_group_schd.py b/lmdeploy/serve/qos_engine/inner_group_schd.py new file mode 100644 index 0000000000..965697b954 --- /dev/null +++ b/lmdeploy/serve/qos_engine/inner_group_schd.py @@ -0,0 +1,81 @@ +import queue +import threading +import logging +logger = logging.getLogger(__name__) + +class UserRequestQueue: + """ + Inner group user request queues + """ + + def __init__(self, group: str, user_id_map: dict): + self.group = group + self.user_queue_map = dict() + self.user_quota_map = dict() + self.user_id_maps = user_id_map + + total_quota = 0 + for item in user_id_map: + total_quota += item["quota_pct"] + for item in user_id_map: + user_id = item["id"] + self.user_queue_map[user_id] = queue.Queue() + self.user_quota_map[user_id] = item["quota_pct"] / total_quota + + self.lock = threading.Lock() + + def enqueue(self, request_event): + """ + Enqueue request to correspoding user queue. + """ + if request_event[0].user_id in self.user_queue_map: + self.user_queue_map[request_event[0].user_id].put(request_event) + else: + self.user_queue_map["default"].put(request_event) + + def empty(self): + """ + Whether all user queues are empty. + """ + with self.lock: + for _, user_queue in self.user_queue_map.items(): + if not user_queue.empty(): + return False + return True + + def dequeue(self, usage_stats): + """ + Dequeue the request to serve. + """ + with self.lock: + uid_to_serve = self.user_to_serve(usage_stats) + if uid_to_serve in self.user_queue_map: + return self.user_queue_map[uid_to_serve].get() + + return None + + def user_to_serve(self, usage_stats): + """ + Inner group scheduling. + Find the user to serve from user request queues. + """ + min_usage = 100 + uid_to_serve = "" + for uid, req_queue in self.user_queue_map.items(): + if req_queue.empty(): + continue + + # TODO: include token length + # Calculate current user's actual used share and quota share + user_usage, _, group_usage, _ = usage_stats.get_user_usage(uid, self.group) + actual_share = (user_usage / group_usage) if group_usage > 0 else 0 + due_share = self.user_quota_map[uid] + + # Serve the user with the relatively least usage share + curr_usage = actual_share / due_share + if curr_usage == 0: + return uid + if curr_usage < min_usage: + uid_to_serve = uid + min_usage = curr_usage + return uid_to_serve diff --git a/lmdeploy/serve/qos_engine/qos_config.json.template b/lmdeploy/serve/qos_engine/qos_config.json.template new file mode 100644 index 0000000000..73bc7723ee --- /dev/null +++ b/lmdeploy/serve/qos_engine/qos_config.json.template @@ -0,0 +1,58 @@ +{ + "enable_user_qos": 1, + "user_groups": ["Platinum", "Gold", "Silver", "Bronze"], + "user_group_map": { + "Platinum": [ + { + "id": "user_id0", + "quota_pct": 100 + }, + { + "id": "default", + "quota_pct": 0 + } + ], + "Gold": [ + { + "id": "user_id1", + "quota_pct": 50 + }, + { + "id": "user_id2", + "quota_pct": 50 + }, + { + "id": "default", + "quota_pct": 0 + } + ], + "Silver": [ + { + "id": "user_id3", + "quota_pct": 5 + }, + { + "id": "default", + "quota_pct": 95 + } + ], + "Bronze": [ + { + "id": "user_id4", + "quota_pct": 30 + }, + { + "id": "user_id5", + "quota_pct": 30 + }, + { + "id": "user_id6", + "quota_pct": 40 + }, + { + "id": "default", + "quota_pct": 0 + } + ] + } +} \ No newline at end of file diff --git a/lmdeploy/serve/qos_engine/qos_engine.py b/lmdeploy/serve/qos_engine/qos_engine.py new file mode 100644 index 0000000000..e5ab80457b --- /dev/null +++ b/lmdeploy/serve/qos_engine/qos_engine.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import queue +import json +import threading +import time +from lmdeploy.serve.async_engine import AsyncEngine +from lmdeploy.serve.qos_engine.usage_stats import UsageStats +from lmdeploy.serve.qos_engine.inner_group_schd import UserRequestQueue + +from lmdeploy.serve.openai.protocol import (ChatCompletionRequestQos, CompletionRequestQos, GenerateRequestQos) + +import logging +logger = logging.getLogger(__name__) + +class QosConfig: + def __init__(self, qos_tag=""): + try: + qos_config = json.loads(qos_tag) + self.is_qos_enabled = qos_config["enable_user_qos"] + self.user_id_maps = qos_config["user_group_map"] + self.user_group_prio = qos_config["user_groups"] + except: + self.is_qos_enabled = False + self.user_id_maps = dict() + self.user_group_prio = [] + logger.debug(f"is_qos_enabled: {self.is_qos_enabled}") + logger.debug(f"user_id_maps: {self.user_id_maps}") + logger.debug(f"user_group_prio: {self.user_group_prio}") + +class QosEngine: + def __init__(self, instance_num=16, qos_tag="", engine=None, **kwargs) -> None: + self.engine = engine + self.availSlots = instance_num + self._stop_event = threading.Event() + self._dequeue_thread = threading.Thread(target=self._serve, daemon=True) + self.qos_config = QosConfig(qos_tag) + + self.qos_user_group = QosGroupQueue(self.qos_config) + + self.usage_stats = UsageStats(60, 6, 0, self.qos_config.user_group_prio) + self.user_served_reqs = dict() + self._dump_stats_thread = threading.Thread(target=self._dump_stats, daemon=True) + + self.lock = threading.Lock() + self.stats_lock = threading.Lock() + + def start(self): + if self.is_qos_enabled(): + self._dequeue_thread.start() + self._dump_stats_thread.start() + + def is_qos_enabled(self): + return self.qos_config.is_qos_enabled + + def stop_session(self, session_id: int): + """Stop a session by a session_id.""" + self.engine.stop_session(session_id) + + async def generate(self, request): + + if isinstance(request,CompletionRequestQos): + generators = [] + for i in range(len(request.prompt)): + result_generator = self.engine.generate( + request.prompt[i], + request.session_id + i, + True, # always use stream to enable batching + sequence_start=True, + sequence_end=True, + request_output_len=request.max_tokens + if request.max_tokens else 512, + stop=False, + top_p=request.top_p, + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + ignore_eos=request.ignore_eos, + do_preprocess=False) + generators.append(result_generator) + return generators + + elif isinstance(request,GenerateRequestQos): + async_engine = self.engine + sequence_start = async_engine.steps.get(str(request.session_id), 0) == 0 + sequence_end = not request.interactive_mode + + generation = async_engine.generate( + request.prompt, + request.session_id, + stream_response=True, # always use stream to enable batching + sequence_start=sequence_start, + sequence_end=sequence_end, + request_output_len=request.request_output_len, + top_p=request.top_p, + top_k=request.top_k, + stop=request.stop, + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + ignore_eos=request.ignore_eos) + return generation + + elif isinstance(request,ChatCompletionRequestQos): + # default chat/completions + result_generator =self.engine.generate( + request.messages, + request.session_id, + True, # always use stream to enable batching + sequence_start=True, + sequence_end=True, + request_output_len=request.max_tokens if request.max_tokens else 512, + stop=request.stop, + top_p=request.top_p, + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + ignore_eos=request.ignore_eos) + return result_generator + + return time.sleep(0.01) + + async def generate_with_qos(self, request): + if not self.is_qos_enabled(): + return await self.generate(request) + + # push (request,event) to queue + event = asyncio.Event() + request_event = (request,event) + self.qos_user_group.enqueue(request_event) + + await event.wait() + + result_generator = await self.generate(request) + + # release self.availSlots resources + with self.lock: + if hasattr(request,'prompt'): + self.availSlots += len(request.prompt) + else: + self.availSlots += 1 + + # Update number of served requests for each user + with self.stats_lock: + if request.user_id not in self.user_served_reqs: + self.user_served_reqs[request.user_id] = 1 + else: + self.user_served_reqs[request.user_id] += 1 + logger.debug(f"Available slot increase, now: {self.availSlots}") + + return result_generator + + + def _serve(self): + while not self._stop_event.is_set(): + if self.availSlots > 0: + with self.lock: + request_event = self.dequeue(self.usage_stats) + if request_event != None: + # Update usage_stats + user_group = self.qos_user_group.get_user_group(request_event[0].user_id) + self.usage_stats.update_usage(request_event[0].user_id, user_group, 100, int(time.time())) + if hasattr(request_event[0],'prompt'): + self.availSlots -= len(request_event[0].prompt) + else: + self.availSlots -= 1 + request_event[1].set() + logger.debug(f"Available slot decrease, now: {self.availSlots}") + continue + + def _dump_stats(self): + ts = 0 + while not self._stop_event.is_set(): + outdata = "" + with self.stats_lock: + if not self.user_served_reqs: + outdata = "none" + else: + sorted_uids = sorted(self.user_served_reqs.keys()) + for uid in sorted_uids: + outdata += f"{uid} {self.user_served_reqs[uid]} reqs, " + self.user_served_reqs = dict() + logger.info(f"qos service running for {ts} seconds, served in last 20 seconds: {outdata}") + ts += 20 + time.sleep(20) + + def dequeue(self, usage_stats): + return self.qos_user_group.dequeue(usage_stats) + + def stop(self): + self._stop_event.set() + self._dequeue_thread.join() + +class QosGroupQueue: + def __init__(self,qos_config): + if qos_config == None: + self.user_list = {} + self.queues = {} + else: + self.user_list = qos_config.user_id_maps + self.queues = {} + for user_group in qos_config.user_group_prio: + self.queues[user_group] = UserRequestQueue(user_group, self.user_list[user_group]) + self.user_group_list = list(self.user_list.keys()) + self.default_user_group = self.user_group_list[2] if len(self.user_group_list)>=3 else "None" + logger.debug(self.user_list) + logger.debug(self.queues) + logger.debug(self.default_user_group) + + def get_user_group(self, user_id): + for category, users in self.user_list.items(): + for user in users: + if user_id == user['id']: + return category + return self.default_user_group + def enqueue(self, request_event): + user_id = self.get_user_group(request_event[0].user_id) + self.queues[user_id].enqueue(request_event) + + def dequeue(self, usage_stats): + for user_group_id, user_group_queue in self.queues.items(): + if user_group_queue.empty(): + continue + else: + return user_group_queue.dequeue(usage_stats) + return None + diff --git a/lmdeploy/serve/qos_engine/usage_stats.py b/lmdeploy/serve/qos_engine/usage_stats.py new file mode 100644 index 0000000000..b47a4dcda4 --- /dev/null +++ b/lmdeploy/serve/qos_engine/usage_stats.py @@ -0,0 +1,115 @@ +import json +import threading +from typing import List + + +class Buffer: + def __init__(self, ts: int, user_groups: List[str]): + self.ts = ts + # Per user usage + self.uid_to_tokens_ps = dict() + self.uid_to_reqs_ps = dict() + + # Per group usage + self.group_to_tokens_ps = dict() + self.group_to_reqs_ps = dict() + + for group in user_groups: + self.group_to_tokens_ps[group] = 0 + self.group_to_reqs_ps[group] = 0 + + +class UsageStats: + def __init__(self, total_duration: int, buffer_count: int, start_index: int, user_groups: List[str]): + self.total_duration = total_duration + self.buffer_count = buffer_count + self.start_index = start_index + self.start_ts = int(0) + + self.buffer_duration = int(total_duration / buffer_count) + self.circular_buffer = [Buffer(self.buffer_duration * i, user_groups) for i in range(buffer_count)] + + self.user_groups = user_groups + + self.lock = threading.Lock() + + def update_usage(self, uid: str, group: str, out_token_len: int, req_ts: int): + """ + Update UsageStats when a request is returned + """ + with self.lock: + intervals = int((req_ts-self.start_ts) / self.buffer_duration) + + curr_idx = (self.start_index+intervals) % self.buffer_count + curr_ts = self.start_ts + intervals*self.buffer_duration + + # Current request outside the sliding window + if intervals >= self.buffer_count: + reset_buf_cnt = intervals - self.buffer_count + curr_buf_ts = 0 + + if reset_buf_cnt >= self.buffer_count: + # All buffers are reset + for i in range(1, self.buffer_count): + reset_idx = (curr_idx+i) % self.buffer_count + self.circular_buffer[reset_idx] = Buffer(req_ts + i*self.buffer_duration, self.user_groups) + # Update self.start_index + self.start_index = curr_idx + self.start_ts = req_ts + curr_buf_ts = req_ts + else: + # Only buffers between self.start_index and curr_idx are reset + for i in range(reset_buf_cnt): + reset_idx = (self.start_index+i) % self.buffer_count + reset_ts = self.circular_buffer[reset_idx].ts + self.total_duration + self.circular_buffer[reset_idx] = Buffer(reset_ts, self.user_groups) + + # Update self.start_index + self.start_index = (curr_idx+1) % self.buffer_count + self.start_ts = self.circular_buffer[self.start_index].ts + curr_buf_ts = self.circular_buffer[curr_idx].ts + self.total_duration + + # Set corresponding buffer + self.circular_buffer[curr_idx] = Buffer(curr_buf_ts, self.user_groups) + self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] = 1 + self.circular_buffer[curr_idx].uid_to_tokens_ps[uid] = out_token_len + self.circular_buffer[curr_idx].group_to_reqs_ps[group] = 1 + self.circular_buffer[curr_idx].group_to_tokens_ps[group] = out_token_len + + # Otherwise update corresponding buffer + else: + self.circular_buffer[curr_idx].ts = curr_ts + + if uid in self.circular_buffer[curr_idx].uid_to_reqs_ps: + self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] += 1 + else: + self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] = 1 + + if uid in self.circular_buffer[curr_idx].uid_to_tokens_ps: + self.circular_buffer[curr_idx].uid_to_tokens_ps[uid] += out_token_len + else: + self.circular_buffer[curr_idx].uid_to_tokens_ps[uid] = out_token_len + + self.circular_buffer[curr_idx].group_to_reqs_ps[group] += 1 + self.circular_buffer[curr_idx].group_to_tokens_ps[group] += out_token_len + + def get_user_usage(self, uid: str, group: str): + """ + Calculate usage stats of the given user and group + """ + user_req_usage = 0 + user_token_usage = 0 + group_req_usage = 0 + group_token_usage = 0 + + # TODO: use reader lock + with self.lock: + for i in range(self.buffer_count): + if uid in self.circular_buffer[i].uid_to_reqs_ps: + user_req_usage += self.circular_buffer[i].uid_to_reqs_ps[uid] + user_token_usage += self.circular_buffer[i].uid_to_tokens_ps[uid] + + group_req_usage += self.circular_buffer[i].group_to_reqs_ps[group] + group_token_usage += self.circular_buffer[i].group_to_tokens_ps[group] + + return user_req_usage, user_token_usage, group_req_usage, group_token_usage diff --git a/tests/test_qos/test_continuous.py b/tests/test_qos/test_continuous.py new file mode 100644 index 0000000000..0bfa8a75d5 --- /dev/null +++ b/tests/test_qos/test_continuous.py @@ -0,0 +1,64 @@ +import json +import threading +import time +import requests +from requests import exceptions as rex +from urllib3 import exceptions as uex + +payload_template = { + "model": "internlm-chat-7b", + "messages": "string", + "temperature": 0.7, + "top_p": 1, + "n": 1, + "max_tokens": 128, + "stop": False, + "stream": False, + "presence_penalty": 0, + "frequency_penalty": 0, + "user_id": "template", + "repetition_penalty": 1, + "session_id": -1, + "ignore_eos": False +} + +url = "http://localhost:64546/v1/chat/completions_qos" + +def send_request(payload, stage, i): + # print(f"send: {stage}#{i}, uid: {payload['user_id']}") + try: + with requests.post(url, json=payload) as response: + rep_json = json.loads(response.content) + print(f"{stage}#{i}, uid: {payload['user_id']}, {response.status_code}, {rep_json['usage']['completion_tokens']}") + except (TimeoutError, uex.NewConnectionError, uex.MaxRetryError, rex.ConnectionError) as e: + print(f"{stage}#{i}, uid: {payload['user_id']}, ERROR, {type(e).__name__}") + + +session_id = 0 +def create_thread(uids, req_cnt, stage): + """ + param + uids: list of uids to send requests in one stage + req_cnt: number of requests to send for each uid + stage: stage number + """ + global session_id + for i in range(req_cnt): + for uid in uids: + payload = payload_template.copy() + payload["user_id"] = uid + payload["session_id"] = session_id + session_id += 1 + th = threading.Thread(target=send_request, args=[payload, stage, i]) + th.start() + time.sleep(0.03) + + +print("\nstage 1:\n") +create_thread(["user_id4"], 4000, 1) +print("\nstage 2:\n") +create_thread(["user_id4", "user_id5"], 1000, 2) +print("\nstage 3:\n") +create_thread(["user_id4"], 4000, 3) +print("\nstage 4:\n") +create_thread(["user_id4", "user_id5"], 2000, 4) diff --git a/tests/test_qos/test_puyu_online.py b/tests/test_qos/test_puyu_online.py new file mode 100644 index 0000000000..e3ca57349d --- /dev/null +++ b/tests/test_qos/test_puyu_online.py @@ -0,0 +1,101 @@ +import csv +import os +import sys +import threading +import time +import requests +import json +from requests import exceptions as rex +from urllib3 import exceptions as uex + + +payload_template = { + "model": "internlm-chat-7b", + "messages": "string", + "temperature": 0.7, + "top_p": 1, + "n": 1, + "max_tokens": 512, + "stop": False, + "stream": False, + "presence_penalty": 0, + "frequency_penalty": 0, + "user_id": "template", + "repetition_penalty": 1, + "session_id": -1, + "ignore_eos": False +} +url = "http://localhost:64546/v1/chat/completions_qos" + + +def send_request(payload, i): + try: + start_ts = time.time() + with requests.post(url, json=payload) as response: + rep_json = json.loads(response.content) + end_ts = time.time() + print(f"{i}, {payload['user_id']}, {rep_json['usage']['completion_tokens']}, {end_ts-start_ts}") + except (TimeoutError, uex.NewConnectionError, uex.MaxRetryError, rex.ConnectionError) as e: + print(f"{i}, {payload['user_id']}, ERROR, {type(e).__name__}") + + +def create_threads(uid, messages, intervals, session_ids): + i = 0 + # Set number of requests sent by each use case + while True: + payload = payload_template.copy() + payload["user_id"] = uid + payload["session_id"] = i % (session_ids[1]-session_ids[0]) + session_ids[0] + + payload["messages"] = messages[i % len(messages)] + th = threading.Thread(target=send_request, args=[payload, i]) + th.start() + time.sleep(intervals[i % len(intervals)]) + i += 1 + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python test_regular.py 0|1|2\n\t0: high priority user\n\t1: normal user\n\t2: pressure test") + sys.exit(1) + uid = "" + session_ids = [] + pattern = sys.argv[1] + if pattern == '0': + uid = "user_id0" + session_ids = [0, 6400] + elif pattern == '1': + uid = "user_id3" + session_ids = [6422, 12800] + elif pattern == '2': + uid = "user_id4" + session_ids = [12844, 21600] + else: + print("Usage: python test_regular.py 0|1|2\n\t0: high priority user\n\t1: normal user\n\t2: pressure test") + sys.exit(1) + + # Load files + script_path = os.path.abspath(__file__) + script_directory = os.path.dirname(script_path) + + csv_file_name = script_directory + "/prompt_1115.csv" + messages = [] + with open(csv_file_name, "r") as f: + reader = csv.reader(f) + for row in reader: + # row[0]: prompt + messages.append(row[0]) + + itv_file_name = script_directory + "/" + [ + "interval_high_priority.csv", + "interval_normal_state.csv", + "interval_press_test.csv"][int(pattern)] + intervals = [] + with open(itv_file_name, "r") as f: + reader = csv.reader(f) + for row in reader: + # row[0]: sleep interval + intervals.append(float(row[0])) + + # Send requests + create_threads(uid, messages, intervals, session_ids) diff --git a/tests/test_qos/test_puyu_online.sh b/tests/test_qos/test_puyu_online.sh new file mode 100644 index 0000000000..def92f8fbe --- /dev/null +++ b/tests/test_qos/test_puyu_online.sh @@ -0,0 +1,21 @@ +#!/bin/sh +script_dir=$(dirname "$0") + +python -u $script_dir/test_puyu_online.py 1 > $script_dir/normal_state.log & +normal_pid=$! +echo "normal state: $normal_pid" +sleep 20m + +python -u $script_dir/test_puyu_online.py 2 > $script_dir/press_test.log & +press_pid=$! +echo "pressure test: $press_pid" +sleep 10m + +python -u $script_dir/test_puyu_online.py 0 > $script_dir/high_priority.log & +high_pid=$! +echo "high priority: $high_pid" + +sleep 30m +kill -9 $normal_pid +kill -9 $press_pid +kill -9 $high_pid