diff --git a/.vscode/settings.json b/.vscode/settings.json index aa5b7cfb..5c9e7b91 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -6,13 +6,9 @@ "-p", "*test*.py" ], - "python.testing.pytestEnabled": true, - "python.testing.unittestEnabled": false, + "python.testing.unittestEnabled": true, "editor.defaultFormatter": "charliermarsh.ruff", "[python]": { "editor.formatOnSave": true, }, - "python.testing.pytestArgs": [ - "tests" - ] } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 941253d7..3b2bfcd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,9 @@ requires-python = ">=3.7" dependencies = [ "requests", "tqdm", - "packaging" + "packaging", + "aiohttp[speedups]", + ] [project.urls] diff --git a/src/kagglehub/auth.py b/src/kagglehub/auth.py index fee737f7..3c4f2356 100644 --- a/src/kagglehub/auth.py +++ b/src/kagglehub/auth.py @@ -7,6 +7,8 @@ from kagglehub.config import get_kaggle_credentials, set_kaggle_credentials from kagglehub.exceptions import UnauthenticatedError +from aiohttp import ClientSession + _logger = logging.getLogger(__name__) INVALID_CREDENTIALS_ERROR = 401 @@ -60,7 +62,7 @@ def _is_in_notebook() -> bool: return False # Probably standard Python interpreter -def _notebook_login(validate_credentials: bool) -> None: # noqa: FBT001 +async def _notebook_login(validate_credentials: bool) -> None: # noqa: FBT001 """Prompt the user for their Kaggle token and save it in a widget (Jupyter or Colab).""" library_error = "You need the `ipywidgets` module: `pip install ipywidgets`." try: @@ -87,7 +89,7 @@ def _notebook_login(validate_credentials: bool) -> None: # noqa: FBT001 ) display(login_token_widget) - def on_click_login_button(_: str) -> None: + async def on_click_login_button(_: str) -> None: username = username_widget.value token = token_widget.value # Erase token and clear value to make sure it's not saved in the notebook. @@ -102,7 +104,8 @@ def on_click_login_button(_: str) -> None: # Validate credentials if necessary if validate_credentials is True: - _validate_credentials_helper() + async with ClientSession() as session: + await _validate_credentials_helper(session) message = captured.getvalue() except Exception as error: message = str(error) @@ -112,9 +115,9 @@ def on_click_login_button(_: str) -> None: login_button.on_click(on_click_login_button) -def _validate_credentials_helper() -> None: - api_client = KaggleApiV1Client() - response = api_client.get("/hello") +async def _validate_credentials_helper(session: ClientSession) -> None: + api_client = KaggleApiV1Client(session) + response = await api_client.get("/hello") if "code" not in response: _logger.info("Kaggle credentials successfully validated.") elif response["code"] == INVALID_CREDENTIALS_ERROR: @@ -125,11 +128,11 @@ def _validate_credentials_helper() -> None: _logger.warning("Unable to validate Kaggle credentials at this time.") -def login(validate_credentials: bool = True) -> None: # noqa: FBT002, FBT001 +async def login(validate_credentials: bool = True) -> None: # noqa: FBT002, FBT001 """Prompt the user for their Kaggle username and API key and save them globally.""" if _is_in_notebook(): - _notebook_login(validate_credentials) + await _notebook_login(validate_credentials) return else: username = input("Enter your Kaggle username: ") @@ -140,7 +143,8 @@ def login(validate_credentials: bool = True) -> None: # noqa: FBT002, FBT001 if not validate_credentials: return - _validate_credentials_helper() + with ClientSession() as session: + await _validate_credentials_helper(session) def whoami() -> dict: diff --git a/src/kagglehub/clients.py b/src/kagglehub/clients.py index f2544a5c..b0218b90 100644 --- a/src/kagglehub/clients.py +++ b/src/kagglehub/clients.py @@ -6,6 +6,7 @@ from urllib.parse import urljoin import requests +from aiohttp import ClientResponse, ClientSession from packaging.version import parse from requests.auth import HTTPBasicAuth from tqdm import tqdm @@ -81,11 +82,12 @@ def get_user_agent() -> str: class KaggleApiV1Client: BASE_PATH = "api/v1" - def __init__(self) -> None: + def __init__(self, session: ClientSession) -> None: self.credentials = get_kaggle_credentials() self.endpoint = get_kaggle_api_endpoint() + self.session = session - def _check_for_version_update(self, response: requests.Response) -> None: + def _check_for_version_update(self, response: ClientResponse) -> None: latest_version_str = response.headers.get("X-Kaggle-HubVersion") if latest_version_str: current_version = parse(kagglehub.__version__) @@ -96,26 +98,26 @@ def _check_for_version_update(self, response: requests.Response) -> None: f"version, please consider updating (latest version: {latest_version})" ) - def get(self, path: str, resource_handle: Optional[ResourceHandle] = None) -> dict: + async def get(self, path: str, resource_handle: Optional[ResourceHandle] = None) -> dict: url = self._build_url(path) - with requests.get( + async with self.session.get( url, headers={"User-Agent": get_user_agent()}, - auth=self._get_http_basic_auth(), - timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), - ) as response: - kaggle_api_raise_for_status(response, resource_handle) - self._check_for_version_update(response) - return response.json() - - def post(self, path: str, data: dict) -> dict: + read_timeout=DEFAULT_READ_TIMEOUT, + conn_timeout=DEFAULT_CONNECT_TIMEOUT, + ) as resp: + kaggle_api_raise_for_status(resp, resource_handle) + self._check_for_version_update(resp) + return resp.json() + + async def post(self, path: str, data: dict) -> dict: url = self._build_url(path) - with requests.post( + async with self.session.post( url, headers={"User-Agent": get_user_agent()}, + read_timeout=DEFAULT_READ_TIMEOUT, + conn_timeout=DEFAULT_CONNECT_TIMEOUT, json=data, - auth=self._get_http_basic_auth(), - timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), ) as response: response.raise_for_status() response_dict = response.json() @@ -210,7 +212,7 @@ def _download_file( class KaggleJwtClient: BASE_PATH = "/kaggle-jwt-handler/" - def __init__(self) -> None: + def __init__(self, session: ClientSession) -> None: self.endpoint = os.getenv(KAGGLE_DATA_PROXY_URL_ENV_VAR_NAME) if self.endpoint is None: msg = f"The {KAGGLE_DATA_PROXY_URL_ENV_VAR_NAME} should be set." @@ -236,20 +238,16 @@ def __init__(self) -> None: "X-Kaggle-Authorization": f"Bearer {jwt_token}", "X-KAGGLE-PROXY-DATA": data_proxy_token, } + self.session = session - def post( + async def post( self, request_name: str, data: dict, timeout: Tuple[float, float] = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), ) -> dict: url = f"{self.endpoint}{KaggleJwtClient.BASE_PATH}{request_name}" - with requests.post( - url, - headers=self.headers, - data=bytes(json.dumps(data), "utf-8"), - timeout=timeout, - ) as response: + async with self.session.post(url, headers=self.headers, json=data, timeout=timeout) as response: response.raise_for_status() json_response = response.json() if "wasSuccessful" not in json_response: @@ -271,7 +269,7 @@ class ColabClient: # of ModelColabCacheResolver. TBE_RUNTIME_ADDR_ENV_VAR_NAME = "TBE_RUNTIME_ADDR" - def __init__(self) -> None: + def __init__(self, session: ClientSession) -> None: self.endpoint = os.getenv(self.TBE_RUNTIME_ADDR_ENV_VAR_NAME) if self.endpoint is None: msg = f"The {self.TBE_RUNTIME_ADDR_ENV_VAR_NAME} should be set." @@ -279,17 +277,20 @@ def __init__(self) -> None: self.credentials = get_kaggle_credentials() self.headers = {"Content-type": "application/json"} + self.session = session - def post(self, data: dict, handle_path: str, resource_handle: Optional[ResourceHandle] = None) -> Optional[dict]: + async def post( + self, data: dict, handle_path: str, resource_handle: Optional[ResourceHandle] = None + ) -> Optional[dict]: url = f"http://{self.endpoint}{handle_path}" - with requests.post( + with self.session.post( url, data=json.dumps(data), - auth=self._get_http_basic_auth(), headers=self.headers, - timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT), + read_timeout=DEFAULT_READ_TIMEOUT, + connect_timeout=DEFAULT_CONNECT_TIMEOUT, ) as response: - if response.status_code == HTTP_STATUS_404: + if response.staus == HTTP_STATUS_404: raise NotFoundError() colab_raise_for_status(response, resource_handle) if response.text: diff --git a/src/kagglehub/colab_cache_resolver.py b/src/kagglehub/colab_cache_resolver.py index c3b68ade..2df06426 100644 --- a/src/kagglehub/colab_cache_resolver.py +++ b/src/kagglehub/colab_cache_resolver.py @@ -8,6 +8,8 @@ from kagglehub.handle import ModelHandle from kagglehub.resolver import Resolver +from aiohttp import ClientSession + COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME = "COLAB_CACHE_MOUNT_FOLDER" DEFAULT_COLAB_CACHE_MOUNT_FOLDER = "/kaggle/input" @@ -15,29 +17,32 @@ class ModelColabCacheResolver(Resolver[ModelHandle]): - def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002, ANN003 + async def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002, ANN003 if ColabClient.TBE_RUNTIME_ADDR_ENV_VAR_NAME not in os.environ or is_colab_cache_disabled(): return False - api_client = ColabClient() - data = { - "owner": handle.owner, - "model": handle.model, - "framework": handle.framework, - "variation": handle.variation, - } - - if handle.is_versioned(): - # Colab treats version as int in the request - data["version"] = handle.version # type: ignore - - try: - api_client.post(data, ColabClient.IS_SUPPORTED_PATH, handle) - except NotFoundError: - return False - return True + with ClientSession() as session: + api_client = ColabClient(session) + data = { + "owner": handle.owner, + "model": handle.model, + "framework": handle.framework, + "variation": handle.variation, + } + + if handle.is_versioned(): + # Colab treats version as int in the request + data["version"] = handle.version # type: ignore - def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + try: + await api_client.post(data, ColabClient.IS_SUPPORTED_PATH, handle) + except NotFoundError: + return False + return True + + async def __call__( + self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> str: if force_download: logger.warning("Ignoring invalid input: force_download flag cannot be used in a Colab notebook") @@ -46,37 +51,38 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download else: logger.info(f"Attaching model '{h}' to your Colab notebook...") - api_client = ColabClient() - data = { - "owner": h.owner, - "model": h.model, - "framework": h.framework, - "variation": h.variation, - } - if h.is_versioned(): - # Colab treats version as int in the request - data["version"] = h.version # type: ignore - - response = api_client.post(data, ColabClient.MOUNT_PATH, h) - - if response is not None: - if "slug" not in response: - msg = "'slug' field missing from response" - raise BackendError(msg) - - base_mount_path = os.getenv(COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_COLAB_CACHE_MOUNT_FOLDER) - cached_path = f"{base_mount_path}/{response['slug']}" - - if path: - cached_filepath = f"{cached_path}/{path}" - if not os.path.exists(cached_filepath): - msg = ( - f"'{path}' is not present in the model files. " - f"You can access the other files of the attached model at '{cached_path}'" - ) - raise ValueError(msg) - return cached_filepath - return cached_path - else: - no_response = "No response received or response was empty." - raise ValueError(no_response) + with ClientSession() as session: + api_client = ColabClient(session) + data = { + "owner": h.owner, + "model": h.model, + "framework": h.framework, + "variation": h.variation, + } + if h.is_versioned(): + # Colab treats version as int in the request + data["version"] = h.version # type: ignore + + response = await api_client.post(data, ColabClient.MOUNT_PATH, h) + + if response is not None: + if "slug" not in response: + msg = "'slug' field missing from response" + raise BackendError(msg) + + base_mount_path = os.getenv(COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_COLAB_CACHE_MOUNT_FOLDER) + cached_path = f"{base_mount_path}/{response['slug']}" + + if path: + cached_filepath = f"{cached_path}/{path}" + if not os.path.exists(cached_filepath): + msg = ( + f"'{path}' is not present in the model files. " + f"You can access the other files of the attached model at '{cached_path}'" + ) + raise ValueError(msg) + return cached_filepath + return cached_path + else: + no_response = "No response received or response was empty." + raise ValueError(no_response) diff --git a/src/kagglehub/datasets.py b/src/kagglehub/datasets.py index 083e8806..c33821ec 100644 --- a/src/kagglehub/datasets.py +++ b/src/kagglehub/datasets.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: +async def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: """Download dataset files Args: handle: (string) the dataset handle @@ -20,4 +20,4 @@ def dataset_download(handle: str, path: Optional[str] = None, *, force_download: h = parse_dataset_handle(handle) logger.info(f"Downloading Dataset: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - return registry.dataset_resolver(h, path, force_download=force_download) + return await registry.dataset_resolver(h, path, force_download=force_download) diff --git a/src/kagglehub/exceptions.py b/src/kagglehub/exceptions.py index 86e41bfd..11f53173 100644 --- a/src/kagglehub/exceptions.py +++ b/src/kagglehub/exceptions.py @@ -1,6 +1,8 @@ from http import HTTPStatus from typing import Any, Dict, Optional +import aiohttp +import aiohttp.web_exceptions import requests from kagglehub.handle import ResourceHandle @@ -32,6 +34,18 @@ class DataCorruptionError(Exception): pass +class KaggleApiHTTP1Error(Exception): + def __init__(self, message: str, response: Optional[aiohttp.ClientResponse] = None) -> None: + self.message = message + self.response = response + + +class ColabHTTP1Error(Exception): + def __init__(self, message: str, response: Optional[aiohttp.ClientResponse] = None) -> None: + self.message = message + self.response = response + + class KaggleApiHTTPError(requests.HTTPError): def __init__(self, message: str, response: Optional[requests.Response] = None) -> None: super().__init__(message, response=response) @@ -49,58 +63,68 @@ def __init__(self, message: str = "User is not authenticated") -> None: super().__init__(message) -def kaggle_api_raise_for_status(response: requests.Response, resource_handle: Optional[ResourceHandle] = None) -> None: +def kaggle_api_raise_for_status( + response: aiohttp.ClientResponse, resource_handle: Optional[ResourceHandle] = None +) -> None: """ Wrapper around `response.raise_for_status()` that provides nicer error messages See: https://requests.readthedocs.io/en/latest/api/#requests.Response.raise_for_status """ try: response.raise_for_status() - except requests.HTTPError as e: + except aiohttp.web_exceptions.HTTPUnauthorized as e: message = str(e) resource_url = resource_handle.to_url() if resource_handle else response.url - - if response.status_code in {HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN}: - message = ( - f"{response.status_code} Client Error." - "\n\n" - f"You don't have permission to access resource at URL: {resource_url}" - "\nPlease make sure you are authenticated if you are trying to access a private resource or a resource" - " requiring consent." - ) - if response.status_code == HTTPStatus.NOT_FOUND: - message = ( - f"{response.status_code} Client Error." - "\n\n" - f"Resource not found at URL: {resource_url}" - "\nPlease make sure you specified the correct resource identifiers." - ) - - # Default handling - raise KaggleApiHTTPError(message, response=response) from e - - -def colab_raise_for_status(response: requests.Response, resource_handle: Optional[ResourceHandle] = None) -> None: + message = ( + f"{response.status_code} Client Error." + "\n\n" + f"You don't have permission to access resource at URL: {resource_url}" + "\nPlease make sure you are authenticated if you are trying to access a private resource or a resource" + " requiring consent." + ) + raise KaggleApiHTTP1Error(message, response=response) from e + except aiohttp.web_exceptions.HTTPForbidden as e: + message = str(e) + resource_url = resource_handle.to_url() if resource_handle else response.url + message = ( + f"{response.status_code} Client Error." + "\n\n" + f"You don't have permission to access resource at URL: {resource_url}" + "\nPlease make sure you are authenticated if you are trying to access a private resource or a resource" + " requiring consent." + ) + raise KaggleApiHTTP1Error(message, response=response) from e + except aiohttp.web_exceptions.HTTPNotFound as e: + message = ( + f"{response.status_code} Client Error." + "\n\n" + f"Resource not found at URL: {resource_url}" + "\nPlease make sure you specified the correct resource identifiers." + ) + raise KaggleApiHTTP1Error(message, response=response) from e + + +def colab_raise_for_status(response: aiohttp.ClientResponse, resource_handle: Optional[ResourceHandle] = None) -> None: """ Wrapper around `response.raise_for_status()` that provides nicer error messages See: https://requests.readthedocs.io/en/latest/api/#requests.Response.raise_for_status """ try: response.raise_for_status() - except requests.HTTPError as e: + except aiohttp.web_exceptions.HTTPError as e: message = str(e) resource_url = resource_handle.to_url() if resource_handle else response.url - if response.status_code in {HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN}: + if response.status in {HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN}: message = ( - f"{response.status_code} Client Error." + f"{response.status} Client Error." "\n\n" f"You don't have permission to access resource at URL: {resource_url}" "\nPlease make sure you are authenticated if you are trying to access a private resource or a resource" " requiring consent." ) # Default handling - raise ColabHTTPError(message, response=response) from e + raise ColabHTTP1Error(message, response=response) from e def process_post_response(response: Dict[str, Any]) -> None: diff --git a/src/kagglehub/gcs_upload.py b/src/kagglehub/gcs_upload.py index a50b6298..ff0fe557 100644 --- a/src/kagglehub/gcs_upload.py +++ b/src/kagglehub/gcs_upload.py @@ -7,9 +7,11 @@ from typing import Dict, List, Optional, Union import requests +from aiohttp import ClientSession from requests.exceptions import ConnectionError, Timeout from tqdm import tqdm from tqdm.utils import CallbackIOWrapper +import asyncio from kagglehub.clients import KaggleApiV1Client from kagglehub.exceptions import BackendError @@ -91,7 +93,7 @@ def _check_uploaded_size(session_uri: str, file_size: int, backoff_factor: int = return 0 # Return 0 if all retries fail -def _upload_blob(file_path: str, model_type: str) -> str: +async def _upload_blob(session: ClientSession, file_path: str, model_type: str) -> str: """Uploads a file to a remote server as a blob and returns an upload token. Parameters @@ -106,8 +108,8 @@ def _upload_blob(file_path: str, model_type: str) -> str: "contentLength": file_size, "lastModifiedEpochSeconds": int(os.path.getmtime(file_path)), } - api_client = KaggleApiV1Client() - response = api_client.post("/blobs/upload", data=data) + api_client = KaggleApiV1Client(session) + response = await api_client.post("/blobs/upload", data=data) # Validate response content if "createUrl" not in response: @@ -154,7 +156,8 @@ def _upload_blob(file_path: str, model_type: str) -> str: return response["token"] -def upload_files_and_directories( +async def upload_files_and_directories( + session: ClientSession, folder: str, model_type: str, quiet: bool = False, # noqa: FBT002, FBT001 @@ -176,18 +179,14 @@ def upload_files_and_directories( file_path = os.path.join(root, file) zipf.write(file_path, os.path.relpath(file_path, folder)) - tokens = [ - token - for token in [_upload_file_or_folder(temp_dir, TEMP_ARCHIVE_FILE, model_type, quiet)] - if token is not None - ] - return UploadDirectoryInfo(name="archive", files=tokens) + token = await _upload_file_or_folder(session, temp_dir, TEMP_ARCHIVE_FILE, model_type, quiet) + return UploadDirectoryInfo(name="archive", files=[token] if token else None) root_dict = UploadDirectoryInfo(name="root") if os.path.isfile(folder): # Directly upload the file if the path is a file file_name = os.path.basename(folder) - token = _upload_file_or_folder(os.path.dirname(folder), file_name, model_type, quiet) + token = await _upload_file_or_folder(session, os.path.dirname(folder), file_name, model_type, quiet) if token: root_dict.files.append(token) else: @@ -212,14 +211,15 @@ def upload_files_and_directories( # Add file tokens to the current directory in the dictionary for file in files: - token = _upload_file_or_folder(root, file, model_type, quiet) + token = await _upload_file_or_folder(session, root, file, model_type, quiet) if token: current_dict.files.append(token) return root_dict -def _upload_file_or_folder( +async def _upload_file_or_folder( + session: ClientSession, parent_path: str, file_or_folder_name: str, model_type: str, @@ -238,11 +238,11 @@ def _upload_file_or_folder( """ full_path = os.path.join(parent_path, file_or_folder_name) if os.path.isfile(full_path): - return _upload_file(full_path, quiet, model_type) + return await _upload_file(session, full_path, quiet, model_type) return None -def _upload_file(full_path: str, quiet: bool, model_type: str) -> Optional[str]: # noqa: FBT001 +async def _upload_file(session: ClientSession, full_path: str, quiet: bool, model_type: str) -> Optional[str]: # noqa: FBT001 """Helper function to upload a single file Parameters ========== @@ -256,7 +256,7 @@ def _upload_file(full_path: str, quiet: bool, model_type: str) -> Optional[str]: logger.info("Starting upload for file " + full_path) content_length = os.path.getsize(full_path) - token = _upload_blob(full_path, model_type) + token = await _upload_blob(session, full_path, model_type) if not quiet: logger.info("Upload successful: " + full_path + " (" + File.get_size(content_length) + ")") return token diff --git a/src/kagglehub/http_resolver.py b/src/kagglehub/http_resolver.py index 508757d0..e34f3624 100644 --- a/src/kagglehub/http_resolver.py +++ b/src/kagglehub/http_resolver.py @@ -4,6 +4,7 @@ import zipfile from typing import List, Optional, Tuple +from aiohttp import ClientSession from tqdm.contrib.concurrent import thread_map from kagglehub.cache import ( @@ -26,82 +27,39 @@ class DatasetHttpResolver(Resolver[DatasetHandle]): - def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 + async def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 # Downloading files over HTTP is supported in all environments for all handles / paths. return True - def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: - api_client = KaggleApiV1Client() - - if not h.is_versioned(): - h.version = _get_current_version(api_client, h) - - dataset_path = load_from_cache(h, path) - if dataset_path and not force_download: - return dataset_path # Already cached - elif dataset_path and force_download: - delete_from_cache(h, path) - - url_path = _build_dataset_download_url_path(h) - out_path = get_cached_path(h, path) - - # Create the intermediary directories - if path: - # Downloading a single file. - os.makedirs(os.path.dirname(out_path), exist_ok=True) - api_client.download_file(url_path + "&file_name=" + path, out_path, h) - else: - # TODO(b/345800027) Implement parallel download when < 25 files in databundle. - # Downloading the full archived bundle. - archive_path = get_cached_archive_path(h) - os.makedirs(os.path.dirname(archive_path), exist_ok=True) - - # First, we download the archive. - api_client.download_file(url_path, archive_path, h) - - # Create the directory to extract the archive to. - os.makedirs(out_path, exist_ok=True) - - _extract_archive(archive_path, out_path) - - # Delete the archive - os.remove(archive_path) - - mark_as_complete(h, path) - return out_path - - -class ModelHttpResolver(Resolver[ModelHandle]): - def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 - # Downloading files over HTTP is supported in all environments for all handles / path. - return True - - def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: - api_client = KaggleApiV1Client() - - if not h.is_versioned(): - h.version = _get_current_version(api_client, h) - - model_path = load_from_cache(h, path) - if model_path and not force_download: - return model_path # Already cached - elif model_path and force_download: - delete_from_cache(h, path) - - url_path = _build_download_url_path(h) - out_path = get_cached_path(h, path) - - # Create the intermediary directories - if path: - # Downloading a single file. - os.makedirs(os.path.dirname(out_path), exist_ok=True) - api_client.download_file(url_path + "/" + path, out_path, h) - else: - # List the files and decide how to download them: - # - <= 25 files: Download files in parallel - # > 25 files: Download the archive and uncompress - (files, has_more) = _list_files(api_client, h) - if has_more: + async def __call__( + self, + h: DatasetHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + ) -> str: + async with ClientSession() as session: + api_client = KaggleApiV1Client(session) + + if not h.is_versioned(): + h.version = await _get_current_version(api_client, h) + + dataset_path = load_from_cache(h, path) + if dataset_path and not force_download: + return dataset_path # Already cached + elif dataset_path and force_download: + delete_from_cache(h, path) + + url_path = _build_dataset_download_url_path(h) + out_path = get_cached_path(h, path) + + # Create the intermediary directories + if path: + # Downloading a single file. + os.makedirs(os.path.dirname(out_path), exist_ok=True) + api_client.download_file(url_path + "&file_name=" + path, out_path, h) + else: + # TODO(b/345800027) Implement parallel download when < 25 files in databundle. # Downloading the full archived bundle. archive_path = get_cached_archive_path(h) os.makedirs(os.path.dirname(archive_path), exist_ok=True) @@ -116,22 +74,79 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download # Delete the archive os.remove(archive_path) - else: - # Download files individually in parallel - def _inner_download_file(file: str) -> None: - file_out_path = out_path + "/" + file - os.makedirs(os.path.dirname(file_out_path), exist_ok=True) - api_client.download_file(url_path + "/" + file, file_out_path, h) - thread_map( - _inner_download_file, - files, - desc=f"Downloading {len(files)} files", - max_workers=8, # Never use more than 8 threads in parallel to download files. - ) + mark_as_complete(h, path) + return out_path + - mark_as_complete(h, path) - return out_path +class ModelHttpResolver(Resolver[ModelHandle]): + async def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 + # Downloading files over HTTP is supported in all environments for all handles / path. + return True + + async def __call__( + self, + h: ModelHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + ) -> str: + async with ClientSession() as session: + api_client = KaggleApiV1Client(session) + + if not h.is_versioned(): + h.version = await _get_current_version(api_client, h) + + model_path = load_from_cache(h, path) + if model_path and not force_download: + return model_path # Already cached + elif model_path and force_download: + delete_from_cache(h, path) + + url_path = _build_download_url_path(h) + out_path = get_cached_path(h, path) + + # Create the intermediary directories + if path: + # Downloading a single file. + os.makedirs(os.path.dirname(out_path), exist_ok=True) + api_client.download_file(url_path + "/" + path, out_path, h) + else: + # List the files and decide how to download them: + # - <= 25 files: Download files in parallel + # > 25 files: Download the archive and uncompress + (files, has_more) = await _list_files(api_client, h) + if has_more: + # Downloading the full archived bundle. + archive_path = get_cached_archive_path(h) + os.makedirs(os.path.dirname(archive_path), exist_ok=True) + + # First, we download the archive. + api_client.download_file(url_path, archive_path, h) + + # Create the directory to extract the archive to. + os.makedirs(out_path, exist_ok=True) + + _extract_archive(archive_path, out_path) + + # Delete the archive + os.remove(archive_path) + else: + # Download files individually in parallel + def _inner_download_file(file: str) -> None: + file_out_path = out_path + "/" + file + os.makedirs(os.path.dirname(file_out_path), exist_ok=True) + api_client.download_file(url_path + "/" + file, file_out_path, h) + + thread_map( + _inner_download_file, + files, + desc=f"Downloading {len(files)} files", + max_workers=8, # Never use more than 8 threads in parallel to download files. + ) + + mark_as_complete(h, path) + return out_path def _extract_archive(archive_path: str, out_path: str) -> None: @@ -147,9 +162,9 @@ def _extract_archive(archive_path: str, out_path: str) -> None: raise ValueError(msg) -def _get_current_version(api_client: KaggleApiV1Client, h: ResourceHandle) -> int: +async def _get_current_version(api_client: KaggleApiV1Client, h: ResourceHandle) -> int: if isinstance(h, ModelHandle): - json_response = api_client.get(_build_get_instance_url_path(h), h) + json_response = await api_client.get(_build_get_instance_url_path(h), h) if MODEL_INSTANCE_VERSION_FIELD not in json_response: msg = f"Invalid GetModelInstance API response. Expected to include a {MODEL_INSTANCE_VERSION_FIELD} field" raise ValueError(msg) @@ -157,7 +172,7 @@ def _get_current_version(api_client: KaggleApiV1Client, h: ResourceHandle) -> in return json_response[MODEL_INSTANCE_VERSION_FIELD] elif isinstance(h, DatasetHandle): - json_response = api_client.get(_build_get_dataset_url_path(h), h) + json_response = await api_client.get(_build_get_dataset_url_path(h), h) if DATASET_CURRENT_VERSION_FIELD not in json_response: msg = f"Invalid GetDataset API response. Expected to include a {DATASET_CURRENT_VERSION_FIELD} field" raise ValueError(msg) @@ -169,8 +184,8 @@ def _get_current_version(api_client: KaggleApiV1Client, h: ResourceHandle) -> in raise ValueError(msg) -def _list_files(api_client: KaggleApiV1Client, h: ModelHandle) -> Tuple[List[str], bool]: - json_response = api_client.get(_build_list_model_instance_version_files_url_path(h), h) +async def _list_files(api_client: KaggleApiV1Client, h: ModelHandle) -> Tuple[List[str], bool]: + json_response = await api_client.get(_build_list_model_instance_version_files_url_path(h), h) if "files" not in json_response: msg = "Invalid ListModelInstanceVersionFiles API response. Expected to include a 'files' field" raise ValueError(msg) diff --git a/src/kagglehub/kaggle_cache_resolver.py b/src/kagglehub/kaggle_cache_resolver.py index 0a7ca1cb..912b9101 100644 --- a/src/kagglehub/kaggle_cache_resolver.py +++ b/src/kagglehub/kaggle_cache_resolver.py @@ -3,6 +3,8 @@ import time from typing import Optional +import aiohttp + from kagglehub.clients import ( DEFAULT_CONNECT_TIMEOUT, KaggleJwtClient, @@ -26,7 +28,7 @@ class ModelKaggleCacheResolver(Resolver[ModelHandle]): - def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 + async def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 if is_kaggle_cache_disabled(): return False @@ -35,49 +37,52 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return False - def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + async def __call__( + self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> str: if force_download: logger.warning("Ignoring invalid input: force_download flag cannot be used in a Kaggle notebook") - client = KaggleJwtClient() - model_ref = { - "OwnerSlug": h.owner, - "ModelSlug": h.model, - "Framework": h.framework, - "InstanceSlug": h.variation, - } - if h.is_versioned(): - model_ref["VersionNumber"] = str(h.version) - - result = client.post( - ATTACH_DATASOURCE_REQUEST_NAME, - { - "modelRef": model_ref, - }, - timeout=(DEFAULT_CONNECT_TIMEOUT, ATTACH_DATASOURCE_READ_TIMEOUT), - ) - if "mountSlug" not in result: - msg = "'result.mountSlug' field missing from response" - raise BackendError(msg) - - base_mount_path = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_KAGGLE_CACHE_MOUNT_FOLDER) - cached_path = f"{base_mount_path}/{result['mountSlug']}" - - if not os.path.exists(cached_path): - # Only print this if the model is not already mounted. - logger.info(f"Mounting files to {cached_path}...", extra={**EXTRA_CONSOLE_BLOCK}) - else: - logger.info(f"Attaching '{path}' from model '{h}' to your Kaggle notebook...") - - while not os.path.exists(cached_path): - time.sleep(5) - - if path: - cached_filepath = f"{cached_path}/{path}" - if not os.path.exists(cached_filepath): - msg = ( - f"'{path}' is not present in the model files. " - f"You can access the other files of the attached model at '{cached_path}'" - ) - raise ValueError(msg) - return cached_filepath - return cached_path + async with aiohttp.ClientSession() as session: + client = KaggleJwtClient(session) + model_ref = { + "OwnerSlug": h.owner, + "ModelSlug": h.model, + "Framework": h.framework, + "InstanceSlug": h.variation, + } + if h.is_versioned(): + model_ref["VersionNumber"] = str(h.version) + + result = await client.post( + ATTACH_DATASOURCE_REQUEST_NAME, + { + "modelRef": model_ref, + }, + timeout=(DEFAULT_CONNECT_TIMEOUT, ATTACH_DATASOURCE_READ_TIMEOUT), + ) + if "mountSlug" not in result: + msg = "'result.mountSlug' field missing from response" + raise BackendError(msg) + + base_mount_path = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_KAGGLE_CACHE_MOUNT_FOLDER) + cached_path = f"{base_mount_path}/{result['mountSlug']}" + + if not os.path.exists(cached_path): + # Only print this if the model is not already mounted. + logger.info(f"Mounting files to {cached_path}...", extra={**EXTRA_CONSOLE_BLOCK}) + else: + logger.info(f"Attaching '{path}' from model '{h}' to your Kaggle notebook...") + + while not os.path.exists(cached_path): + time.sleep(5) + + if path: + cached_filepath = f"{cached_path}/{path}" + if not os.path.exists(cached_filepath): + msg = ( + f"'{path}' is not present in the model files. " + f"You can access the other files of the attached model at '{cached_path}'" + ) + raise ValueError(msg) + return cached_filepath + return cached_path diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index c712b6f3..2ceae1c7 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -1,6 +1,8 @@ import logging from typing import Optional +import aiohttp + from kagglehub import registry from kagglehub.gcs_upload import upload_files_and_directories from kagglehub.handle import parse_model_handle @@ -10,7 +12,7 @@ logger = logging.getLogger(__name__) -def model_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: +async def model_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: """Download model files. Args: @@ -24,10 +26,10 @@ def model_download(handle: str, path: Optional[str] = None, *, force_download: O """ h = parse_model_handle(handle) logger.info(f"Downloading Model: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - return registry.model_resolver(h, path, force_download=force_download) + return await registry.model_resolver(h, path, force_download=force_download) -def model_upload( +async def model_upload( handle: str, local_model_dir: str, license_name: Optional[str] = None, version_notes: str = "" ) -> None: """Upload model files. @@ -39,17 +41,23 @@ def model_upload( version_notes: (string) Optional to write to model versions. """ # parse slug + h = parse_model_handle(handle) logger.info(f"Uploading Model {h.to_url()} ...") if h.is_versioned(): is_versioned_exception = "The model handle should not include the version" raise ValueError(is_versioned_exception) - # Create the model if it doesn't already exist - create_model_if_missing(h.owner, h.model) + async with aiohttp.ClientSession() as session: + # Create the model if it doesn't already exist + await create_model_if_missing( + session, + h.owner, + h.model, + ) - # Upload the model files to GCS - tokens = upload_files_and_directories(local_model_dir, "model") + # Upload the model files to GCS + tokens = await upload_files_and_directories(session, local_model_dir, "model") - # Create a model instance if it doesn't exist, and create a new instance version if an instance exists - create_model_instance_or_version(h, tokens, license_name, version_notes) + # Create a model instance if it doesn't exist, and create a new instance version if an instance exists + await create_model_instance_or_version(session, h, tokens, license_name, version_notes) diff --git a/src/kagglehub/models_helpers.py b/src/kagglehub/models_helpers.py index 45efed79..956d6cbc 100644 --- a/src/kagglehub/models_helpers.py +++ b/src/kagglehub/models_helpers.py @@ -2,23 +2,28 @@ from http import HTTPStatus from typing import Optional +from aiohttp import ClientSession + from kagglehub.clients import BackendError, KaggleApiV1Client -from kagglehub.exceptions import KaggleApiHTTPError +from kagglehub.exceptions import KaggleApiHTTPError, KaggleApiHTTP1Error from kagglehub.gcs_upload import UploadDirectoryInfo from kagglehub.handle import ModelHandle logger = logging.getLogger(__name__) -def _create_model(owner_slug: str, model_slug: str) -> None: +async def _create_model(session: ClientSession, owner_slug: str, model_slug: str) -> None: data = {"ownerSlug": owner_slug, "slug": model_slug, "title": model_slug, "isPrivate": True} - api_client = KaggleApiV1Client() - api_client.post("/models/create/new", data) + api_client = KaggleApiV1Client(session) + await api_client.post("/models/create/new", data) logger.info(f"Model '{model_slug}' Created.") -def _create_model_instance( - model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, license_name: Optional[str] = None +async def _create_model_instance( + session: ClientSession, + model_handle: ModelHandle, + files_and_directories: UploadDirectoryInfo, + license_name: Optional[str] = None, ) -> None: serialized_data = files_and_directories.serialize() data = { @@ -30,13 +35,16 @@ def _create_model_instance( if license_name is not None: data["licenseName"] = license_name - api_client = KaggleApiV1Client() - api_client.post(f"/models/{model_handle.owner}/{model_handle.model}/create/instance", data) + api_client = KaggleApiV1Client(session) + await api_client.post(f"/models/{model_handle.owner}/{model_handle.model}/create/instance", data) logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: {model_handle.to_url()}") -def _create_model_instance_version( - model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, version_notes: str = "" +async def _create_model_instance_version( + session: ClientSession, + model_handle: ModelHandle, + files_and_directories: UploadDirectoryInfo, + version_notes: str = "", ) -> None: serialized_data = files_and_directories.serialize() data = { @@ -44,8 +52,8 @@ def _create_model_instance_version( "files": [{"token": file_token} for file_token in files_and_directories.files], "directories": serialized_data["directories"], } - api_client = KaggleApiV1Client() - api_client.post( + api_client = KaggleApiV1Client(session) + await api_client.post( f"/models/{model_handle.owner}/{model_handle.model}/{model_handle.framework}/{model_handle.variation}/create/version", data, ) @@ -54,40 +62,43 @@ def _create_model_instance_version( ) -def create_model_instance_or_version( - model_handle: ModelHandle, files: UploadDirectoryInfo, license_name: Optional[str], version_notes: str = "" +async def create_model_instance_or_version( + session: ClientSession, + model_handle: ModelHandle, + files: UploadDirectoryInfo, + license_name: Optional[str], + version_notes: str = "", ) -> None: try: - _create_model_instance(model_handle, files, license_name) + await _create_model_instance(session, model_handle, files, license_name) except BackendError as e: if e.error_code == HTTPStatus.CONFLICT: # Instance already exist, creating a new version instead. - _create_model_instance_version(model_handle, files, version_notes) + await _create_model_instance_version(session, model_handle, files, version_notes) else: raise (e) -def create_model_if_missing(owner_slug: str, model_slug: str) -> None: +async def create_model_if_missing(session: ClientSession, owner_slug: str, model_slug: str) -> None: try: - api_client = KaggleApiV1Client() - api_client.get(f"/models/{owner_slug}/{model_slug}/get") - except KaggleApiHTTPError as e: + api_client = KaggleApiV1Client(session) + await api_client.get(f"/models/{owner_slug}/{model_slug}/get") + except KaggleApiHTTP1Error as e: if e.response is not None and ( - e.response.status_code == HTTPStatus.NOT_FOUND # noqa: PLR1714 - or e.response.status_code == HTTPStatus.FORBIDDEN + e.response.status == HTTPStatus.NOT_FOUND or e.response.status == HTTPStatus.FORBIDDEN # noqa: PLR1714 ): logger.info( f"Model '{model_slug}' does not exist or access is forbidden for user '{owner_slug}'. Creating or handling Model..." # noqa: E501 ) - _create_model(owner_slug, model_slug) + await _create_model(session, owner_slug, model_slug) else: raise (e) -def delete_model(owner_slug: str, model_slug: str) -> None: +async def delete_model(session: ClientSession, owner_slug: str, model_slug: str) -> None: try: - api_client = KaggleApiV1Client() - api_client.post( + api_client = KaggleApiV1Client(session) + await api_client.post( f"/models/{owner_slug}/{model_slug}/delete", {}, ) diff --git a/src/kagglehub/registry.py b/src/kagglehub/registry.py index 7f53190d..2b21099e 100644 --- a/src/kagglehub/registry.py +++ b/src/kagglehub/registry.py @@ -16,11 +16,11 @@ def __init__(self, name: str) -> None: def add_implementation(self, impl: Callable) -> None: self._impls += [impl] - def __call__(self, *args, **kwargs): # noqa: ANN002, ANN003 + async def __call__(self, *args, **kwargs): # noqa: ANN002, ANN003 fails = [] for impl in reversed(self._impls): if impl.is_supported(*args, **kwargs): - return impl(*args, **kwargs) + return await impl(*args, **kwargs) else: fails.append(type(impl).__name__) diff --git a/src/kagglehub/resolver.py b/src/kagglehub/resolver.py index 9c16cd39..b5de5aa1 100644 --- a/src/kagglehub/resolver.py +++ b/src/kagglehub/resolver.py @@ -10,7 +10,7 @@ class Resolver(Generic[T]): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def __call__(self, handle: T, path: Optional[str], *, force_download: Optional[bool] = False) -> str: + async def __call__(self, handle: T, path: Optional[str], *, force_download: Optional[bool] = False) -> str: """Resolves a handle into a path with the requested model files. Args: @@ -25,6 +25,6 @@ def __call__(self, handle: T, path: Optional[str], *, force_download: Optional[b pass @abc.abstractmethod - def is_supported(self, handle: T, path: Optional[str]) -> bool: + async def is_supported(self, handle: T, path: Optional[str]) -> bool: """Returns whether the current environment supports this handle/path.""" pass diff --git a/tests/test_auth.py b/tests/test_auth.py index d46b8042..bfee8bd4 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -22,10 +22,10 @@ def setUpClass(cls): def tearDownClass(cls): cls.server.shutdown() - def test_login_updates_global_credentials(self) -> None: + async def test_login_updates_global_credentials(self) -> None: with mock.patch("builtins.input") as mock_input: mock_input.side_effect = ["lastplacelarry", "some-key"] - kagglehub.login() + await kagglehub.login() # Verify that the global variable contains the updated credentials credentials = get_kaggle_credentials() @@ -34,11 +34,11 @@ def test_login_updates_global_credentials(self) -> None: self.assertEqual("lastplacelarry", credentials.username) self.assertEqual("some-key", credentials.key) - def test_login_updates_global_credentials_no_validation(self) -> None: + async def test_login_updates_global_credentials_no_validation(self) -> None: # Simulate user input for credentials with mock.patch("builtins.input") as mock_input: mock_input.side_effect = ["lastplacelarry", "some-key"] - kagglehub.login(validate_credentials=False) + await kagglehub.login(validate_credentials=False) # Verify that the global variable contains the updated credentials credentials = get_kaggle_credentials() @@ -47,31 +47,31 @@ def test_login_updates_global_credentials_no_validation(self) -> None: self.assertEqual("lastplacelarry", credentials.username) self.assertEqual("some-key", credentials.key) - def test_set_kaggle_credentials_raises_error_with_empty_username(self) -> None: + async def test_set_kaggle_credentials_raises_error_with_empty_username(self) -> None: with self.assertRaises(ValueError): with mock.patch("builtins.input") as mock_input: mock_input.side_effect = ["", "some-key"] - kagglehub.login() + await kagglehub.login() - def test_set_kaggle_credentials_raises_error_with_empty_api_key(self) -> None: + async def test_set_kaggle_credentials_raises_error_with_empty_api_key(self) -> None: with self.assertRaises(ValueError): with mock.patch("builtins.input") as mock_input: mock_input.side_effect = ["lastplacelarry", ""] - kagglehub.login() + await kagglehub.login() - def test_set_kaggle_credentials_raises_error_with_empty_username_api_key(self) -> None: + async def test_set_kaggle_credentials_raises_error_with_empty_username_api_key(self) -> None: with self.assertRaises(ValueError): with mock.patch("builtins.input") as mock_input: mock_input.side_effect = ["", ""] - kagglehub.login() + await kagglehub.login() - def test_login_returns_403_for_bad_credentials(self) -> None: + async def test_login_returns_403_for_bad_credentials(self) -> None: output_stream = io.StringIO() handler = logging.StreamHandler(output_stream) logger.addHandler(handler) with mock.patch("builtins.input") as mock_input: mock_input.side_effect = ["invalid", "invalid"] - kagglehub.login() + await kagglehub.login() captured_output = output_stream.getvalue() self.assertEqual( @@ -79,21 +79,21 @@ def test_login_returns_403_for_bad_credentials(self) -> None: "Invalid Kaggle credentials. You can check your credentials on the [Kaggle settings page](https://www.kaggle.com/settings/account).\n", ) - def test_capture_logger_output(self) -> None: + async def test_capture_logger_output(self) -> None: with _capture_logger_output() as output: logger.info("This is an info message") logger.error("This is an error message") self.assertEqual(output.getvalue(), "This is an info message\nThis is an error message\n") - def test_whoami_raises_unauthenticated_error(self) -> None: + async def test_whoami_raises_unauthenticated_error(self) -> None: with self.assertRaises(UnauthenticatedError): kagglehub.whoami() - def test_whoami_success(self) -> None: + async def test_whoami_success(self) -> None: with mock.patch("builtins.input") as mock_input: mock_input.side_effect = ["lastplacelarry", "some-key"] - kagglehub.login(validate_credentials=False) + await kagglehub.login(validate_credentials=False) result = kagglehub.whoami() self.assertEqual(result, {"username": "lastplacelarry"}) diff --git a/tests/test_colab_cache_model_download.py b/tests/test_colab_cache_model_download.py index 698dc01e..2bf3b011 100644 --- a/tests/test_colab_cache_model_download.py +++ b/tests/test_colab_cache_model_download.py @@ -28,55 +28,55 @@ def setUpClass(cls): def tearDownClass(cls): cls.server.shutdown() - def test_unversioned_model_download(self) -> None: + async def test_unversioned_model_download(self) -> None: with stub.create_env(): - model_path = kagglehub.model_download(UNVERSIONED_MODEL_HANDLE) + model_path = await kagglehub.model_download(UNVERSIONED_MODEL_HANDLE) self.assertTrue(model_path.endswith("/2")) self.assertEqual(["config.json", "model.keras"], sorted(os.listdir(model_path))) - def test_versioned_model_download(self) -> None: + async def test_versioned_model_download(self) -> None: with stub.create_env(): - model_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE) + model_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE) self.assertTrue(model_path.endswith("/1")) self.assertEqual(["config.json"], sorted(os.listdir(model_path))) - def test_versioned_model_download_with_path(self) -> None: + async def test_versioned_model_download_with_path(self) -> None: with stub.create_env(): - model_file_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE, "config.json") + model_file_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE, "config.json") self.assertTrue(model_file_path.endswith("config.json")) self.assertTrue(os.path.isfile(model_file_path)) - def test_unversioned_model_download_with_path(self) -> None: + async def test_unversioned_model_download_with_path(self) -> None: with stub.create_env(): - model_file_path = kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, "config.json") + model_file_path = await kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, "config.json") self.assertTrue(model_file_path.endswith("config.json")) self.assertTrue(os.path.isfile(model_file_path)) - def test_versioned_model_download_with_missing_file_raises(self) -> None: + async def test_versioned_model_download_with_missing_file_raises(self) -> None: with stub.create_env(): with self.assertRaises(ValueError): - kagglehub.model_download(VERSIONED_MODEL_HANDLE, "missing.txt") + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, "missing.txt") - def test_unversioned_model_download_with_missing_file_raises(self) -> None: + async def test_unversioned_model_download_with_missing_file_raises(self) -> None: with stub.create_env(): with self.assertRaises(ValueError): - kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, "missing.txt") + await kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, "missing.txt") - def test_colab_resolver_skipped_when_disable_colab_cache_env_var_name(self) -> None: + async def test_colab_resolver_skipped_when_disable_colab_cache_env_var_name(self) -> None: with mock.patch.dict(os.environ, {DISABLE_COLAB_CACHE_ENV_VAR_NAME: "true"}): with stub.create_env(): # Assert that a ConnectionError is set (uses HTTP server which is not set) with self.assertRaises(requests.exceptions.ConnectionError): - kagglehub.model_download(VERSIONED_MODEL_HANDLE) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE) - def test_versioned_model_download_bad_handle_raises(self) -> None: + async def test_versioned_model_download_bad_handle_raises(self) -> None: with self.assertRaises(ValueError): - kagglehub.model_download("bad handle") + await kagglehub.model_download("bad handle") class TestNoInternetColabCacheModelDownload(BaseTestCase): - def test_colab_resolver_skipped_when_model_not_present(self) -> None: + async def test_colab_resolver_skipped_when_model_not_present(self) -> None: with stub.create_env(): # Assert that a ConnectionError is set (uses HTTP server which is not set) with self.assertRaises(requests.exceptions.ConnectionError): - kagglehub.model_download(UNAVAILABLE_MODEL_HANDLE) + await kagglehub.model_download(UNAVAILABLE_MODEL_HANDLE) diff --git a/tests/test_http_dataset_download.py b/tests/test_http_dataset_download.py index 166c9ff9..3d157d4b 100644 --- a/tests/test_http_dataset_download.py +++ b/tests/test_http_dataset_download.py @@ -39,7 +39,7 @@ def setUpClass(cls): def tearDownClass(cls): cls.server.shutdown() - def _download_dataset_and_assert_downloaded( + async def _download_dataset_and_assert_downloaded( self, d: str, dataset_handle: str, @@ -48,7 +48,7 @@ def _download_dataset_and_assert_downloaded( **kwargs, # noqa: ANN003 ) -> None: # Download the full datasets and ensure all files are there. - dataset_path = kagglehub.dataset_download(dataset_handle, **kwargs) + dataset_path = await kagglehub.dataset_download(dataset_handle, **kwargs) self.assertEqual(os.path.join(d, expected_subdir_or_subpath), dataset_path) @@ -60,8 +60,8 @@ def _download_dataset_and_assert_downloaded( archive_path = get_cached_archive_path(parse_dataset_handle(dataset_handle)) self.assertFalse(os.path.exists(archive_path)) - def _download_test_file_and_assert_downloaded(self, d: str, dataset_handle: str, **kwargs) -> None: # noqa: ANN003 - dataset_path = kagglehub.dataset_download(dataset_handle, path=TEST_FILEPATH, **kwargs) + async def _download_test_file_and_assert_downloaded(self, d: str, dataset_handle: str, **kwargs) -> None: # noqa: ANN003 + dataset_path = await kagglehub.dataset_download(dataset_handle, path=TEST_FILEPATH, **kwargs) self.assertEqual(os.path.join(d, EXPECTED_DATASET_SUBPATH), dataset_path) with zipfile.ZipFile(dataset_path, "r") as zip_ref: @@ -71,52 +71,52 @@ def _download_test_file_and_assert_downloaded(self, d: str, dataset_handle: str, self.assertEqual(TEST_CONTENTS, contents) - def test_unversioned_dataset_download(self) -> None: + async def test_unversioned_dataset_download(self) -> None: with create_test_cache() as d: - self._download_dataset_and_assert_downloaded(d, UNVERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) + await self._download_dataset_and_assert_downloaded(d, UNVERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) - def test_versioned_dataset_download(self) -> None: + async def test_versioned_dataset_download(self) -> None: with create_test_cache() as d: - self._download_dataset_and_assert_downloaded(d, VERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) + await self._download_dataset_and_assert_downloaded(d, VERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) - def test_versioned_dataset_targz_archive_download(self) -> None: + async def test_versioned_dataset_targz_archive_download(self) -> None: with create_test_cache() as d: - self._download_dataset_and_assert_downloaded( + await self._download_dataset_and_assert_downloaded( d, stub.TARGZ_ARCHIVE_HANDLE, f"{DATASETS_CACHE_SUBFOLDER}/{stub.TARGZ_ARCHIVE_HANDLE}", expected_files=[f"{i}.txt" for i in range(1, 51)], ) - def test_versioned_dataset_download_bad_archive(self) -> None: + async def test_versioned_dataset_download_bad_archive(self) -> None: with create_test_cache(): with self.assertRaises(ValueError): - kagglehub.dataset_download(INVALID_ARCHIVE_DATASET_HANDLE) + await kagglehub.dataset_download(INVALID_ARCHIVE_DATASET_HANDLE) - def test_versioned_dataset_download_with_path(self) -> None: + async def test_versioned_dataset_download_with_path(self) -> None: with create_test_cache() as d: - self._download_test_file_and_assert_downloaded(d, VERSIONED_DATASET_HANDLE) + await self._download_test_file_and_assert_downloaded(d, VERSIONED_DATASET_HANDLE) - def test_unversioned_dataset_download_with_force_download(self) -> None: + async def test_unversioned_dataset_download_with_force_download(self) -> None: with create_test_cache() as d: - self._download_dataset_and_assert_downloaded( + await self._download_dataset_and_assert_downloaded( d, UNVERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR, force_download=True ) - def test_versioned_dataset_download_with_force_download(self) -> None: + async def test_versioned_dataset_download_with_force_download(self) -> None: with create_test_cache() as d: - self._download_dataset_and_assert_downloaded( + await self._download_dataset_and_assert_downloaded( d, VERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR, force_download=True ) - def test_versioned_dataset_full_download_with_file_already_cached(self) -> None: + async def test_versioned_dataset_full_download_with_file_already_cached(self) -> None: with create_test_cache() as d: # Download a single file first - kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, path=TEST_FILEPATH) - self._download_dataset_and_assert_downloaded(d, VERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) + await kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, path=TEST_FILEPATH) + await self._download_dataset_and_assert_downloaded(d, VERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) - def test_unversioned_dataset_full_download_with_file_already_cached(self) -> None: + async def test_unversioned_dataset_full_download_with_file_already_cached(self) -> None: with create_test_cache() as d: # Download a single file first - kagglehub.dataset_download(UNVERSIONED_DATASET_HANDLE, path=TEST_FILEPATH) - self._download_dataset_and_assert_downloaded(d, UNVERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) + await kagglehub.dataset_download(UNVERSIONED_DATASET_HANDLE, path=TEST_FILEPATH) + await self._download_dataset_and_assert_downloaded(d, UNVERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) diff --git a/tests/test_http_model_download.py b/tests/test_http_model_download.py index 137be45f..55f7706a 100644 --- a/tests/test_http_model_download.py +++ b/tests/test_http_model_download.py @@ -39,7 +39,7 @@ def setUpClass(cls): def tearDownClass(cls): cls.server.shutdown() - def _download_model_and_assert_downloaded( + async def _download_model_and_assert_downloaded( self, d: str, model_handle: str, @@ -48,7 +48,7 @@ def _download_model_and_assert_downloaded( **kwargs, # noqa: ANN003 ) -> None: # Download the full model and ensure all files are there. - model_path = kagglehub.model_download(model_handle, **kwargs) + model_path = await kagglehub.model_download(model_handle, **kwargs) self.assertEqual(os.path.join(d, expected_subdir_or_subpath), model_path) if not expected_files: expected_files = ["config.json", "model.keras"] @@ -58,154 +58,156 @@ def _download_model_and_assert_downloaded( archive_path = get_cached_archive_path(parse_model_handle(model_handle)) self.assertFalse(os.path.exists(archive_path)) - def _download_test_file_and_assert_downloaded(self, d: str, model_handle: str, **kwargs) -> None: # noqa: ANN003 - model_path = kagglehub.model_download(model_handle, path=TEST_FILEPATH, **kwargs) + async def _download_test_file_and_assert_downloaded(self, d: str, model_handle: str, **kwargs) -> None: # noqa: ANN003 + model_path = await kagglehub.model_download(model_handle, path=TEST_FILEPATH, **kwargs) self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBPATH), model_path) with open(model_path) as model_file: self.assertEqual(TEST_CONTENTS, model_file.readline()) - def test_unversioned_model_download(self) -> None: + async def test_unversioned_model_download(self) -> None: with create_test_cache() as d: - self._download_model_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR) + await self._download_model_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR) - def test_versioned_model_download(self) -> None: + async def test_versioned_model_download(self) -> None: with create_test_cache() as d: - self._download_model_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR) + await self._download_model_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR) - def test_model_archive_targz_download(self) -> None: + async def test_model_archive_targz_download(self) -> None: with create_test_cache() as d: - self._download_model_and_assert_downloaded( + await self._download_model_and_assert_downloaded( d, stub.TOO_MANY_FILES_FOR_PARALLEL_DOWNLOAD_HANDLE, f"{MODELS_CACHE_SUBFOLDER}/{stub.TOO_MANY_FILES_FOR_PARALLEL_DOWNLOAD_HANDLE}", expected_files=[f"{i}.txt" for i in range(1, 51)], ) - def test_model_archive_zip_download(self) -> None: + async def test_model_archive_zip_download(self) -> None: with create_test_cache() as d: - self._download_model_and_assert_downloaded( + await self._download_model_and_assert_downloaded( d, stub.ZIP_ARCHIVE_HANDLE, f"{MODELS_CACHE_SUBFOLDER}/{stub.ZIP_ARCHIVE_HANDLE}", expected_files=[f"model-{i}.txt" for i in range(1, 27)], ) - def test_versioned_model_full_download_with_file_already_cached(self) -> None: + async def test_versioned_model_full_download_with_file_already_cached(self) -> None: with create_test_cache() as d: # Download a single file first - kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) - self._download_model_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) + await self._download_model_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR) - def test_unversioned_model_full_download_with_file_already_cached(self) -> None: + async def test_unversioned_model_full_download_with_file_already_cached(self) -> None: with create_test_cache() as d: # Download a single file first - kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) - self._download_model_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR) + await kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) + await self._download_model_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR) - def test_unversioned_model_download_with_force_download(self) -> None: + async def test_unversioned_model_download_with_force_download(self) -> None: with create_test_cache() as d: - self._download_model_and_assert_downloaded( + await self._download_model_and_assert_downloaded( d, UNVERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR, force_download=True ) - def test_versioned_model_download_with_force_download(self) -> None: + async def test_versioned_model_download_with_force_download(self) -> None: with create_test_cache() as d: - self._download_model_and_assert_downloaded( + await self._download_model_and_assert_downloaded( d, VERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR, force_download=True ) - def test_versioned_model_full_download_with_file_already_cached_and_force_download(self) -> None: + async def test_versioned_model_full_download_with_file_already_cached_and_force_download(self) -> None: with create_test_cache() as d: # Download a single file first - kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) - self._download_model_and_assert_downloaded( + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) + await self._download_model_and_assert_downloaded( d, VERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR, force_download=True ) - def test_unversioned_model_full_download_with_file_already_cached_and_force_download(self) -> None: + async def test_unversioned_model_full_download_with_file_already_cached_and_force_download(self) -> None: with create_test_cache() as d: # Download a single file first - kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) - self._download_model_and_assert_downloaded( + await kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) + await self._download_model_and_assert_downloaded( d, UNVERSIONED_MODEL_HANDLE, EXPECTED_MODEL_SUBDIR, force_download=True ) - def test_versioned_model_download_bad_archive(self) -> None: + async def test_versioned_model_download_bad_archive(self) -> None: with create_test_cache(): with self.assertRaises(ValueError): - kagglehub.model_download(stub.INVALID_ARCHIVE_HANDLE) + await kagglehub.model_download(stub.INVALID_ARCHIVE_HANDLE) - def test_versioned_model_download_with_path(self) -> None: + async def test_versioned_model_download_with_path(self) -> None: with create_test_cache() as d: - self._download_test_file_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE) + await self._download_test_file_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE) - def test_unversioned_model_download_with_path(self) -> None: + async def test_unversioned_model_download_with_path(self) -> None: with create_test_cache() as d: - self._download_test_file_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE) + await self._download_test_file_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE) - def test_versioned_model_download_with_path_with_force_download(self) -> None: + async def test_versioned_model_download_with_path_with_force_download(self) -> None: with create_test_cache() as d: - self._download_test_file_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, force_download=True) + await self._download_test_file_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, force_download=True) - def test_unversioned_model_download_with_path_with_force_download(self) -> None: + async def test_unversioned_model_download_with_path_with_force_download(self) -> None: with create_test_cache() as d: - self._download_test_file_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, force_download=True) + await self._download_test_file_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, force_download=True) - def test_versioned_model_download_already_cached(self) -> None: + async def test_versioned_model_download_already_cached(self) -> None: with create_test_cache() as d: # Download from server. - kagglehub.model_download(VERSIONED_MODEL_HANDLE) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE) # No internet, cache hit. - model_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE) + model_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE) self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBDIR), model_path) - def test_versioned_model_download_with_path_already_cached(self) -> None: + async def test_versioned_model_download_with_path_already_cached(self) -> None: with create_test_cache() as d: - kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) # No internet, cache hit. - model_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) + model_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBPATH), model_path) - def test_versioned_model_download_already_cached_with_force_download_explicit_false(self) -> None: + async def test_versioned_model_download_already_cached_with_force_download_explicit_false(self) -> None: with create_test_cache() as d: - kagglehub.model_download(VERSIONED_MODEL_HANDLE) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE) # Not force downloaded, cache hit. - model_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download=False) + model_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download=False) self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBDIR), model_path) - def test_versioned_model_download_with_path_already_cached_with_force_download_explicit_false(self) -> None: + async def test_versioned_model_download_with_path_already_cached_with_force_download_explicit_false(self) -> None: with create_test_cache() as d: - kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) # Not force downloaded, cache hit. - model_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH, force_download=False) + model_path = await kagglehub.model_download( + VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH, force_download=False + ) self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBPATH), model_path) class TestHttpNoInternet(BaseTestCase): - def test_versioned_model_download_already_cached_with_force_download(self) -> None: + async def test_versioned_model_download_already_cached_with_force_download(self) -> None: with create_test_cache(): server = serv.start_server(stub.app) - kagglehub.model_download(VERSIONED_MODEL_HANDLE) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE) server.shutdown() # No internet should throw an error. with self.assertRaises(requests.exceptions.ConnectionError): - kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download=True) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download=True) - def test_versioned_model_download_with_path_already_cached_with_force_download(self) -> None: + async def test_versioned_model_download_with_path_already_cached_with_force_download(self) -> None: with create_test_cache(): server = serv.start_server(stub.app) - kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH) server.shutdown() # No internet should throw an error. with self.assertRaises(requests.exceptions.ConnectionError): - kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH, force_download=True) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, path=TEST_FILEPATH, force_download=True) diff --git a/tests/test_kaggle_api_client.py b/tests/test_kaggle_api_client.py index 4b8a6cf6..6ab3b976 100644 --- a/tests/test_kaggle_api_client.py +++ b/tests/test_kaggle_api_client.py @@ -8,6 +8,8 @@ from .server_stubs import kaggle_api_stub as stub from .server_stubs import serv +from aiohttp import ClientSession + class TestKaggleApiV1Client(BaseTestCase): @classmethod @@ -18,17 +20,18 @@ def setUpClass(cls): def tearDownClass(cls): cls.server.shutdown() - def test_download_with_integrity_check(self) -> None: + async def test_download_with_integrity_check(self) -> None: with TemporaryDirectory() as d: out_file = os.path.join(d, "out") - api_client = KaggleApiV1Client() - api_client.download_file("good", out_file) + async with ClientSession() as session: + api_client = KaggleApiV1Client(session) + api_client.download_file("good", out_file) - with open(out_file) as f: - self.assertEqual("foo", f.read()) + with open(out_file) as f: + self.assertEqual("foo", f.read()) - def test_resumable_download_with_integrity_check(self) -> None: + async def test_resumable_download_with_integrity_check(self) -> None: with TemporaryDirectory() as d: out_file = os.path.join(d, "out") @@ -36,31 +39,34 @@ def test_resumable_download_with_integrity_check(self) -> None: with open(out_file, "w") as f: f.write("fo") # Should download the remaining "o". - api_client = KaggleApiV1Client() - with self.assertLogs("kagglehub", level="INFO") as cm: - api_client.download_file("good", out_file) - self.assertIn("INFO:kagglehub.clients:Resuming download from 2 bytes (1 bytes left)...", cm.output) + async with ClientSession() as session: + api_client = KaggleApiV1Client(session) + with self.assertLogs("kagglehub", level="INFO") as cm: + api_client.download_file("good", out_file) + self.assertIn("INFO:kagglehub.clients:Resuming download from 2 bytes (1 bytes left)...", cm.output) - with open(out_file) as f: - self.assertEqual("foo", f.read()) + with open(out_file) as f: + self.assertEqual("foo", f.read()) - def test_download_no_integrity_check(self) -> None: + async def test_download_no_integrity_check(self) -> None: with TemporaryDirectory() as d: out_file = os.path.join(d, "out") - api_client = KaggleApiV1Client() - api_client.download_file("no-integrity", out_file) + async with ClientSession() as session: + api_client = KaggleApiV1Client(session) + api_client.download_file("no-integrity", out_file) - with open(out_file) as f: - self.assertEqual("foo", f.read()) + with open(out_file) as f: + self.assertEqual("foo", f.read()) - def test_download_corrupted_file_fail_integrity_check(self) -> None: + async def test_download_corrupted_file_fail_integrity_check(self) -> None: with TemporaryDirectory() as d: out_file = os.path.join(d, "out") - api_client = KaggleApiV1Client() - with self.assertRaises(DataCorruptionError): - api_client.download_file("corrupted", out_file) + async with ClientSession() as session: + api_client = KaggleApiV1Client(session) + with self.assertRaises(DataCorruptionError): + api_client.download_file("corrupted", out_file) - # Assert the corrupted file has been deleted. - self.assertFalse(os.path.exists(out_file)) + # Assert the corrupted file has been deleted. + self.assertFalse(os.path.exists(out_file)) diff --git a/tests/test_kaggle_cache_model_download.py b/tests/test_kaggle_cache_model_download.py index b66ec1cd..3ca9d853 100644 --- a/tests/test_kaggle_cache_model_download.py +++ b/tests/test_kaggle_cache_model_download.py @@ -28,63 +28,63 @@ def setUpClass(cls): def tearDownClass(cls): cls.server.shutdown() - def test_unversioned_model_download(self) -> None: + async def test_unversioned_model_download(self) -> None: with stub.create_env(): - model_path = kagglehub.model_download(UNVERSIONED_MODEL_HANDLE) + model_path = await kagglehub.model_download(UNVERSIONED_MODEL_HANDLE) self.assertTrue(model_path.endswith("/2")) self.assertEqual(["config.json", "model.keras"], sorted(os.listdir(model_path))) - def test_versioned_model_download(self) -> None: + async def test_versioned_model_download(self) -> None: with stub.create_env(): - model_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE) + model_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE) self.assertTrue(model_path.endswith("/1")) self.assertEqual(["config.json"], sorted(os.listdir(model_path))) - def test_versioned_model_download_with_path(self) -> None: + async def test_versioned_model_download_with_path(self) -> None: with stub.create_env(): - model_file_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE, "config.json") + model_file_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE, "config.json") self.assertTrue(model_file_path.endswith("config.json")) self.assertTrue(os.path.isfile(model_file_path)) - def test_unversioned_model_download_with_path(self) -> None: + async def test_unversioned_model_download_with_path(self) -> None: with stub.create_env(): - model_file_path = kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, "config.json") + model_file_path = await kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, "config.json") self.assertTrue(model_file_path.endswith("config.json")) self.assertTrue(os.path.isfile(model_file_path)) - def test_versioned_model_download_with_missing_file_raises(self) -> None: + async def test_versioned_model_download_with_missing_file_raises(self) -> None: with stub.create_env(): with self.assertRaises(ValueError): - kagglehub.model_download(VERSIONED_MODEL_HANDLE, "missing.txt") + await kagglehub.model_download(VERSIONED_MODEL_HANDLE, "missing.txt") - def test_unversioned_model_download_with_missing_file_raises(self) -> None: + async def test_unversioned_model_download_with_missing_file_raises(self) -> None: with stub.create_env(): with self.assertRaises(ValueError): - kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, "missing.txt") + await kagglehub.model_download(UNVERSIONED_MODEL_HANDLE, "missing.txt") - def test_kaggle_resolver_skipped(self) -> None: + async def test_kaggle_resolver_skipped(self) -> None: with mock.patch.dict(os.environ, {DISABLE_KAGGLE_CACHE_ENV_VAR_NAME: "true"}): with stub.create_env(): # Assert that a ConnectionError is set (uses HTTP server which is not set) with self.assertRaises(requests.exceptions.ConnectionError): - kagglehub.model_download(VERSIONED_MODEL_HANDLE) + await kagglehub.model_download(VERSIONED_MODEL_HANDLE) - def test_versioned_model_download_bad_handle_raises(self) -> None: + async def test_versioned_model_download_bad_handle_raises(self) -> None: with self.assertRaises(ValueError): - kagglehub.model_download("bad handle") + await kagglehub.model_download("bad handle") - def test_versioned_model_download_with_force_download(self) -> None: + async def test_versioned_model_download_with_force_download(self) -> None: with stub.create_env(): - model_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE) - model_path_forced = kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download=True) + model_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE) + model_path_forced = await kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download=True) # Using force_download shouldn't change the expected output of model_download. self.assertTrue(model_path_forced.endswith("/1")) self.assertEqual(["config.json"], sorted(os.listdir(model_path_forced))) self.assertEqual(model_path, model_path_forced) - def test_versioned_model_download_with_force_download_explicitly_false(self) -> None: + async def test_versioned_model_download_with_force_download_explicitly_false(self) -> None: with stub.create_env(): - model_path = kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download=False) + model_path = await kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download=False) self.assertTrue(model_path.endswith("/1")) self.assertEqual(["config.json"], sorted(os.listdir(model_path))) diff --git a/tests/test_model_upload.py b/tests/test_model_upload.py index b2bfea5e..92ecb11e 100644 --- a/tests/test_model_upload.py +++ b/tests/test_model_upload.py @@ -26,22 +26,22 @@ def setUpClass(cls): def tearDownClass(cls): cls.server.shutdown() - def test_model_upload_with_invalid_handle(self) -> None: + async def test_model_upload_with_invalid_handle(self) -> None: with self.assertRaises(ValueError): with TemporaryDirectory() as temp_dir: test_filepath = Path(temp_dir) / TEMP_TEST_FILE test_filepath.touch() # Create a temporary file in the temporary directory - model_upload("invalid/invalid/invalid", temp_dir, APACHE_LICENSE, "model_type") + await model_upload("invalid/invalid/invalid", temp_dir, APACHE_LICENSE, "model_type") - def test_model_upload_instance_with_valid_handle(self) -> None: + async def test_model_upload_instance_with_valid_handle(self) -> None: with TemporaryDirectory() as temp_dir: test_filepath = Path(temp_dir) / TEMP_TEST_FILE test_filepath.touch() # Create a temporary file in the temporary directory - model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") + await model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") self.assertEqual(len(stub.shared_data.files), 1) self.assertIn(TEMP_TEST_FILE, stub.shared_data.files) - def test_model_upload_instance_with_nested_directories(self) -> None: + async def test_model_upload_instance_with_nested_directories(self) -> None: with TemporaryDirectory() as temp_dir: # Create a nested directory structure nested_dir = Path(temp_dir) / "nested" @@ -49,77 +49,77 @@ def test_model_upload_instance_with_nested_directories(self) -> None: # Create a temporary file in the nested directory test_filepath = nested_dir / TEMP_TEST_FILE test_filepath.touch() - model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") + await model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") self.assertEqual(len(stub.shared_data.files), 1) self.assertIn(TEMP_TEST_FILE, stub.shared_data.files) - def test_model_upload_version_with_valid_handle(self) -> None: + async def test_model_upload_version_with_valid_handle(self) -> None: with TemporaryDirectory() as temp_dir: test_filepath = Path(temp_dir) / TEMP_TEST_FILE test_filepath.touch() # Create a temporary file in the temporary directory - model_upload("metaresearch/llama-2/pyTorch/7b", temp_dir, APACHE_LICENSE, "model_type") + await model_upload("metaresearch/llama-2/pyTorch/7b", temp_dir, APACHE_LICENSE, "model_type") self.assertEqual(len(stub.shared_data.files), 1) self.assertIn(TEMP_TEST_FILE, stub.shared_data.files) - def test_model_upload_with_too_many_files(self) -> None: + async def test_model_upload_with_too_many_files(self) -> None: with TemporaryDirectory() as temp_dir: # Create more than 50 temporary files in the directory for i in range(MAX_FILES_TO_UPLOAD + 1): test_filepath = Path(temp_dir) / f"temp_test_file_{i}" test_filepath.touch() - model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") + await model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") self.assertEqual(len(stub.shared_data.files), 1) self.assertIn(TEMP_ARCHIVE_FILE, stub.shared_data.files) - def test_model_upload_resumable(self) -> None: + async def test_model_upload_resumable(self) -> None: stub.simulate_308(state=True) # Enable simulation of 308 response for this test with TemporaryDirectory() as temp_dir: test_filepath = Path(temp_dir) / TEMP_TEST_FILE test_filepath.touch() with open(test_filepath, "wb") as f: f.write(os.urandom(1000)) - model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") + await model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") self.assertGreaterEqual(stub.shared_data.blob_request_count, 1) self.assertEqual(len(stub.shared_data.files), 1) self.assertIn(TEMP_TEST_FILE, stub.shared_data.files) - def test_model_upload_with_none_license(self) -> None: + async def test_model_upload_with_none_license(self) -> None: with TemporaryDirectory() as temp_dir: test_filepath = Path(temp_dir) / TEMP_TEST_FILE test_filepath.touch() # Create a temporary file in the temporary directory - model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, None, "model_type") + await model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, None, "model_type") self.assertEqual(len(stub.shared_data.files), 1) self.assertIn(TEMP_TEST_FILE, stub.shared_data.files) - def test_model_upload_without_license(self) -> None: + async def test_model_upload_without_license(self) -> None: with TemporaryDirectory() as temp_dir: test_filepath = Path(temp_dir) / TEMP_TEST_FILE test_filepath.touch() # Create a temporary file in the temporary directory - model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, version_notes="model_type") + await model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, version_notes="model_type") self.assertEqual(len(stub.shared_data.files), 1) self.assertIn(TEMP_TEST_FILE, stub.shared_data.files) - def test_model_upload_with_invalid_license_fails(self) -> None: + async def test_model_upload_with_invalid_license_fails(self) -> None: with TemporaryDirectory() as temp_dir: test_filepath = Path(temp_dir) / TEMP_TEST_FILE test_filepath.touch() # Create a temporary file in the temporary directory with self.assertRaises(BackendError): - model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, "Invalid License") + await model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, "Invalid License") - def test_single_file_upload(self) -> None: + async def test_single_file_upload(self) -> None: with TemporaryDirectory() as temp_dir: test_filepath = Path(temp_dir) / "single_dummy_file.txt" with open(test_filepath, "wb") as f: f.write(os.urandom(100)) - model_upload( + await model_upload( "metaresearch/new-model/pyTorch/new-variation", str(test_filepath), APACHE_LICENSE, "model_type" ) self.assertEqual(len(stub.shared_data.files), 1) self.assertIn("single_dummy_file.txt", stub.shared_data.files) - def test_model_upload_with_directory_structure(self) -> None: + async def test_model_upload_with_directory_structure(self) -> None: with TemporaryDirectory() as temp_dir: base_path = Path(temp_dir) (base_path / "dir1").mkdir() @@ -133,7 +133,7 @@ def test_model_upload_with_directory_structure(self) -> None: (base_path / "dir1" / "subdir1").mkdir() (base_path / "dir1" / "subdir1" / "file4.txt").touch() - model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") + await model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") self.assertEqual(len(stub.shared_data.files), 4) expected_files = {"file1.txt", "file2.txt", "file3.txt", "file4.txt"} diff --git a/tests/test_registry.py b/tests/test_registry.py index 0f07bf68..eee5bb31 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -16,34 +16,35 @@ def __init__(self, is_supported_fn: Callable[..., bool], call_fn: Callable[..., self._is_supported_fn = is_supported_fn self._call_fn = call_fn - def is_supported(self, *args: Any, **kwargs: Any) -> bool: # noqa: ANN401 + async def is_supported(self, *args: Any, **kwargs: Any) -> bool: # noqa: ANN401 return self._is_supported_fn(*args, **kwargs) - def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + async def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 return self._call_fn(*args, **kwargs) class RegistryTest(BaseTestCase): - def test_calls_only_supported(self) -> None: + async def test_calls_only_supported(self) -> None: r = registry.MultiImplRegistry("test") r.add_implementation(FakeImpl(lambda _: True, lambda _: SOME_VALUE)) r.add_implementation(FakeImpl(lambda _: False, fail_fn)) - val = r(SOME_VALUE) + val = await r(SOME_VALUE) self.assertEqual(SOME_VALUE, val) - def test_calls_first_supported_reverse(self) -> None: + async def test_calls_first_supported_reverse(self) -> None: r = registry.MultiImplRegistry("test") r.add_implementation(FakeImpl(lambda _: True, fail_fn)) r.add_implementation(FakeImpl(lambda _: True, lambda _: SOME_VALUE)) - val = r(SOME_VALUE) + val = await r(SOME_VALUE) self.assertEqual(SOME_VALUE, val) - def test_calls_throw_not_supported(self) -> None: + async def test_calls_throw_not_supported(self) -> None: r = registry.MultiImplRegistry("test") r.add_implementation(FakeImpl(lambda _: False, fail_fn)) - self.assertRaisesRegex(RuntimeError, r"Missing implementation", r, SOME_VALUE) + with self.assertRaisesRegex(RuntimeError, r"Missing implementation"): + await r(SOME_VALUE)