From b8cf464a874fba1c83df150c83ecf683aa19a361 Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Fri, 12 Apr 2024 18:03:52 +0800 Subject: [PATCH] Fix generation parameters in API models (#181) * fix gen_params * remove from --------- Co-authored-by: wangzy --- lagent/llms/base_api.py | 2 +- lagent/llms/base_llm.py | 2 +- lagent/llms/openai.py | 14 +++++++++++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/lagent/llms/base_api.py b/lagent/llms/base_api.py index 6088fe79..d6667263 100644 --- a/lagent/llms/base_api.py +++ b/lagent/llms/base_api.py @@ -154,7 +154,7 @@ def __init__(self, *, max_new_tokens: int = 512, top_p: float = 0.8, - top_k: float = None, + top_k: int = 40, temperature: float = 0.8, repetition_penalty: float = 0.0, stop_words: Union[List[str], str] = None): diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py index b9687eae..6677e5b2 100644 --- a/lagent/llms/base_llm.py +++ b/lagent/llms/base_llm.py @@ -118,7 +118,7 @@ def __init__(self, *, max_new_tokens: int = 512, top_p: float = 0.8, - top_k: float = None, + top_k: float = 40, temperature: float = 0.8, repetition_penalty: float = 1.0, stop_words: Union[List[str], str] = None): diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 91b3f236..9a887c48 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -1,6 +1,7 @@ import json import os import time +import warnings from concurrent.futures import ThreadPoolExecutor from logging import getLogger from threading import Lock @@ -10,6 +11,8 @@ from .base_api import BaseAPIModel +warnings.simplefilter('default') + OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' @@ -54,6 +57,10 @@ def __init__(self, ], openai_api_base: str = OPENAI_API_BASE, **gen_params): + if 'top_k' in gen_params: + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) + gen_params.pop('top_k') super().__init__( model_type=model_type, meta_template=meta_template, @@ -170,14 +177,15 @@ def _chat(self, messages: List[dict], **gen_params) -> str: header['OpenAI-Organization'] = self.orgs[self.org_ctr] try: + gen_params_new = gen_params.copy() data = dict( model=self.model_type, messages=messages, max_tokens=max_tokens, n=1, - stop=gen_params.pop('stop_words'), - frequency_penalty=gen_params.pop('repetition_penalty'), - **gen_params, + stop=gen_params_new.pop('stop_words'), + frequency_penalty=gen_params_new.pop('repetition_penalty'), + **gen_params_new, ) raw_response = requests.post( self.url, headers=header, data=json.dumps(data))