From ddfa8c47a4300caa70e50290a7800e862cfe6b2d Mon Sep 17 00:00:00 2001 From: sallyjunjun <72725839+sallyjunjun@users.noreply.github.com> Date: Wed, 27 Dec 2023 16:21:59 +0800 Subject: [PATCH] Support QoS in api_server (#877) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * merge qos feature to sched_main * update doc * Update UserRequestQueue, replace queue.Queue() to improve performance * fix lint and several review comments --------- Co-authored-by: 葛芮君 Co-authored-by: shenliu Co-authored-by: Zhou Zihan --- docs/en/qos.md | 219 ++++++++++ docs/zh_cn/qos.md | 225 ++++++++++ lmdeploy/serve/openai/api_server.py | 386 +++++++++++++++++- lmdeploy/serve/openai/protocol.py | 60 +++ lmdeploy/serve/qos_engine/__init__.py | 1 + lmdeploy/serve/qos_engine/inner_group_schd.py | 72 ++++ .../serve/qos_engine/qos_config.json.template | 58 +++ lmdeploy/serve/qos_engine/qos_engine.py | 255 ++++++++++++ lmdeploy/serve/qos_engine/usage_stats.py | 136 ++++++ 9 files changed, 1407 insertions(+), 5 deletions(-) create mode 100644 docs/en/qos.md create mode 100644 docs/zh_cn/qos.md create mode 100644 lmdeploy/serve/qos_engine/__init__.py create mode 100644 lmdeploy/serve/qos_engine/inner_group_schd.py create mode 100644 lmdeploy/serve/qos_engine/qos_config.json.template create mode 100644 lmdeploy/serve/qos_engine/qos_engine.py create mode 100644 lmdeploy/serve/qos_engine/usage_stats.py diff --git a/docs/en/qos.md b/docs/en/qos.md new file mode 100644 index 0000000000..28d8e41af9 --- /dev/null +++ b/docs/en/qos.md @@ -0,0 +1,219 @@ +## LMDeploy-QoS Introduce and Usage + +### Background + +With the rise of Large Language Model (LLM) and Artificial General Intelligence (AGI), numerous inference frameworks have emerged. These frameworks deliver scalable and high-performance services by serving online workloads with language models. However, these workloads often come from multiple user groups, exhibiting rapid changes in workload patterns within short periods. Many inference frameworks struggle to meet the demands of such multi-tenancy traffic patterns and fail to effectively shape user behaviors. Therefore, we believe that systematically considering these issues in LLM inference framework is both valuable and necessary. + +### User Categorizations for Multi-tenancy Handling + +LMDeploy-QoS is part of LMDeploy, offering a range of multi-tenancy functionalities. It requires users to tag their inference requests with appropriate user identifications (user_id in configuration or codebase). The system operates based on a dictionary-like configuration that serves as a multi-tenancy policy. In this configuration, users are mapped to different classes, known as "user groups", each configured with a ratio value. Our multi-tenancy strategy reads this configuration and schedules user inference requests according to class priority and the difference between the predefined ratio and real-time allocation ratio. Extensive testing shows that LMDeploy-QoS significantly enhances LLM serving reliability and GPU resource utilization for real-world large language model inference workloads. + +We categorize LMDeploy users into four groups: + +- Platinum +- Gold +- Silver +- Bronze + +Based on our experiences in delivering LLM services, we can map the following four types of users to these user groups: + +- Platinum: VIP or administrative users. Examples include service inspectors or product demo presenters who require uninterrupted online services. Their workloads are typically at a low frequency and require limited resources. + +- Gold: Contracted business user groups requiring specific quantities of reliable services. For instance, Company A signs a contract with the LLM service provider to secure X requests/sec service capability with Z% availability for its employees at the cost of Y million dollars per year. + +- Silver: The vast majority of users fall under this category. Most trial or monthly subscribed users are included in this group. They need a relatively small quantity of services, but their user experiences significantly affect the LLM service reputation. + +- Bronze: Heavy users who pay minimal fees to LLM providers. + +The above user group categorization is intended for guidance rather than as a recommendation for all LMDeploy users, as it may not be suitable for all LLM service providers. Users can develop their own method of categorizing users based on their observations of daily workloads. + +Next, we will discuss how LMDeploy schedules requests based on these categorizations. + +### Multi-tenancy Strategies + +#### Strategy 1: prioritized scheduling between groups + +This strategy works as simple as its title suggests. + +User groups are introduced for this strategy, with users in each group to be specified. Recommended user groups are as follows: + +- Platinum +- Gold +- Silver +- Bronze + +The priority of each group decreases sequentially. Requests with higher priority are always given precedence for inference. Be noted that the scheduling is performed at the time of request reception, so lower-priority requests will not be withdrawn from the GPU if they are already under inference. + +The below diagram shows how the prioritization works. As you can see, the platinum request is reprioritized and moved to the queue head. + +![](https://github.com/InternLM/lmdeploy/assets/52888924/9d63f081-7168-4c74-8456-24f0a4b41649) + +#### Strategy 2: proportionally rated scheduling with a pre-defined ratio within user group + +This strategy works only within the user group. We introduce a within-group user quota configuration table. This table defines users' "ideal share ratio" with a sum value of 100% GPU resource. Each "user" appears in the list as a user_id, and a user can only belong to one user group. Requests from different users will be scheduled according to each user's "ideal share ratio". To be specific, users with their real-time usage ratio lower than their quota ratio will have priority over users whose real-time usage ratio is higher than their quota ratio. It is worth noting that the scheduling only considers users in the request queue, ignoring any absent users from the configuration table. + +The below diagram shows a typical example of how this strategy works. + +![](https://github.com/InternLM/lmdeploy/assets/52888924/3e1d7135-6b11-4998-89a1-b72af6c962c3) + +#### Strategy 3: a combination strategy of 1 and 2 + +We can call it a hybrid strategy. The way we hybrid these 2 strategies is fairly simple: we adopt strategy 1 in between user groups, and adopt strategy 2 within a user group. So users belonging to different groups with different priorities will only obey strategy 1 to determine their privilege in resource allocation. That is, when both strategies are applied, the first strategy will overpower the second. When it comes to a situation that no cross-group requests are waiting for serving, the within-group strategy 2 comes into play. + +Below is a diagram showing it. + +![](https://github.com/InternLM/lmdeploy/assets/52888924/e335f976-ff15-48db-b1ff-abf1c3327d6e) + +To be noted, there could be other ways of hybrid strategies 1 & 2, and this doc only introduces one method that works well in our scenario. Considering that prioritization and pro-rated sharing are obviously conflicting strategies, there is no easy way to mix them to work within a single dimension. + +### A Sample QoS Configuration + +The configuration will be specified by the `--qos_config_path` flag, and will be loaded by program upon startup. + +```json +{ + "enable_user_qos": true, + "user_groups": [ + "Platinum", + "Gold", + "Silver", + "Bronze" + ], + "user_group_map": { + "Platinum": [ + { + "id": "user_id0", + "quota_pct": 100 + }, + { + "id": "default", + "quota_pct": 0 + } + ], + "Gold": [ + { + "id": "user_id1", + "quota_pct": 50 + }, + { + "id": "user_id2", + "quota_pct": 50 + } + ], + "Silver": [ + { + "id": "user_id3", + "quota_pct": 5 + }, + { + "id": "default", + "quota_pct": 95 + } + ], + "Bronze": [ + { + "id": "user_id4", + "quota_pct": 30 + }, + { + "id": "user_id5", + "quota_pct": 30 + }, + { + "id": "user_id6", + "quota_pct": 40 + }, + { + "id": "default", + "quota_pct": 0 + } + ] + } +} +``` + +### How to perform inference job with Lmdeploy-QoS aware + +We provide the code link below to show how to call infer requests with multi-tenancy strategy awarded. What the qos related argument appears as in http body: + +/v1/chat/interactive_qos + +```bash +curl -X POST http://localhost/v1/chat/interactive_qos \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "Hello,Hello", + "session_id": -1, + "interactive_mode": false, + "stream": false, + "stop": false, + "request_output_len": 512, + "top_p": 0.8, + "top_k": 40, + "temperature": 0.8, + "repetition_penalty": 1, + "ignore_eos": false, + "user_id": "user_id0" +}' +``` + +/v1/chat/completions_qos + +```bash +curl -X POST http://localhost/v1/chat/completions_qos \ + -H "Content-Type: application/json" \ + -d '{ + "model": "internlm-chat-7b", + "messages": "Hello,Hello", + "temperature": 0.7, + "top_p": 1, + "n": 1, + "max_tokens": 512, + "stop": false, + "stream": false, + "presence_penalty": 0, + "frequency_penalty": 0, + "repetition_penalty": 1, + "session_id": -1, + "ignore_eos": false, + "user_id": "user_id0" +}' +``` + +/v1/completions_qos + +```bash +curl -X POST http://localhost/v1/completions_qos \ + -H "Content-Type: application/json" \ + -d '{ + "model": "internlm-chat-7b", + "prompt": "Hello,Hello", + "suffix": "string", + "temperature": 0.7, + "n": 1, + "max_tokens": 16, + "stop": "string", + "stream": false, + "top_p": 1, + "repetition_penalty": 1, + "session_id": -1, + "ignore_eos": false, + "user_id": "user_id0" +}' +``` + +### File Configuration Modification + +The template of the configuration file is located at: `lmdeploy/server/qos_engine/qos_config.json.template`. Add the necessary users based on actual requirements, ensure correct priority assignment, and set appropriate quota values. + +### Passing Configuration Parameters + +Upon starting the api_server, pass the configuration file and its path using the `--qos_config_path` flag. An example is illustrated below: + +```bash +CUDA_VISIBLE_DEVICES=0 lmdeploy serve api_server InternLM/internlm-chat-7b --server_port 8000 --qos_config_path lmdeploy/serve/qos_engine/qos_config.json.template +``` + +### Contributor + +[Eric](https://github.com/rhinouser0), [sallyjunjun](https://github.com/sallyjunjun), [sfireworks](https://github.com/sfireworks), [Dofgal](https://github.com/Dofgal), [shadow](https://github.com/awslshadowstar) diff --git a/docs/zh_cn/qos.md b/docs/zh_cn/qos.md new file mode 100644 index 0000000000..0d6cfc5c39 --- /dev/null +++ b/docs/zh_cn/qos.md @@ -0,0 +1,225 @@ +## LMDeploy-QoS 介绍与用法 + +### 背景 + +在过去一段时间,推理框架伴随着LLM和AGI出现。许多推理框架为语言模型提供可扩展和高性能的在线工作负载服务。它们的工作负载通常涉及多个用户群体,而且工作负载在短时间内快速变化。许多推理框架在满足这些多租户流量模式的要求方面存在困难,而且未能很好的规范约束用户的行为,所以我们认为在LLM推理框架考虑多用户负载均衡是很有必要的。 + +### 多租户处理的用户分类 + +LMDeploy-QoS与LMDeploy 提供一系列多租户功能。它要求用户使用适当的用户标识(配置文件或代码库中的user_id)标记其推理请求。它是基于字典的配置作为多租户策略。在这个配置中,用户被映射到不同“用户组”中,并配备一个使用配额。我们的多租户策略可以读取配置,并根据其用户组的优先级和预定义配额与实时分配比率之间的差异安排用户推理请求的调度。经过完备的测试,我们的LMDeploy-QoS模块极大地提高了LLM的服务可靠性并提升了大型语言模型推理工作的GPU资源利用率。 + +LMDeploy将用户分为4组: + +- 白金(Platinum) +- 金(Gold) +- 银(Silver) +- 青铜(Bronze) + +根据我们在提供LLM服务方面的使用经验,我们可以将以下4种类型的用户映射到这些用户组中: + +- Platinum : VIP用户或管理员用户。包括需要不间断使用的的服务开发人员或演示人员。他们的工作负载频率低,对推理工作的资源需求也不高。 + +- Gold : 签署定期服务的高级用户,他们需要可衡量的可靠服务。例如,某个公司A与LLM服务提供商签订了合同,购买了每秒X个请求的服务能力,可用性为Z%,供A公司员工使用,年付Y百万美元。 + +- Silver : 绝大多数用户。大多数试用或每月订阅的用户被归类为此类别。他们需要相对较少的服务,但他们的用户体验对于LLM服务的声誉也很重要。 + +- Bronze : 支付很少费用给LLM提供商的重度用户。 + +以上引入用户组分类的目的是为了提供指导,而不是为所有LMDeploy用户提供建议,因为这并不一定适用于所有LLM业务提供商。管理员可以对用户的日常负载进行统计,自行决定如何对用户进行分类。 + +接下来让我们讨论一下LMDeploy如何根据这些分类进行分配请求。 + +### 多租户策略 + +#### 策略 1: 用户组之间的优先级调度 + +我们引入“用户组”概念。由模块使用者来定义哪些用户到用户组的映射(可以理解为 uid 到用户组的映射)。推荐用户组为4组如下: + +- Platinum +- Gold +- Silver +- Bronze + +四个用户组之间的优先级顺序是严格的 Platinum > Gold > Silver > Bronze 。当系统繁忙的时候,我们会优先执行排名靠前的请求。 + +下面的图表显示了优先级处理的工作原理。您可以看到 Platinum 请求已被重新设置优先级并移至队列头部。 + +![](https://github.com/InternLM/lmdeploy/assets/52888924/9d63f081-7168-4c74-8456-24f0a4b41649) + +#### 策略 2: 用户组内均摊与软隔离 + +这个策略仅适用于用户组内部。我们引入了一个用户组内的用户配额配置表。该表定义了用户在 100% GPU 资源中的 “理想份额比例”。每个 “用户” 在列表中以 user_id 的形式出现,并且一个用户只能属于一个用户组。低于配额表上额定值的用户会比高于额定值的用户拥有更高的优先级获得被释放资源而进行更多的推理,直到双方使用量趋近于原始配额比例。此处调度只考虑请求队列中的用户,忽略没有出现在请求队列中的已配置用户。 + +以下图表展示了这种策略的典型示例。 + +![](https://github.com/InternLM/lmdeploy/assets/52888924/3e1d7135-6b11-4998-89a1-b72af6c962c3) + +#### 策略3:混合机制 + +是指在一个系统中优先级+均摊/隔离同时开启。执行顺序是先用户组间优先级,再在组内做均摊/隔离实现。这里略去时序图描写。需要注意的是,用户组间的优先级可以压倒性覆盖组内的决策。例如,当低优先级内部的两个用户互相之间有请求顺序调度时,高优先级的请求一旦抵达,将会覆盖所有低优先级的分配逻辑而有限执行高优任务。 + +![](https://github.com/InternLM/lmdeploy/assets/52888924/e335f976-ff15-48db-b1ff-abf1c3327d6e) + +需要注意的是,混合机制可能有其他方法,本文档只介绍了一种在我们场景下有效的方法。其他混合方法需要考虑到优先级和按比例共享明显是相互冲突的策略,因此没有简单的方法将它们混合在单一维度内工作。 + +### QoS 配置项模板 + +配置文件通过启动参数`--qos_config_path`指定,并由程序在启动时加载。 + +配置会和lmdeploy启动脚本等文件放置在一起。配置内容包含: + +1. QoS的启用开关,设置为True时后续的QoS和用户相关配置才会生效,设置为False后续配置不会生效; + +2. user_groups 是一个列表,包含了多种不同的组间优先级; + +3. user_group_map 的映射配置,包含了用户组优先级,组内用户id以及每个用户组内用户的配额分配。 + +配置项模板如下: + +```json +{ + "enable_user_qos": true, + "user_groups": [ + "Platinum", + "Gold", + "Silver", + "Bronze" + ], + "user_group_map": { + "Platinum": [ + { + "id": "user_id0", + "quota_pct": 100 + }, + { + "id": "default", + "quota_pct": 0 + } + ], + "Gold": [ + { + "id": "user_id1", + "quota_pct": 50 + }, + { + "id": "user_id2", + "quota_pct": 50 + } + ], + "Silver": [ + { + "id": "user_id3", + "quota_pct": 5 + }, + { + "id": "default", + "quota_pct": 95 + } + ], + "Bronze": [ + { + "id": "user_id4", + "quota_pct": 30 + }, + { + "id": "user_id5", + "quota_pct": 30 + }, + { + "id": "user_id6", + "quota_pct": 40 + }, + { + "id": "default", + "quota_pct": 0 + } + ] + } +} +``` + +### 如何使用 LMDeploy-QoS 感知进行推理 + +我们提供以下代码链接,展示如何调用具有多租户策略感知的推理请求,在 HTTP Body 中,与 QoS 相关的参数如下: + +/v1/chat/interactive_qos + +```bash +curl -X POST http://localhost/v1/chat/interactive_qos \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "Hello,Hello", + "session_id": -1, + "interactive_mode": false, + "stream": false, + "stop": false, + "request_output_len": 512, + "top_p": 0.8, + "top_k": 40, + "temperature": 0.8, + "repetition_penalty": 1, + "ignore_eos": false, + "user_id": "user_id0" +}' +``` + +/v1/chat/completions_qos + +```bash +curl -X POST http://localhost/v1/chat/completions_qos \ + -H "Content-Type: application/json" \ + -d '{ + "model": "internlm-chat-7b", + "messages": "Hello,Hello", + "temperature": 0.7, + "top_p": 1, + "n": 1, + "max_tokens": 512, + "stop": false, + "stream": false, + "presence_penalty": 0, + "frequency_penalty": 0, + "repetition_penalty": 1, + "session_id": -1, + "ignore_eos": false, + "user_id": "user_id0" +}' +``` + +/v1/completions_qos + +```bash +curl -X POST http://localhost/v1/completions_qos \ + -H "Content-Type: application/json" \ + -d '{ + "model": "internlm-chat-7b", + "prompt": "Hello,Hello", + "suffix": "string", + "temperature": 0.7, + "n": 1, + "max_tokens": 16, + "stop": "string", + "stream": false, + "top_p": 1, + "repetition_penalty": 1, + "session_id": -1, + "ignore_eos": false, + "user_id": "user_id0" +}' +``` + +### 配置文件修改 + +配置文件模板路径为:`lmdeploy/server/qos_engine/qos_config.json.template`,可以根据实际需求添加需要配置的用户,设置正确的优先级以及quota值。 + +### 配置参数传入 + +启动api_server时,通过`--qos_config_path`,将配置文件及路径传入,示例如下: + +```bash +CUDA_VISIBLE_DEVICES=0 lmdeploy serve api_server InternLM/internlm-chat-7b --server_port 8000 --qos_config_path lmdeploy/serve/qos_engine/qos_config.json.template +``` + +### 贡献者 + +[Eric](https://github.com/rhinouser0), [sallyjunjun](https://github.com/sallyjunjun), [sfireworks](https://github.com/sfireworks), [Dofgal](https://github.com/Dofgal), [shadow](https://github.com/awslshadowstar) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 7867c6ca7c..f821897331 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -13,19 +13,21 @@ from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.protocol import ( # noqa: E501 - ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionRequest, ChatCompletionRequestQos, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, CompletionRequest, - CompletionResponse, CompletionResponseChoice, + CompletionRequestQos, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse, - GenerateRequest, GenerateResponse, ModelCard, ModelList, ModelPermission, - UsageInfo) + GenerateRequest, GenerateRequestQos, GenerateResponse, ModelCard, + ModelList, ModelPermission, UsageInfo) +from lmdeploy.serve.qos_engine.qos_engine import QosEngine class VariableInterface: """A IO interface maintaining variables.""" async_engine: AsyncEngine = None + qos_engine: QosEngine = None request_hosts = [] @@ -84,6 +86,148 @@ def ip2id(host_ip: str): return 0 +@app.post('/v1/chat/completions_qos') +async def chat_completions_v1_qos(request: ChatCompletionRequestQos, + raw_request: Request = None): + """Completion API similar to OpenAI's API. + + Refer to `https://platform.openai.com/docs/api-reference/chat/create` + for the API specification. + + The request should be a JSON object with the following fields: + - model: model name. Available from /v1/models. + - messages: string prompt or chat history in OpenAI format. + - temperature (float): to modulate the next token probability + - top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + - n (int): How many chat completion choices to generate for each input + message. Only support one here. + - stream: whether to stream the results or not. Default to false. + - max_tokens (int): output token nums + - repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + + Additional arguments supported by LMDeploy: + - ignore_eos (bool): indicator for ignoring eos + - session_id (int): if not specified, will set random value + - user_id (str): for qos; if not specified, will set to "default" + + Currently we do not support the following features: + - function_call (Users should implement this by themselves) + - logit_bias (not supported yet) + - presence_penalty (replaced with repetition_penalty) + - frequency_penalty (replaced with repetition_penalty) + """ + if request.session_id == -1: + request.session_id = random.randint(1, 10086) + error_check_ret = await check_request(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = str(request.session_id) + created_time = int(time.time()) + + if VariableInterface.qos_engine is None: + return create_error_response( + HTTPStatus.NOT_FOUND, + 'cannot parse qos engine config, this api is not work') + + result_generator = await VariableInterface.qos_engine.generate_with_qos( + request) + + if result_generator is None: + return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, + 'Failed to generate completions') + + def create_stream_response_json( + index: int, + text: str, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(role='assistant', content=text), + finish_reason=finish_reason, + ) + response = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.model_dump_json() + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # First chunk with role + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role='assistant'), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse(id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f'data: {data}\n\n' + + async for res in result_generator: + response_json = create_stream_response_json( + index=0, + text=res.response, + ) + yield f'data: {response_json}\n\n' + yield 'data: [DONE]\n\n' + + # Streaming response + if request.stream: + return StreamingResponse(completion_stream_generator(), + media_type='text/event-stream') + + # Non-streaming response + final_res = None + text = '' + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + VariableInterface.async_engine.stop_session(request.session_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Client disconnected') + final_res = res + text += res.response + assert final_res is not None + choices = [] + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role='assistant', content=text), + finish_reason=final_res.finish_reason, + ) + choices.append(choice_data) + + total_tokens = sum([ + final_res.history_token_len, final_res.input_token_len, + final_res.generate_token_len + ]) + usage = UsageInfo( + prompt_tokens=final_res.input_token_len, + completion_tokens=final_res.generate_token_len, + total_tokens=total_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + @app.post('/v1/chat/completions') async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None): @@ -230,6 +374,154 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return response +@app.post('/v1/completions_qos') +async def completions_v1_qos(request: CompletionRequestQos, + raw_request: Request = None): + """Completion API similar to OpenAI's API. + + Go to `https://platform.openai.com/docs/api-reference/completions/create` + for the API specification. + + The request should be a JSON object with the following fields: + - model (str): model name. Available from /v1/models. + - prompt (str): the input prompt. + - suffix (str): The suffix that comes after a completion of inserted text. + - max_tokens (int): output token nums + - temperature (float): to modulate the next token probability + - top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + - n (int): How many chat completion choices to generate for each input + message. Only support one here. + - stream: whether to stream the results or not. Default to false. + - repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + - user (str): A unique identifier representing your end-user. + + Additional arguments supported by LMDeploy: + - top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering + - ignore_eos (bool): indicator for ignoring eos + - session_id (int): if not specified, will set random value + - user_id (str): for qos; if not specified, will set to "default" + + Currently we do not support the following features: + - logprobs (not supported yet) + - presence_penalty (replaced with repetition_penalty) + - frequency_penalty (replaced with repetition_penalty) + """ + if request.session_id == -1: + request.session_id = random.randint(1, 10086) + error_check_ret = await check_request(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = str(request.session_id) + created_time = int(time.time()) + if isinstance(request.prompt, str): + request.prompt = [request.prompt] + + if VariableInterface.qos_engine is None: + return create_error_response( + HTTPStatus.NOT_FOUND, + 'cannot parse qos engine config, this api is not work') + + generators = await VariableInterface.qos_engine.generate_with_qos(request) + + def create_stream_response_json( + index: int, + text: str, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.model_dump_json() + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # First chunk with role + for generator in generators: + for i in range(request.n): + choice_data = CompletionResponseStreamChoice( + index=i, + text='', + finish_reason=None, + ) + chunk = CompletionStreamResponse(id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f'data: {data}\n\n' + + async for res in generator: + response_json = create_stream_response_json( + index=0, + text=res.response, + ) + yield f'data: {response_json}\n\n' + yield 'data: [DONE]\n\n' + + # Streaming response + if request.stream: + return StreamingResponse(completion_stream_generator(), + media_type='text/event-stream') + + # Non-streaming response + usage = UsageInfo() + choices = [] + + async def _inner_call(i, generator): + final_res = None + text = '' + async for res in generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + VariableInterface.async_engine.stop_session(request.session_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Client disconnected') + final_res = res + text += res.response + assert final_res is not None + choice_data = CompletionResponseChoice( + index=0, + text=text, + finish_reason=final_res.finish_reason, + ) + choices.append(choice_data) + + total_tokens = sum([ + final_res.history_token_len, final_res.input_token_len, + final_res.generate_token_len + ]) + usage.prompt_tokens += final_res.input_token_len + usage.completion_tokens += final_res.generate_token_len + usage.total_tokens += total_tokens + + await asyncio.gather( + *[_inner_call(i, generators[i]) for i in range(len(generators))]) + + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + @app.post('/v1/completions') async def completions_v1(request: CompletionRequest, raw_request: Request = None): @@ -428,6 +720,76 @@ def encode(prompt: str, do_preprocess: bool, add_bos: bool): return EncodeResponse(input_ids=encoded, length=length) +@app.post('/v1/chat/interactive_qos') +async def chat_interactive_v1_qos(request: GenerateRequestQos, + raw_request: Request = None): + """Generate completion for the request. + + - On interactive mode, the chat history is kept on the server. Please set + `interactive_mode = True`. + - On normal mode, no chat history is kept on the server. Set + `interactive_mode = False`. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - session_id: determine which instance will be called. If not specified + with a value other than -1, using random value directly. + - interactive_mode (bool): turn on interactive mode or not. On interactive + mode, session history is kept on the server (and vice versa). + - stream: whether to stream the results or not. + - stop: whether to stop the session response or not. + - request_output_len (int): output token nums + - top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + - top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering + - temperature (float): to modulate the next token probability + - repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + - ignore_eos (bool): indicator for ignoring eos + - user_id (str): for qos; if not specified, will set to "default" + """ + if request.session_id == -1: + request.session_id = random.randint(10087, 23333) + + if VariableInterface.qos_engine is None: + return create_error_response( + HTTPStatus.NOT_FOUND, + 'cannot parse qos engine config, this api is not work') + + generation = await VariableInterface.qos_engine.generate_with_qos(request) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for out in generation: + chunk = GenerateResponse(text=out.response, + tokens=out.generate_token_len, + finish_reason=out.finish_reason) + data = chunk.model_dump_json() + yield f'{data}\n' + + if request.stream: + return StreamingResponse(stream_results(), + media_type='text/event-stream') + else: + ret = {} + text = '' + tokens = 0 + finish_reason = None + async for out in generation: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + VariableInterface.qos_engine.stop_session(request.session_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Client disconnected') + text += out.response + tokens = out.generate_token_len + finish_reason = out.finish_reason + ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason} + return JSONResponse(ret) + + @app.post('/generate', tags=['deprecated'], description='please use /v1/chat/interactive') @@ -522,6 +884,7 @@ def serve(model_path: str, allow_methods: List[str] = ['*'], allow_headers: List[str] = ['*'], log_level: str = 'ERROR', + qos_config_path: str = '', **kwargs): """An example to perform model inference through the command line interface. @@ -551,6 +914,7 @@ def serve(model_path: str, allow_methods (List[str]): a list of allowed HTTP methods for CORS allow_headers (List[str]): a list of allowed HTTP headers for CORS log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] + qos_config_path (str): qos policy config path """ # noqa E501 os.environ['TM_LOG_LEVEL'] = log_level @@ -562,12 +926,24 @@ def serve(model_path: str, allow_methods=allow_methods, allow_headers=allow_headers, ) - VariableInterface.async_engine = AsyncEngine(model_path=model_path, model_name=model_name, instance_num=instance_num, tp=tp, **kwargs) + + if qos_config_path: + try: + with open(qos_config_path, 'r') as file: + qos_config_str = file.read() + VariableInterface.qos_engine = QosEngine( + qos_tag=qos_config_str, + engine=VariableInterface.async_engine, + **kwargs) + VariableInterface.qos_engine.start() + except FileNotFoundError: + VariableInterface.qos_engine = None + for i in range(3): print(f'HINT: Please open \033[93m\033[1mhttp://{server_name}:' f'{server_port}\033[0m in a browser for detailed api usage!!!') diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 942f554034..53e247510d 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -55,6 +55,26 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 +class ChatCompletionRequestQos(BaseModel): + """Chat completion request.""" + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = 512 + stop: Optional[bool] = False + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + user_id: Optional[str] = None + # additional argument of lmdeploy + repetition_penalty: Optional[float] = 1.0 + session_id: Optional[int] = -1 + ignore_eos: Optional[bool] = False + + class ChatCompletionRequest(BaseModel): """Chat completion request.""" model: str @@ -142,6 +162,30 @@ class CompletionRequest(BaseModel): top_k: Optional[int] = 40 # for opencompass +class CompletionRequestQos(BaseModel): + """Completion request.""" + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + # additional argument of lmdeploy + top_k: int = 40 + repetition_penalty: Optional[float] = 1.0 + session_id: Optional[int] = -1 + ignore_eos: Optional[bool] = False + user_id: Optional[str] = None + + class CompletionResponseChoice(BaseModel): """Completion response choices.""" index: int @@ -220,6 +264,22 @@ class GenerateRequest(BaseModel): ignore_eos: bool = False +class GenerateRequestQos(BaseModel): + """Generate request.""" + prompt: Union[str, List[Dict[str, str]]] + session_id: int = -1 + interactive_mode: bool = False + stream: bool = False + stop: bool = False + request_output_len: int = 512 + top_p: float = 0.8 + top_k: int = 40 + temperature: float = 0.8 + repetition_penalty: float = 1.0 + ignore_eos: bool = False + user_id: Optional[str] = None + + class GenerateResponse(BaseModel): """Generate response.""" text: str diff --git a/lmdeploy/serve/qos_engine/__init__.py b/lmdeploy/serve/qos_engine/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/serve/qos_engine/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/serve/qos_engine/inner_group_schd.py b/lmdeploy/serve/qos_engine/inner_group_schd.py new file mode 100644 index 0000000000..73a57e1efa --- /dev/null +++ b/lmdeploy/serve/qos_engine/inner_group_schd.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import logging + +logger = logging.getLogger(__name__) + + +class UserRequestQueue: + """Inner group user request queues.""" + + def __init__(self, group: str, user_id_map: dict): + self.group = group + self.user_queue_map = dict() + self.user_quota_map = dict() + self.user_id_maps = user_id_map + + total_quota = 0 + for item in user_id_map: + total_quota += item['quota_pct'] + for item in user_id_map: + user_id = item['id'] + self.user_queue_map[user_id] = collections.deque() + self.user_quota_map[user_id] = item['quota_pct'] / total_quota + + def enqueue(self, request_event): + """Enqueue request to corresponding user queue.""" + if request_event[0].user_id in self.user_queue_map: + self.user_queue_map[request_event[0].user_id].append(request_event) + else: + self.user_queue_map['default'].append(request_event) + + def empty(self): + """Whether all user queues are empty.""" + for _, req_queue in self.user_queue_map.items(): + if len(req_queue) != 0: + return False + return True + + def dequeue(self, usage_stats): + """Dequeue the request to serve.""" + uid_to_serve = self.user_to_serve(usage_stats) + if uid_to_serve in self.user_queue_map: + return self.user_queue_map[uid_to_serve].popleft() + + return None + + def user_to_serve(self, usage_stats): + """Inner group scheduling. + + Find the user to serve from user request queues. + """ + min_usage = 100 + uid_to_serve = '' + for uid, req_queue in self.user_queue_map.items(): + if len(req_queue) == 0: + continue + + # TODO: include token length + # Calculate current user's actual used share and quota share + user_usage, _, group_usage, _ = usage_stats.get_user_usage( + uid, self.group) + actual_share = (user_usage / group_usage) if group_usage > 0 else 0 + due_share = self.user_quota_map[uid] + + # Serve the user with the relatively least usage share + curr_usage = (actual_share / due_share) if due_share > 0 else 0 + if curr_usage == 0: + return uid + if curr_usage < min_usage: + uid_to_serve = uid + min_usage = curr_usage + return uid_to_serve diff --git a/lmdeploy/serve/qos_engine/qos_config.json.template b/lmdeploy/serve/qos_engine/qos_config.json.template new file mode 100644 index 0000000000..1120fbdd27 --- /dev/null +++ b/lmdeploy/serve/qos_engine/qos_config.json.template @@ -0,0 +1,58 @@ +{ + "enable_user_qos": 1, + "user_groups": ["Platinum", "Gold", "Silver", "Bronze"], + "user_group_map": { + "Platinum": [ + { + "id": "user_id0", + "quota_pct": 100 + }, + { + "id": "default", + "quota_pct": 0 + } + ], + "Gold": [ + { + "id": "user_id1", + "quota_pct": 50 + }, + { + "id": "user_id2", + "quota_pct": 50 + }, + { + "id": "default", + "quota_pct": 0 + } + ], + "Silver": [ + { + "id": "user_id3", + "quota_pct": 5 + }, + { + "id": "default", + "quota_pct": 95 + } + ], + "Bronze": [ + { + "id": "user_id4", + "quota_pct": 30 + }, + { + "id": "user_id5", + "quota_pct": 30 + }, + { + "id": "user_id6", + "quota_pct": 40 + }, + { + "id": "default", + "quota_pct": 0 + } + ] + } +} diff --git a/lmdeploy/serve/qos_engine/qos_engine.py b/lmdeploy/serve/qos_engine/qos_engine.py new file mode 100644 index 0000000000..ca5762de54 --- /dev/null +++ b/lmdeploy/serve/qos_engine/qos_engine.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import json +import logging +import threading +import time +from typing import List + +from lmdeploy.serve.openai.protocol import (ChatCompletionRequestQos, + CompletionRequestQos, + GenerateRequestQos) +from lmdeploy.serve.qos_engine.inner_group_schd import UserRequestQueue +from lmdeploy.serve.qos_engine.usage_stats import UsageStats + +logger = logging.getLogger(__name__) + + +class QosConfig: + """qos config class: parse qosconfig for qos engine.""" + + def __init__(self, qos_tag=''): + qos_config = json.loads(qos_tag) + self.is_qos_enabled = qos_config.get('enable_user_qos', False) + logger.debug(f'is_qos_enabled: {self.is_qos_enabled}') + + if self.is_qos_enabled: + self.user_id_maps = qos_config['user_group_map'] + self.user_group_prio = qos_config['user_groups'] + logger.debug(f'user_id_maps: {self.user_id_maps}') + logger.debug(f'user_group_prio: {self.user_group_prio}') + + +class QosEngine: + """impl for qos engine, docs/en/qos.md.""" + + def __init__(self, qos_tag='', engine=None, **kwargs) -> None: + self.engine = engine + self.availSlots = engine.instance_num + self._stop_event = threading.Event() + self._dequeue_thread = threading.Thread(target=self._serve, + daemon=True) + self.qos_config = QosConfig(qos_tag) + + self.qos_user_group = QosGroupQueue(self.qos_config) + + self.usage_stats = UsageStats( + total_duration=60, + buffer_count=6, + start_index=0, + user_groups=self.qos_config.user_group_prio) + self.user_served_reqs = dict() + self._dump_stats_thread = threading.Thread(target=self._dump_stats, + daemon=True) + + self.lock = threading.Lock() + self.stats_lock = threading.Lock() + + def start(self): + """start qos engine.""" + if self.is_qos_enabled(): + self._dequeue_thread.start() + self._dump_stats_thread.start() + + def is_qos_enabled(self): + """check while qos engine is enabled.""" + return self.qos_config.is_qos_enabled + + def stop_session(self, session_id: int): + """Stop a session by a session_id.""" + self.engine.stop_session(session_id) + + async def generate(self, request): + """entry of qos engine generate for three api.""" + if isinstance(request, CompletionRequestQos): + if isinstance(request.prompt, str): + request.prompt = [request.prompt] + generators = [] + for i in range(len(request.prompt)): + result_generator = self.engine.generate( + request.prompt[i], + request.session_id + i, + True, # always use stream to enable batching + sequence_start=True, + sequence_end=True, + request_output_len=request.max_tokens + if request.max_tokens else 512, + stop=False, + top_p=request.top_p, + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + ignore_eos=request.ignore_eos, + do_preprocess=False) + generators.append(result_generator) + return generators + + elif isinstance(request, GenerateRequestQos): + async_engine = self.engine + sequence_start = async_engine.id2step.get(str(request.session_id), + 0) == 0 + sequence_end = not request.interactive_mode + + generation = async_engine.generate( + request.prompt, + request.session_id, + stream_response=True, # always use stream to enable batching + sequence_start=sequence_start, + sequence_end=sequence_end, + request_output_len=request.request_output_len, + top_p=request.top_p, + top_k=request.top_k, + stop=request.stop, + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + ignore_eos=request.ignore_eos) + return generation + + elif isinstance(request, ChatCompletionRequestQos): + # default chat/completions + result_generator = self.engine.generate( + request.messages, + request.session_id, + True, # always use stream to enable batching + sequence_start=True, + sequence_end=True, + request_output_len=request.max_tokens + if request.max_tokens else 512, + stop=request.stop, + top_p=request.top_p, + temperature=request.temperature, + repetition_penalty=request.repetition_penalty, + ignore_eos=request.ignore_eos) + return result_generator + + return time.sleep(0.01) + + async def generate_with_qos(self, request): + """called by api server for qos generate.""" + if not self.is_qos_enabled(): + return await self.generate(request) + + # push (request,event) to queue + event = asyncio.Event() + request_event = (request, event) + with self.lock: + self.qos_user_group.enqueue(request_event) + + await event.wait() + + result_generator = await self.generate(request) + + # release self.availSlots resources + with self.lock: + if isinstance(request, CompletionRequestQos) and isinstance( + request.prompt, List): + self.availSlots += len(request.prompt) + else: + self.availSlots += 1 + + # Update number of served requests for each user + with self.stats_lock: + if request.user_id not in self.user_served_reqs: + self.user_served_reqs[request.user_id] = 1 + else: + self.user_served_reqs[request.user_id] += 1 + + return result_generator + + def _serve(self): + """backend thread for dequeue.""" + while not self._stop_event.is_set(): + if self.availSlots > 0: + with self.lock: + request_event = self.dequeue(self.usage_stats) + if request_event is not None: + # Update usage_stats + user_group = self.qos_user_group.get_user_group( + request_event[0].user_id) + self.usage_stats.update_usage(request_event[0].user_id, + user_group, 100, + int(time.time())) + if isinstance(request_event[0], + CompletionRequestQos) and isinstance( + request_event[0].prompt, List): + self.availSlots -= len(request_event[0].prompt) + else: + self.availSlots -= 1 + request_event[1].set() + logger.debug( + f'Available slot decrease, now: {self.availSlots}') + time.sleep(0) + + def _dump_stats(self): + """dump usage states for debugs.""" + ts = 0 + while not self._stop_event.is_set(): + outdata = '' + with self.stats_lock: + if not self.user_served_reqs: + outdata = 'none' + else: + sorted_uids = sorted(self.user_served_reqs.keys()) + for uid in sorted_uids: + outdata += f'{uid} {self.user_served_reqs[uid]} reqs, ' + self.user_served_reqs = dict() + logger.info( + f'qos svc running for {ts} seconds,last 20 seconds: {outdata}') + ts += 20 + time.sleep(20) + + def dequeue(self, usage_stats): + """dequeue from multiqueue.""" + return self.qos_user_group.dequeue(usage_stats) + + +class QosGroupQueue: + """create groups for qos outer group schedule.""" + + def __init__(self, qos_config): + if qos_config is None: + self.user_list = {} + self.queues = {} + else: + self.user_list = qos_config.user_id_maps + self.queues = {} + for user_group in qos_config.user_group_prio: + self.queues[user_group] = UserRequestQueue( + user_group, self.user_list[user_group]) + self.user_group_list = list(self.user_list.keys()) + self.default_user_group = self.user_group_list[2] if len( + self.user_group_list) >= 3 else 'None' + logger.debug(self.user_list) + logger.debug(self.queues) + logger.debug(self.default_user_group) + + def get_user_group(self, user_id): + """input: user, output user_id""" + for category, users in self.user_list.items(): + for user in users: + if user_id == user['id']: + return category + return self.default_user_group + + def enqueue(self, request_event): + """enqueue outer group waiting for schedule.""" + user_id = self.get_user_group(request_event[0].user_id) + self.queues[user_id].enqueue(request_event) + + def dequeue(self, usage_stats): + """dequeue outer group schedule.""" + for user_group_id, user_group_queue in self.queues.items(): + if user_group_queue.empty(): + continue + else: + return user_group_queue.dequeue(usage_stats) + return None diff --git a/lmdeploy/serve/qos_engine/usage_stats.py b/lmdeploy/serve/qos_engine/usage_stats.py new file mode 100644 index 0000000000..05ed97cc13 --- /dev/null +++ b/lmdeploy/serve/qos_engine/usage_stats.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import threading +from typing import List + + +class Buffer: + """Ring buffer for calculate tokens and reqs usage.""" + + def __init__(self, ts: int, user_groups: List[str]): + self.ts = ts + # Per user usage + self.uid_to_tokens_ps = dict() + self.uid_to_reqs_ps = dict() + + # Per group usage + self.group_to_tokens_ps = dict() + self.group_to_reqs_ps = dict() + + for group in user_groups: + self.group_to_tokens_ps[group] = 0 + self.group_to_reqs_ps[group] = 0 + + +class UsageStats: + """calculate usage for qos engine for inner group schedule.""" + + def __init__(self, total_duration: int, buffer_count: int, + start_index: int, user_groups: List[str]): + self.total_duration = total_duration + self.buffer_count = buffer_count + self.start_index = start_index + self.start_ts = int(0) + + self.buffer_duration = int(total_duration / buffer_count) + self.circular_buffer = [ + Buffer(self.buffer_duration * i, user_groups) + for i in range(buffer_count) + ] + + self.user_groups = user_groups + + self.lock = threading.Lock() + + def update_usage(self, uid: str, group: str, out_token_len: int, + req_ts: int): + """Update UsageStats when a request is returned.""" + with self.lock: + intervals = int((req_ts - self.start_ts) / self.buffer_duration) + + curr_idx = (self.start_index + intervals) % self.buffer_count + curr_ts = self.start_ts + intervals * self.buffer_duration + + # Current request outside the sliding window + if intervals >= self.buffer_count: + reset_buf_cnt = intervals - self.buffer_count + curr_buf_ts = 0 + + if reset_buf_cnt >= self.buffer_count: + # All buffers are reset + for i in range(1, self.buffer_count): + reset_idx = (curr_idx + i) % self.buffer_count + self.circular_buffer[reset_idx] = Buffer( + req_ts + i * self.buffer_duration, + self.user_groups) + # Update self.start_index + self.start_index = curr_idx + self.start_ts = req_ts + curr_buf_ts = req_ts + else: + # buffers between self.start_index and curr_idx are reset + for i in range(reset_buf_cnt): + reset_idx = (self.start_index + i) % self.buffer_count + reset_ts = self.circular_buffer[ + reset_idx].ts + self.total_duration + self.circular_buffer[reset_idx] = Buffer( + reset_ts, self.user_groups) + + # Update self.start_index + self.start_index = (curr_idx + 1) % self.buffer_count + self.start_ts = self.circular_buffer[self.start_index].ts + curr_buf_ts = self.circular_buffer[ + curr_idx].ts + self.total_duration + + # Set corresponding buffer + self.circular_buffer[curr_idx] = Buffer( + curr_buf_ts, self.user_groups) + self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] = 1 + self.circular_buffer[curr_idx].uid_to_tokens_ps[ + uid] = out_token_len + self.circular_buffer[curr_idx].group_to_reqs_ps[group] = 1 + self.circular_buffer[curr_idx].group_to_tokens_ps[ + group] = out_token_len + + # Otherwise update corresponding buffer + else: + self.circular_buffer[curr_idx].ts = curr_ts + + if uid in self.circular_buffer[curr_idx].uid_to_reqs_ps: + self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] += 1 + else: + self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] = 1 + + if uid in self.circular_buffer[curr_idx].uid_to_tokens_ps: + self.circular_buffer[curr_idx].uid_to_tokens_ps[ + uid] += out_token_len + else: + self.circular_buffer[curr_idx].uid_to_tokens_ps[ + uid] = out_token_len + + self.circular_buffer[curr_idx].group_to_reqs_ps[group] += 1 + self.circular_buffer[curr_idx].group_to_tokens_ps[ + group] += out_token_len + + def get_user_usage(self, uid: str, group: str): + """Calculate usage stats of the given user and group.""" + user_req_usage = 0 + user_token_usage = 0 + group_req_usage = 0 + group_token_usage = 0 + + # TODO: use reader lock + with self.lock: + for i in range(self.buffer_count): + if uid in self.circular_buffer[i].uid_to_reqs_ps: + user_req_usage += self.circular_buffer[i].uid_to_reqs_ps[ + uid] + user_token_usage += self.circular_buffer[ + i].uid_to_tokens_ps[uid] + + group_req_usage += self.circular_buffer[i].group_to_reqs_ps[ + group] + group_token_usage += self.circular_buffer[ + i].group_to_tokens_ps[group] + + return (user_req_usage, user_token_usage, group_req_usage, + group_token_usage)