Skip to content

Commit

Permalink
FIX: Final tweaks leading up to 1.0 release (#3)
Browse files Browse the repository at this point in the history
* FIX: Move optional dependencies into try-except statement

* FIX: Updated log throttler doc strings

* FIX: Removed warnings when running tests

* Fix: Fixed broken link in log throttler doc string

* DOC: revert Sphinx workaround for HuggingfaceChat (no longer needed)

* DOC: Exclude artkit.api from sphinx doc build

* DOC: Update links in README.rst

* DOC: Remove show source button

* FIX: File formatting

---------

Co-authored-by: ALontke <[email protected]>
Co-authored-by: j-ittner <[email protected]>
Co-authored-by: breakbotz <[email protected]>
  • Loading branch information
4 people authored Jun 19, 2024
1 parent afd58be commit 19ab84b
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 38 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Getting started
---------------

- See the `ARTKIT Documentation <https://bcg-x-official.github.io/artkit/home.html>`_ for our User Guide, Examples, API reference, and more.
- See `Contributing <CONTRIBUTING.md>`_ or visit our `Contributor Guide <https://bcg-x-official.github.io/artkit/contributor_guide/index.html>`_ for information on contributing.
- See `Contributing <https://github.com/BCG-X-Official/artkit/blob/HEAD/CONTRIBUTING.md>`_ or visit our `Contributor Guide <https://bcg-x-official.github.io/artkit/contributor_guide/index.html>`_ for information on contributing.
- We have an `FAQ <https://bcg-x-official.github.io/artkit/faq.html>`_ for common questions. For anything else, please reach out to [email protected].

.. _Introduction:
Expand Down Expand Up @@ -326,7 +326,7 @@ and `Examples <https://bcg-x-official.github.io/artkit/examples/index.html>`_.
Contributing
------------

Contributions to ARTKIT are welcome and appreciated! Please see the `Contributing <https://bcg-x-official.github.io/artkit/contributor_guide/index.html>`_ section for information.
Contributions to ARTKIT are welcome and appreciated! Please see the `Contributor Guide <https://bcg-x-official.github.io/artkit/contributor_guide/index.html>`_ section for information.


License
Expand Down
2 changes: 1 addition & 1 deletion sphinx/make/make_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
]
assert len(PACKAGE_NAMES) == 1, "only one package per Sphinx build is supported"
PROJECT_NAME = PACKAGE_NAMES[0]
EXCLUDE_MODULES = []
EXCLUDE_MODULES = ["api"]
DIR_DOCS = os.path.join(DIR_REPO_ROOT, "docs")
DIR_DOCS_VERSION = os.path.join(DIR_DOCS, "docs-version")
DIR_SPHINX_SOURCE = os.path.join(DIR_SPHINX_ROOT, "source")
Expand Down
2 changes: 2 additions & 0 deletions sphinx/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@
project="artkit",
html_logo=os.path.join("_images", "ARTKIT_Logo_Light_RGB-small.png"),
)

html_show_sourcelink = False
4 changes: 0 additions & 4 deletions src/artkit/model/llm/huggingface/_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
try: # pragma: no cover
from aiohttp import ClientResponseError, ClientSession
from huggingface_hub import AsyncInferenceClient
from transformers import AutoTokenizer
except ImportError:

class AsyncInferenceClient( # type: ignore
Expand All @@ -37,9 +36,6 @@ class ClientResponseError(metaclass=MissingClassMeta, module="aiohttp"): # type
class ClientSession(metaclass=MissingClassMeta, module="aiohttp"): # type: ignore
"""Placeholder class for missing ``ClientSession`` class."""

class AutoTokenizer(metaclass=MissingClassMeta, module="transformers"): # type: ignore
"""Placeholder class for missing ``AutoTokenizer`` class."""


log = logging.getLogger(__name__)

Expand Down
21 changes: 17 additions & 4 deletions src/artkit/model/llm/huggingface/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from contextlib import AsyncExitStack
from typing import Any, Generic, TypeVar, final

import torch

from pytools.api import MissingClassMeta, appenddoc, inheritdoc, subsdoc
from pytools.api import (
MissingClassMeta,
appenddoc,
inheritdoc,
missing_function,
subsdoc,
)

from ....base import ConnectorMixin
from ...base import ChatModelConnector, CompletionModelConnector
Expand All @@ -29,16 +33,24 @@
try:
# noinspection PyUnresolvedReferences
from huggingface_hub import AsyncInferenceClient
from torch.cuda import is_available

# noinspection PyUnresolvedReferences
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError: # pragma: no cover

is_available = missing_function(name="is_available", module="torch.cuda")

class AsyncInferenceClient( # type: ignore
metaclass=MissingClassMeta, module="huggingface_hub"
):
"""Placeholder class for missing ``AsyncInferenceClient`` class."""

class AutoModelForCausalLM( # type: ignore
metaclass=MissingClassMeta, module="transformers"
):
"""Placeholder class for missing ``AutoModelForCausalLM`` class."""

class AutoTokenizer(metaclass=MissingClassMeta, module="transformers"): # type: ignore
"""Placeholder class for missing ``AutoTokenizer`` class."""

Expand Down Expand Up @@ -128,7 +140,8 @@ def __init__(
model_params=model_params,
)

if use_cuda and not torch.cuda.is_available(): # pragma: no cover
# test if cuda is available
if use_cuda and not is_available(): # pragma: no cover
raise RuntimeError("CUDA requested but not available.")

self.use_cuda = use_cuda
Expand Down
24 changes: 11 additions & 13 deletions src/artkit/util/_log_throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class LogThrottlingHandler(logging.Handler):
the maximum number of messages sent during a given time interval.
This is useful for preventing log spamming in high-throughput applications.
To follow the native logging flow (https://docs.python.org/3/howto/logging\
.html#logging-flow), functionality is implemented across
the :meth:`.filter` and :meth:`.emit` methods.
To follow the native logging flow
(https://docs.python.org/3/howto/logging.html#logging-flow),
functionality is implemented across the :meth:`.filter` and :meth:`.emit` methods.
Example:
Expand All @@ -44,12 +44,10 @@ def __init__(
self, handler: logging.Handler, interval: float, max_messages: int
) -> None:
"""
Initializes log throttling handler.
:param handler: the handler to wrap.
:param handler: the handler to wrap
:param interval: the minimum interval in seconds between log messages
with the same message.
:param max_messages: the maximum number of messages to log within the interval.
with the same message
:param max_messages: the maximum number of messages to log within the interval
"""
super().__init__()
self.handler = handler
Expand All @@ -63,9 +61,9 @@ def filter(self, record: logging.LogRecord) -> bool:
"""
Filter a log record based on the throttling settings.
:param record: the log record to filter.
:param record: the log record to filter
:return: ``True`` if the max messages are not exceeded
within the time interval, otherwise ``False``.
within the time interval, otherwise ``False``
"""

# if any other filter was registered and returns False, return False
Expand All @@ -90,7 +88,7 @@ def emit(self, record: logging.LogRecord) -> None:
"""
Emit a record if the max messages are not exceeded within the time interval.
:param record: the log record to emit.
:param record: the log record to emit
"""
count, last_log_time, buffer = self.log_counts[record.msg]

Expand All @@ -110,8 +108,8 @@ def _create_ellipsis_record(record: logging.LogRecord) -> logging.LogRecord:
"""
Create a log record with a custom message to indicate throttling.
:param record: the original log record.
:return: a new log record with the custom message message.
:param record: the original log record
:return: a new log record with the custom message message
"""
return logging.LogRecord(
name=record.name,
Expand Down
6 changes: 3 additions & 3 deletions test/artkit_test/model/diffusion/test_diffusion_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from collections.abc import Iterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import Mock, patch

import pytest
from openai import RateLimitError
Expand Down Expand Up @@ -38,13 +38,13 @@ async def test_openai_retry(caplog: pytest.LogCaptureFixture) -> None:
)

# Response mock
response = AsyncMock()
response = Mock()
response.status_code = 429

MockClientSession.return_value.images.generate.side_effect = RateLimitError(
message="Rate Limit exceeded",
response=response,
body=AsyncMock(),
body=Mock(),
)

with pytest.raises(RateLimitException):
Expand Down
15 changes: 9 additions & 6 deletions test/artkit_test/model/llm/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ async def test_huggingface_chat_async(
with patch("aiohttp.ClientSession") as MockClientSession:

# Mock the response object
mock_post = AsyncMock()
mock_post.read.return_value = b'[{"generated_text": "blue"}]'
mock_post = Mock()
mock_post.read = AsyncMock(return_value=b'[{"generated_text": "blue"}]')
mock_post.return_value.status = 200

# Set up the mock connection object
Expand Down Expand Up @@ -149,10 +149,13 @@ async def test_huggingface_chat_aiohttp(
) as MockClientSession:

# Mock the response object
mock_post = AsyncMock()
mock_post.json.return_value = {
"choices": [{"message": {"role": "assistant", "content": "blue"}}]
}
mock_post = Mock()
mock_post.json = AsyncMock(
return_value={
"choices": [{"message": {"role": "assistant", "content": "blue"}}]
}
)
mock_post.text = AsyncMock()
mock_post.return_value.status = 200

# Set up the mock connection object
Expand Down
4 changes: 2 additions & 2 deletions test/artkit_test/model/llm/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ async def test_openai_retry(
# Mock openai Client
with patch("artkit.model.llm.openai._openai.AsyncOpenAI") as mock_get_client:
# Set mock response as return value
response = AsyncMock()
response = MagicMock()
response.status_code = 429

# Mock exception on method call
mock_get_client.return_value.chat.completions.create.side_effect = (
RateLimitError(
message="Rate Limit exceeded",
response=response,
body=AsyncMock(),
body=MagicMock(),
)
)

Expand Down
6 changes: 3 additions & 3 deletions test/artkit_test/model/vision/test_vision_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from collections.abc import Iterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import Mock, patch

import pytest
from openai import RateLimitError
Expand Down Expand Up @@ -40,14 +40,14 @@ async def test_openai_retry(image: Image, caplog: pytest.LogCaptureFixture) -> N
)

# Response mock
response = AsyncMock()
response = Mock()
response.status_code = 429

MockClientSession.return_value.chat.completions.create.side_effect = (
RateLimitError(
message="Rate Limit exceeded",
response=response,
body=AsyncMock(),
body=Mock(),
)
)

Expand Down

0 comments on commit 19ab84b

Please sign in to comment.