Skip to content

Commit

Permalink
feat: add names to clients (#239)
Browse files Browse the repository at this point in the history
* feat: add name to client

* feat: add client override to cli

* chore: update changelog

* feat: dynamically look up subclasses
  • Loading branch information
gadomski authored Nov 5, 2024
1 parent fa1d83d commit c164de2
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added

- `--http-header` CLI argument ([#238](https://github.com/stac-utils/stac-asset/pull/238))
- `--client` CLI argument, names to clients ([#239](https://github.com/stac-utils/stac-asset/pull/239))

## [0.4.5] - 2024-10-29

Expand Down
3 changes: 2 additions & 1 deletion src/stac_asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
open_href,
read_href,
)
from .client import Client
from .client import Client, get_client_classes
from .config import Config
from .earthdata_client import EarthdataClient
from .errors import (
Expand Down Expand Up @@ -65,6 +65,7 @@
"download_item",
"download_item_collection",
"download_file",
"get_client_classes",
"open_href",
"read_href",
]
13 changes: 12 additions & 1 deletion src/stac_asset/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pystac import Asset, Item, ItemCollection

from . import Config, ErrorStrategy, _functions
from .client import Clients
from .client import Clients, get_client_classes
from .config import (
DEFAULT_HTTP_CLIENT_TIMEOUT,
DEFAULT_HTTP_MAX_ATTEMPTS,
Expand Down Expand Up @@ -61,6 +61,13 @@ def cli() -> None:
@cli.command()
@click.argument("href", required=False)
@click.argument("directory", required=False)
@click.option(
"-c",
"--client",
type=Choice([c.name for c in get_client_classes()]),
help="Set the client to use for all downloads. If not "
"provided, the client will be guessed from the asset href.",
)
@click.option(
"-p",
"--path-template",
Expand Down Expand Up @@ -164,6 +171,7 @@ def cli() -> None:
def download(
href: str | None,
directory: str | None,
client: str | None,
path_template: str | None,
alternate_assets: list[str],
include: list[str],
Expand Down Expand Up @@ -213,6 +221,7 @@ def download(
download_async(
href,
directory,
client,
path_template,
alternate_assets,
include,
Expand All @@ -237,6 +246,7 @@ def download(
async def download_async(
href: str | None,
directory: str | None,
client: str | None,
path_template: str | None,
alternate_assets: list[str],
include: list[str],
Expand Down Expand Up @@ -275,6 +285,7 @@ async def download_async(
warn=not fail_fast,
fail_fast=fail_fast,
overwrite=overwrite,
client_override=client,
)

input_dict = await read_as_dict(href, config)
Expand Down
42 changes: 35 additions & 7 deletions src/stac_asset/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
class Client(ABC):
"""An abstract base class for all clients."""

name: str
"""The name of this client."""

@classmethod
async def from_config(cls: type[T], config: Config) -> T:
"""Creates a client using the provided configuration.
Expand Down Expand Up @@ -184,7 +187,7 @@ class Clients:
"""An async-safe cache of clients."""

lock: Lock
clients: dict[type[Client], Client]
clients: dict[str, Client]
config: Config

def __init__(self, config: Config, clients: list[Client] | None = None) -> None:
Expand All @@ -193,7 +196,7 @@ def __init__(self, config: Config, clients: list[Client] | None = None) -> None:
if clients:
# TODO check for duplicate types in clients list
for client in clients:
self.clients[type(client)] = client
self.clients[client.name] = client
self.config = config

async def get_client(self, href: str) -> Client:
Expand All @@ -205,15 +208,21 @@ async def get_client(self, href: str) -> Client:
Returns:
Client: An instance of that client.
"""
# TODO allow dynamic registration of new clients, e.g. via a plugin mechanism

from .earthdata_client import EarthdataClient
from .filesystem_client import FilesystemClient
from .http_client import HttpClient
from .planetary_computer_client import PlanetaryComputerClient
from .s3_client import S3Client

url = URL(href)
if not url.host:
client_class: type[Client] = FilesystemClient
if self.config.client_override:
client_class: type[Client] = _get_client_class_by_name(
self.config.client_override
)
elif not url.host:
client_class = FilesystemClient
elif url.scheme == "s3":
client_class = S3Client
elif url.host.endswith("blob.core.windows.net"):
Expand All @@ -226,15 +235,34 @@ async def get_client(self, href: str) -> Client:
raise ValueError(f"could not guess client class for href: {href}")

async with self.lock:
if client_class in self.clients:
return self.clients[client_class]
if client_class.name in self.clients:
return self.clients[client_class.name]
else:
client = await client_class.from_config(self.config)
self.clients[client_class] = client
self.clients[client_class.name] = client
return client

async def close_all(self) -> None:
"""Close all clients."""
async with self.lock:
for client in self.clients.values():
await client.close()


def _get_client_class_by_name(name: str) -> type[Client]:
for client_class in get_client_classes():
if client_class.name == name:
return client_class
raise ValueError(f"no client with name: {name}")


def get_client_classes() -> list[type[Client]]:
"""Returns a list of all known subclasses of Client."""

# https://stackoverflow.com/questions/3862310/how-to-find-all-the-subclasses-of-a-class-given-its-name
def all_subclasses(cls: type[Client]) -> set[type[Client]]:
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)]
)

return list(all_subclasses(Client)) # type: ignore
6 changes: 6 additions & 0 deletions src/stac_asset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class Config:
overwrite: bool = False
"""Download files even if they already exist locally."""

client_override: str | None = None
"""Use the same client for all asset requests.
If not set, each asset's client will be guessed from its href.
"""

http_client_timeout: float | None = DEFAULT_HTTP_CLIENT_TIMEOUT
"""Total number of seconds for the whole request."""

Expand Down
2 changes: 2 additions & 0 deletions src/stac_asset/earthdata_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class EarthdataClient(HttpClient):
3. Use :py:meth:`EarthdataClient.from_config()` to create a new client.
"""

name = "earthdata"

@classmethod
async def from_config(cls, config: Config) -> EarthdataClient:
"""Logs in to Earthdata and returns the default earthdata client.
Expand Down
2 changes: 2 additions & 0 deletions src/stac_asset/filesystem_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class FilesystemClient(Client):
Mostly used for testing, but could be useful in some real-world cases.
"""

name = "filesystem"

async def open_url(
self,
url: URL,
Expand Down
2 changes: 2 additions & 0 deletions src/stac_asset/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class HttpClient(Client):
Configure the session to customize its behavior.
"""

name = "http"

@classmethod
async def from_config(cls: type[T], config: Config) -> T:
"""Creates an HTTP client with an aiohttp session object.
Expand Down
2 changes: 2 additions & 0 deletions src/stac_asset/planetary_computer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class PlanetaryComputerClient(HttpClient):
thanks Tom Augspurger!
"""

name = "planetary-computer"

def __init__(
self,
session: ClientSession,
Expand Down
2 changes: 2 additions & 0 deletions src/stac_asset/s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class S3Client(Client):
for instructions.
"""

name = "s3"

@classmethod
async def from_config(cls, config: Config) -> S3Client:
"""Creates an s3 client from a config.
Expand Down

0 comments on commit c164de2

Please sign in to comment.