diff --git a/README.rst b/README.rst index bc57dcb..e80136c 100644 --- a/README.rst +++ b/README.rst @@ -13,7 +13,7 @@ Getting started --------------- - See the `ARTKIT Documentation `_ for our User Guide, Examples, API reference, and more. -- See `Contributing `_ or visit our `Contributor Guide `_ for information on contributing. +- See `Contributing `_ or visit our `Contributor Guide `_ for information on contributing. - We have an `FAQ `_ for common questions. For anything else, please reach out to ARTKIT@bcg.com. .. _Introduction: @@ -326,7 +326,7 @@ and `Examples `_. Contributing ------------ -Contributions to ARTKIT are welcome and appreciated! Please see the `Contributing `_ section for information. +Contributions to ARTKIT are welcome and appreciated! Please see the `Contributor Guide `_ section for information. License diff --git a/sphinx/make/make_base.py b/sphinx/make/make_base.py index 1b48320..0098b47 100755 --- a/sphinx/make/make_base.py +++ b/sphinx/make/make_base.py @@ -45,7 +45,7 @@ ] assert len(PACKAGE_NAMES) == 1, "only one package per Sphinx build is supported" PROJECT_NAME = PACKAGE_NAMES[0] -EXCLUDE_MODULES = [] +EXCLUDE_MODULES = ["api"] DIR_DOCS = os.path.join(DIR_REPO_ROOT, "docs") DIR_DOCS_VERSION = os.path.join(DIR_DOCS, "docs-version") DIR_SPHINX_SOURCE = os.path.join(DIR_SPHINX_ROOT, "source") diff --git a/sphinx/source/conf.py b/sphinx/source/conf.py index 850be0f..2897621 100644 --- a/sphinx/source/conf.py +++ b/sphinx/source/conf.py @@ -19,3 +19,5 @@ project="artkit", html_logo=os.path.join("_images", "ARTKIT_Logo_Light_RGB-small.png"), ) + +html_show_sourcelink = False diff --git a/src/artkit/model/llm/huggingface/_huggingface.py b/src/artkit/model/llm/huggingface/_huggingface.py index 85721b3..ecbe789 100644 --- a/src/artkit/model/llm/huggingface/_huggingface.py +++ b/src/artkit/model/llm/huggingface/_huggingface.py @@ -23,7 +23,6 @@ try: # pragma: no cover from aiohttp import ClientResponseError, ClientSession from huggingface_hub import AsyncInferenceClient - from transformers import AutoTokenizer except ImportError: class AsyncInferenceClient( # type: ignore @@ -37,9 +36,6 @@ class ClientResponseError(metaclass=MissingClassMeta, module="aiohttp"): # type class ClientSession(metaclass=MissingClassMeta, module="aiohttp"): # type: ignore """Placeholder class for missing ``ClientSession`` class.""" - class AutoTokenizer(metaclass=MissingClassMeta, module="transformers"): # type: ignore - """Placeholder class for missing ``AutoTokenizer`` class.""" - log = logging.getLogger(__name__) diff --git a/src/artkit/model/llm/huggingface/base/_base.py b/src/artkit/model/llm/huggingface/base/_base.py index 12baf36..7fed68c 100644 --- a/src/artkit/model/llm/huggingface/base/_base.py +++ b/src/artkit/model/llm/huggingface/base/_base.py @@ -8,9 +8,13 @@ from contextlib import AsyncExitStack from typing import Any, Generic, TypeVar, final -import torch - -from pytools.api import MissingClassMeta, appenddoc, inheritdoc, subsdoc +from pytools.api import ( + MissingClassMeta, + appenddoc, + inheritdoc, + missing_function, + subsdoc, +) from ....base import ConnectorMixin from ...base import ChatModelConnector, CompletionModelConnector @@ -29,16 +33,24 @@ try: # noinspection PyUnresolvedReferences from huggingface_hub import AsyncInferenceClient + from torch.cuda import is_available # noinspection PyUnresolvedReferences from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: # pragma: no cover + is_available = missing_function(name="is_available", module="torch.cuda") + class AsyncInferenceClient( # type: ignore metaclass=MissingClassMeta, module="huggingface_hub" ): """Placeholder class for missing ``AsyncInferenceClient`` class.""" + class AutoModelForCausalLM( # type: ignore + metaclass=MissingClassMeta, module="transformers" + ): + """Placeholder class for missing ``AutoModelForCausalLM`` class.""" + class AutoTokenizer(metaclass=MissingClassMeta, module="transformers"): # type: ignore """Placeholder class for missing ``AutoTokenizer`` class.""" @@ -128,7 +140,8 @@ def __init__( model_params=model_params, ) - if use_cuda and not torch.cuda.is_available(): # pragma: no cover + # test if cuda is available + if use_cuda and not is_available(): # pragma: no cover raise RuntimeError("CUDA requested but not available.") self.use_cuda = use_cuda diff --git a/src/artkit/util/_log_throttling.py b/src/artkit/util/_log_throttling.py index f92fce4..cd08a0a 100644 --- a/src/artkit/util/_log_throttling.py +++ b/src/artkit/util/_log_throttling.py @@ -19,9 +19,9 @@ class LogThrottlingHandler(logging.Handler): the maximum number of messages sent during a given time interval. This is useful for preventing log spamming in high-throughput applications. - To follow the native logging flow (https://docs.python.org/3/howto/logging\ - .html#logging-flow), functionality is implemented across - the :meth:`.filter` and :meth:`.emit` methods. + To follow the native logging flow + (https://docs.python.org/3/howto/logging.html#logging-flow), + functionality is implemented across the :meth:`.filter` and :meth:`.emit` methods. Example: @@ -44,12 +44,10 @@ def __init__( self, handler: logging.Handler, interval: float, max_messages: int ) -> None: """ - Initializes log throttling handler. - - :param handler: the handler to wrap. + :param handler: the handler to wrap :param interval: the minimum interval in seconds between log messages - with the same message. - :param max_messages: the maximum number of messages to log within the interval. + with the same message + :param max_messages: the maximum number of messages to log within the interval """ super().__init__() self.handler = handler @@ -63,9 +61,9 @@ def filter(self, record: logging.LogRecord) -> bool: """ Filter a log record based on the throttling settings. - :param record: the log record to filter. + :param record: the log record to filter :return: ``True`` if the max messages are not exceeded - within the time interval, otherwise ``False``. + within the time interval, otherwise ``False`` """ # if any other filter was registered and returns False, return False @@ -90,7 +88,7 @@ def emit(self, record: logging.LogRecord) -> None: """ Emit a record if the max messages are not exceeded within the time interval. - :param record: the log record to emit. + :param record: the log record to emit """ count, last_log_time, buffer = self.log_counts[record.msg] @@ -110,8 +108,8 @@ def _create_ellipsis_record(record: logging.LogRecord) -> logging.LogRecord: """ Create a log record with a custom message to indicate throttling. - :param record: the original log record. - :return: a new log record with the custom message message. + :param record: the original log record + :return: a new log record with the custom message message """ return logging.LogRecord( name=record.name, diff --git a/test/artkit_test/model/diffusion/test_diffusion_openai.py b/test/artkit_test/model/diffusion/test_diffusion_openai.py index fa8cf93..d0d69a6 100644 --- a/test/artkit_test/model/diffusion/test_diffusion_openai.py +++ b/test/artkit_test/model/diffusion/test_diffusion_openai.py @@ -2,7 +2,7 @@ import shutil from collections.abc import Iterator from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import Mock, patch import pytest from openai import RateLimitError @@ -38,13 +38,13 @@ async def test_openai_retry(caplog: pytest.LogCaptureFixture) -> None: ) # Response mock - response = AsyncMock() + response = Mock() response.status_code = 429 MockClientSession.return_value.images.generate.side_effect = RateLimitError( message="Rate Limit exceeded", response=response, - body=AsyncMock(), + body=Mock(), ) with pytest.raises(RateLimitException): diff --git a/test/artkit_test/model/llm/test_huggingface.py b/test/artkit_test/model/llm/test_huggingface.py index d77b531..84f09f0 100644 --- a/test/artkit_test/model/llm/test_huggingface.py +++ b/test/artkit_test/model/llm/test_huggingface.py @@ -119,8 +119,8 @@ async def test_huggingface_chat_async( with patch("aiohttp.ClientSession") as MockClientSession: # Mock the response object - mock_post = AsyncMock() - mock_post.read.return_value = b'[{"generated_text": "blue"}]' + mock_post = Mock() + mock_post.read = AsyncMock(return_value=b'[{"generated_text": "blue"}]') mock_post.return_value.status = 200 # Set up the mock connection object @@ -149,10 +149,13 @@ async def test_huggingface_chat_aiohttp( ) as MockClientSession: # Mock the response object - mock_post = AsyncMock() - mock_post.json.return_value = { - "choices": [{"message": {"role": "assistant", "content": "blue"}}] - } + mock_post = Mock() + mock_post.json = AsyncMock( + return_value={ + "choices": [{"message": {"role": "assistant", "content": "blue"}}] + } + ) + mock_post.text = AsyncMock() mock_post.return_value.status = 200 # Set up the mock connection object diff --git a/test/artkit_test/model/llm/test_openai.py b/test/artkit_test/model/llm/test_openai.py index 81f651a..00f8559 100644 --- a/test/artkit_test/model/llm/test_openai.py +++ b/test/artkit_test/model/llm/test_openai.py @@ -38,7 +38,7 @@ async def test_openai_retry( # Mock openai Client with patch("artkit.model.llm.openai._openai.AsyncOpenAI") as mock_get_client: # Set mock response as return value - response = AsyncMock() + response = MagicMock() response.status_code = 429 # Mock exception on method call @@ -46,7 +46,7 @@ async def test_openai_retry( RateLimitError( message="Rate Limit exceeded", response=response, - body=AsyncMock(), + body=MagicMock(), ) ) diff --git a/test/artkit_test/model/vision/test_vision_openai.py b/test/artkit_test/model/vision/test_vision_openai.py index 9417c0d..c001177 100644 --- a/test/artkit_test/model/vision/test_vision_openai.py +++ b/test/artkit_test/model/vision/test_vision_openai.py @@ -2,7 +2,7 @@ import shutil from collections.abc import Iterator from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import Mock, patch import pytest from openai import RateLimitError @@ -40,14 +40,14 @@ async def test_openai_retry(image: Image, caplog: pytest.LogCaptureFixture) -> N ) # Response mock - response = AsyncMock() + response = Mock() response.status_code = 429 MockClientSession.return_value.chat.completions.create.side_effect = ( RateLimitError( message="Rate Limit exceeded", response=response, - body=AsyncMock(), + body=Mock(), ) )