Skip to content

Commit

Permalink
infr: update dependencies and openai models
Browse files Browse the repository at this point in the history
  • Loading branch information
nalgeon committed Nov 11, 2023
1 parent c717f7c commit 4a8c63a
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 51 deletions.
54 changes: 35 additions & 19 deletions bot/ai/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
"""ChatGPT (GPT-3.5+) language model from OpenAI."""

import logging
import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI
import tiktoken
from bot.config import config

logger = logging.getLogger(__name__)
if config.openai.azure:
openai = AsyncAzureOpenAI(
api_key=config.openai.api_key,
api_version=config.openai.azure["version"],
azure_endpoint=config.openai.azure["endpoint"],
azure_deployment=config.openai.azure["deployment"],
)
else:
openai = AsyncOpenAI(api_key=config.openai.api_key)

openai.api_key = config.openai.api_key
encoding = tiktoken.get_encoding("cl100k_base")
logger = logging.getLogger(__name__)

# Supported models and their context windows
MODELS = {
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16385,
}


class Model:
Expand All @@ -24,31 +43,30 @@ async def ask(self, question: str, history: list[tuple[str, str]]) -> str:
n_input = _calc_n_input(self.name, n_output=config.openai.params["max_tokens"])
messages = self._generate_messages(question, history)
messages = shorten(messages, length=n_input)
params = self._prepare_params()
resp = await openai.ChatCompletion.acreate(
params = config.openai.params
logger.debug(
f"> chat request: model=%s, params=%s, messages=%s",
self.name,
params,
messages,
)
resp = await openai.chat.completions.create(
model=self.name,
messages=messages,
**params,
)
logger.debug(
"prompt_tokens=%s, completion_tokens=%s, total_tokens=%s",
"< chat response: prompt_tokens=%s, completion_tokens=%s, total_tokens=%s",
resp.usage.prompt_tokens,
resp.usage.completion_tokens,
resp.usage.total_tokens,
)
answer = self._prepare_answer(resp)
return answer

def _prepare_params(self) -> dict:
params = config.openai.params.copy()
if config.openai.azure:
params["api_type"] = "azure"
params["api_base"] = config.openai.azure["endpoint"]
params["api_version"] = config.openai.azure["version"]
params["deployment_id"] = config.openai.azure["deployment"]
return params

def _generate_messages(self, question: str, history: list[tuple[str, str]]) -> list[dict]:
def _generate_messages(
self, question: str, history: list[tuple[str, str]]
) -> list[dict]:
"""Builds message history to provide context for the language model."""
messages = [{"role": "system", "content": config.openai.prompt}]
for prev_question, prev_answer in history:
Expand Down Expand Up @@ -105,7 +123,5 @@ def _calc_n_input(name: str, n_output: int) -> int:
"""
# OpenAI counts length in tokens, not charactes.
# We need to leave some tokens reserved for the output.
n_total = 4096 # max 4096 tokens total by default
if name == "gpt-4":
n_total = 8192
n_total = MODELS.get(name, 4096) # max 4096 tokens total by default
return n_total - n_output
6 changes: 3 additions & 3 deletions bot/ai/custom.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Fine-tuned language model from OpenAI."""

import re
import openai
from openai import AsyncOpenAI
from bot.config import config

openai.api_key = config.openai.api_key
openai = AsyncOpenAI(api_key=config.openai.api_key)

DEFAULT_STOP = "###"
PRE_RE = re.compile(r"&lt;(/?pre)")
Expand All @@ -29,7 +29,7 @@ async def ask(self, question, history=None) -> str:
try:
history = history or []
prompt = self._generate_prompt(question, history)
resp = await openai.Completion.acreate(
resp = await openai.completions.create(
model=self.name,
prompt=prompt,
temperature=0.7,
Expand Down
8 changes: 5 additions & 3 deletions bot/ai/dalle.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""DALL-E model from OpenAI."""

import openai
from openai import AsyncOpenAI
from bot.config import config

openai.api_key = config.openai.api_key
openai = AsyncOpenAI(api_key=config.openai.api_key)


class Model:
"""OpenAI DALL-E wrapper."""

async def imagine(self, prompt: str, size: str) -> str:
"""Generates an image of the specified size according to the description."""
resp = await openai.Image.acreate(prompt=prompt, size=size, n=1)
resp = await openai.images.generate(
model="dall-e-3", prompt=prompt, size=size, n=1
)
if len(resp.data) == 0:
raise ValueError("received an empty answer")
return resp.data[0].url
6 changes: 3 additions & 3 deletions bot/ai/davinci.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""DaVinci (GPT-3) language model from OpenAI."""

import re
import openai
from openai import AsyncOpenAI
from bot.config import config

openai.api_key = config.openai.api_key
openai = AsyncOpenAI(api_key=config.openai.api_key)

BASE_PROMPT = "Your primary goal is to answer my questions. This may involve writing code or providing helpful information. Be detailed and thorough in your responses. Write code inside <pre>, </pre> tags."

Expand All @@ -18,7 +18,7 @@ async def ask(self, question, history=None):
"""Asks the language model a question and returns an answer."""
history = history or []
prompt = self._generate_prompt(question, history)
resp = await openai.Completion.acreate(
resp = await openai.completions.create(
model="text-davinci-003",
prompt=prompt,
temperature=0.7,
Expand Down
26 changes: 19 additions & 7 deletions bot/askers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ async def ask(self, question: str, history: list[tuple[str, str]]) -> str:
"""Asks AI a question."""
pass

async def reply(self, message: Message, context: CallbackContext, answer: str) -> None:
async def reply(
self, message: Message, context: CallbackContext, answer: str
) -> None:
"""Replies with an answer from AI."""
pass

Expand All @@ -38,7 +40,9 @@ async def ask(self, question: str, history: list[tuple[str, str]]) -> str:
"""Asks AI a question."""
return await self.model.ask(question, history)

async def reply(self, message: Message, context: CallbackContext, answer: str) -> None:
async def reply(
self, message: Message, context: CallbackContext, answer: str
) -> None:
"""Replies with an answer from AI."""
html_answer = markdown.to_html(answer)
if len(html_answer) <= MessageLimit.MAX_TEXT_LENGTH:
Expand All @@ -47,7 +51,8 @@ async def reply(self, message: Message, context: CallbackContext, answer: str) -

doc = io.StringIO(answer)
caption = (
textwrap.shorten(answer, width=40, placeholder="...") + " (see attachment for the rest)"
textwrap.shorten(answer, width=40, placeholder="...")
+ " (see attachment for the rest)"
)
reply_to_message_id = message.id if message.chat.type != Chat.PRIVATE else None
await context.bot.send_document(
Expand All @@ -64,8 +69,13 @@ class ImagineAsker(Asker):

model = ai.dalle.Model()
size_re = re.compile(r"(256|512|1024)(?:x\1)?\s?(?:px)?")
sizes = {"256": "256x256", "512": "512x512", "1024": "1024x1024"}
default_size = "512x512"
sizes = {
"256": "256x256",
"512": "512x512",
"1024": "1024x1024",
"1792": "1792x1024",
}
default_size = "1024x1024"

def __init__(self) -> None:
self.caption = ""
Expand All @@ -76,7 +86,9 @@ async def ask(self, question: str, history: list[tuple[str, str]]) -> str:
self.caption = self._extract_caption(question)
return await self.model.imagine(prompt=self.caption, size=size)

async def reply(self, message: Message, context: CallbackContext, answer: str) -> None:
async def reply(
self, message: Message, context: CallbackContext, answer: str
) -> None:
"""Replies with an answer from AI."""
await message.reply_photo(answer, caption=self.caption)

Expand All @@ -85,7 +97,7 @@ def _extract_size(self, question: str) -> str:
if not match:
return self.default_size
width = match.group(1)
return self.sizes[width]
return self.sizes.get(width, width)

def _extract_caption(self, question: str) -> str:
caption = self.size_re.sub("", question).strip()
Expand Down
41 changes: 31 additions & 10 deletions bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("bot.ai.chatgpt").setLevel(logging.INFO)
logging.getLogger("bot.commands").setLevel(logging.INFO)
Expand Down Expand Up @@ -65,24 +66,36 @@ def add_handlers(application: Application):

# info commands
application.add_handler(CommandHandler("start", commands.Start()))
application.add_handler(CommandHandler("help", commands.Help(), filters=filters.users))
application.add_handler(CommandHandler("version", commands.Version(), filters=filters.users))
application.add_handler(
CommandHandler("help", commands.Help(), filters=filters.users)
)
application.add_handler(
CommandHandler("version", commands.Version(), filters=filters.users)
)

# admin commands
application.add_handler(
CommandHandler("config", commands.Config(filters), filters=filters.admins_private)
CommandHandler(
"config", commands.Config(filters), filters=filters.admins_private
)
)

# message-related commands
application.add_handler(
CommandHandler("retry", commands.Retry(reply_to), filters=filters.users_or_chats)
CommandHandler(
"retry", commands.Retry(reply_to), filters=filters.users_or_chats
)
)
application.add_handler(
CommandHandler("imagine", commands.Imagine(reply_to), filters=filters.users_or_chats)
CommandHandler(
"imagine", commands.Imagine(reply_to), filters=filters.users_or_chats
)
)

# non-command handler: the default action is to reply to a message
application.add_handler(MessageHandler(filters.messages, commands.Message(reply_to)))
application.add_handler(
MessageHandler(filters.messages, commands.Message(reply_to))
)

# generic error handler
application.add_error_handler(commands.Error())
Expand All @@ -108,19 +121,25 @@ async def post_shutdown(application: Application) -> None:
def with_message_limit(func):
"""Refuses to reply if the user has exceeded the message limit."""

async def wrapper(message: Message, context: CallbackContext, question: str) -> None:
async def wrapper(
message: Message, context: CallbackContext, question: str
) -> None:
username = message.from_user.username
user = UserData(context.user_data)

# check if the message counter exceeds the message limit
if (
not filters.is_known_user(username)
and user.message_counter.value >= config.conversation.message_limit.count > 0
and user.message_counter.value
>= config.conversation.message_limit.count
> 0
and not user.message_counter.is_expired()
):
# this is a group user and they have exceeded the message limit
wait_for = models.format_timedelta(user.message_counter.expires_after())
await message.reply_text(f"Please wait {wait_for} before asking a new question.")
await message.reply_text(
f"Please wait {wait_for} before asking a new question."
)
return

# this is a known user or they have not exceeded the message limit,
Expand All @@ -137,7 +156,9 @@ async def wrapper(message: Message, context: CallbackContext, question: str) ->
@with_message_limit
async def reply_to(message: Message, context: CallbackContext, question: str) -> None:
"""Replies to a specific question."""
await message.chat.send_action(action="typing", message_thread_id=message.message_thread_id)
await message.chat.send_action(
action="typing", message_thread_id=message.message_thread_id
)

try:
asker = askers.create(question)
Expand Down
11 changes: 6 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
httpx==0.23.3
openai==0.27.2
beautifulsoup4==4.12.0
python-telegram-bot==20.2
httpcore==1.0.2
httpx==0.25.1
openai==1.2.3
beautifulsoup4==4.12.2
python-telegram-bot==20.6
PyYAML==6.0.1
tiktoken==0.3.3
tiktoken==0.5.1
2 changes: 1 addition & 1 deletion tests/test_askers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_extract_size(self):
size = asker._extract_size(question="a cat 256px")
self.assertEqual(size, "256x256")
size = asker._extract_size(question="a cat 384")
self.assertEqual(size, "512x512")
self.assertEqual(size, "1024x1024")

def test_extract_caption(self):
asker = ImagineAsker()
Expand Down

0 comments on commit 4a8c63a

Please sign in to comment.