Skip to content

Commit

Permalink
Support for git-based dataset repos (#14)
Browse files Browse the repository at this point in the history
* Export more constants

* unrelated change= do not gobble certain kinds of requests.ConnectionError

* datasets: file_download

cc @lhoestq

* Exposing this as a constant

* Implement hf_api for datasets + Split constants into their own file

* integration test + CLI
  • Loading branch information
julien-c authored Feb 12, 2021
1 parent 407c838 commit 2bfc9dd
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Integration inside a library is super simple. We expose two functions, `hf_hub_u
### `hf_hub_url`

`hf_hub_url()` takes:
- a model id (like `julien-c/EsperBERTo-small` i.e. a user or organization name and a repo name, separated by `/`),
- a repo id (e.g. a model id like `julien-c/EsperBERTo-small` i.e. a user or organization name and a repo name, separated by `/`),
- a filename (like `pytorch_model.bin`),
- and an optional git revision id (can be a branch name, a tag, or a commit hash)

Expand Down
12 changes: 11 additions & 1 deletion src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,15 @@

__version__ = "0.0.1"

from .file_download import HUGGINGFACE_CO_URL_TEMPLATE, cached_download, hf_hub_url
from .constants import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_URL_HOME,
HUGGINGFACE_CO_URL_TEMPLATE,
PYTORCH_WEIGHTS_NAME,
REPO_TYPE_DATASET,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
)
from .file_download import cached_download, hf_hub_url
from .hf_api import HfApi, HfFolder
32 changes: 25 additions & 7 deletions src/huggingface_hub/commands/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from typing import List, Union

from huggingface_hub.commands import BaseHuggingfaceCLICommand
from huggingface_hub.constants import (
REPO_TYPE_DATASET,
REPO_TYPE_DATASET_URL_PREFIX,
REPO_TYPES,
)
from huggingface_hub.hf_api import HfApi, HfFolder
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -57,7 +62,12 @@ def register_subcommand(parser: ArgumentParser):
repo_create_parser.add_argument(
"name",
type=str,
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
help="Name for your repo. Will be namespaced under your username to build the repo id.",
)
repo_create_parser.add_argument(
"--type",
type=str,
help='Optional: repo_type: set to "dataset" if creating a dataset, default is model.',
)
repo_create_parser.add_argument(
"--organization", type=str, help="Optional: organization namespace."
Expand Down Expand Up @@ -223,11 +233,16 @@ def run(self):
self.args.organization if self.args.organization is not None else user
)

print(
"You are about to create {}".format(
ANSI.bold(namespace + "/" + self.args.name)
)
)
repo_id = f"{namespace}/{self.args.name}"

if self.args.type not in REPO_TYPES:
print("Invalid repo --type")
exit(1)

if self.args.type == REPO_TYPE_DATASET:
repo_id = REPO_TYPE_DATASET_URL_PREFIX + repo_id

print("You are about to create {}".format(ANSI.bold(repo_id)))

if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
Expand All @@ -236,7 +251,10 @@ def run(self):
exit()
try:
url = self._api.create_repo(
token, name=self.args.name, organization=self.args.organization
token,
name=self.args.name,
organization=self.args.organization,
repo_type=self.args.type,
)
except HTTPError as e:
print(e)
Expand Down
32 changes: 32 additions & 0 deletions src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os


# Constants for file downloads

PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
CONFIG_NAME = "config.json"

HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/"

HUGGINGFACE_CO_URL_TEMPLATE = (
"https://huggingface.co/{repo_id}/resolve/{revision}/{filename}"
)

REPO_TYPE_DATASET = "dataset"
REPO_TYPES = [None, REPO_TYPE_DATASET]

REPO_TYPE_DATASET_URL_PREFIX = "datasets/"


# default cache
hf_cache_home = os.path.expanduser(
os.getenv(
"HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")
)
)
default_cache_path = os.path.join(hf_cache_home, "hub")

HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)
55 changes: 27 additions & 28 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
from filelock import FileLock

from . import __version__
from .constants import (
HUGGINGFACE_CO_URL_TEMPLATE,
HUGGINGFACE_HUB_CACHE,
REPO_TYPE_DATASET,
REPO_TYPE_DATASET_URL_PREFIX,
REPO_TYPES,
)
from .hf_api import HfFolder


Expand Down Expand Up @@ -55,34 +62,11 @@ def is_tf_available():
return _tf_available


# Constants for file downloads

PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
CONFIG_NAME = "config.json"

HUGGINGFACE_CO_URL_TEMPLATE = (
"https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
)


# default cache
hf_cache_home = os.path.expanduser(
os.getenv(
"HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")
)
)
default_cache_path = os.path.join(hf_cache_home, "hub")

HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)


def hf_hub_url(
model_id: str,
repo_id: str,
filename: str,
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
) -> str:
"""
Expand All @@ -103,10 +87,16 @@ def hf_hub_url(
if subfolder is not None:
filename = f"{subfolder}/{filename}"

if repo_type not in REPO_TYPES:
raise ValueError("Invalid repo type")

if repo_type == REPO_TYPE_DATASET:
repo_id = REPO_TYPE_DATASET_URL_PREFIX + repo_id

if revision is None:
revision = "main"
return HUGGINGFACE_CO_URL_TEMPLATE.format(
model_id=model_id, revision=revision, filename=filename
repo_id=repo_id, revision=revision, filename=filename
)


Expand Down Expand Up @@ -286,8 +276,17 @@ def cached_download(
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# etag is already None
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
) as exc:
# Actually raise for those subclasses of ConnectionError:
if isinstance(exc, requests.exceptions.SSLError) or isinstance(
exc, requests.exceptions.ProxyError
):
raise exc
# Otherwise, our Internet connection is down.
# etag is None
pass

filename = url_to_filename(url, etag)
Expand Down
29 changes: 27 additions & 2 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import requests

from .constants import REPO_TYPES


ENDPOINT = "https://huggingface.co"

Expand Down Expand Up @@ -139,6 +141,7 @@ def create_repo(
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
repo_type: Optional[str] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str:
Expand All @@ -150,12 +153,20 @@ def create_repo(
Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
repo_type: Set to "dataset" if creating a dataset, default is model
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes.
"""
path = "{}/api/repos/create".format(self.endpoint)

if repo_type not in REPO_TYPES:
raise ValueError("Invalid repo type")

json = {"name": name, "organization": organization, "private": private}
if repo_type is not None:
json["type"] = repo_type
if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
Expand All @@ -169,7 +180,13 @@ def create_repo(
d = r.json()
return d["url"]

def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
def delete_repo(
self,
token: str,
name: str,
organization: Optional[str] = None,
repo_type: Optional[str] = None,
):
"""
HuggingFace git-based system, used for models.
Expand All @@ -178,10 +195,18 @@ def delete_repo(self, token: str, name: str, organization: Optional[str] = None)
CAUTION(this is irreversible).
"""
path = "{}/api/repos/delete".format(self.endpoint)

if repo_type not in REPO_TYPES:
raise ValueError("Invalid repo type")

json = {"name": name, "organization": organization}
if repo_type is not None:
json["type"] = repo_type

r = requests.delete(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"name": name, "organization": organization},
json=json,
)
r.raise_for_status()

Expand Down
50 changes: 45 additions & 5 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@
import unittest

import requests
from huggingface_hub.file_download import (
from huggingface_hub.constants import (
CONFIG_NAME,
PYTORCH_WEIGHTS_NAME,
cached_download,
filename_to_url,
hf_hub_url,
REPO_TYPE_DATASET,
)
from huggingface_hub.file_download import cached_download, filename_to_url, hf_hub_url

from .testing_utils import DUMMY_UNKWOWN_IDENTIFIER
from .testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SAMPLE_DATASET_IDENTIFIER


MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
# An actual model hosted on huggingface.co

DATASET_ID = SAMPLE_DATASET_IDENTIFIER
# An actual dataset hosted on huggingface.co


REVISION_ID_DEFAULT = "main"
# Default branch name
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
Expand All @@ -41,6 +44,10 @@
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes

DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT = "e25d55a1c4933f987c46cc75d8ffadd67f257c61"
# One particular commit for DATASET_ID
DATASET_SAMPLE_PY_FILE = "custom_squad.py"


class CachedDownloadTests(unittest.TestCase):
def test_bogus_url(self):
Expand Down Expand Up @@ -86,3 +93,36 @@ def test_lfs_object(self):
filepath = cached_download(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))

def test_dataset_standard_object_rev(self):
url = hf_hub_url(
DATASET_ID,
filename=DATASET_SAMPLE_PY_FILE,
repo_type=REPO_TYPE_DATASET,
revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
)
# We can also just get the same url by prefixing "datasets" to repo_id:
url2 = hf_hub_url(
repo_id=f"datasets/{DATASET_ID}",
filename=DATASET_SAMPLE_PY_FILE,
revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
)
self.assertEqual(url, url2)
# now let's download
filepath = cached_download(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')

def test_dataset_lfs_object(self):
url = hf_hub_url(
DATASET_ID,
filename="dev-v1.1.json",
repo_type=REPO_TYPE_DATASET,
revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
)
filepath = cached_download(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(
metadata,
(url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'),
)
10 changes: 10 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import unittest

from huggingface_hub.constants import REPO_TYPE_DATASET
from huggingface_hub.hf_api import HfApi, HfFolder, ModelInfo, RepoObj
from requests.exceptions import HTTPError

Expand All @@ -33,6 +34,7 @@

REPO_NAME = "my-model-{}".format(int(time.time() * 10e3))
REPO_NAME_LARGE_FILE = "my-model-largefiles-{}".format(int(time.time() * 10e3))
DATASET_REPO_NAME = "my-dataset-{}".format(int(time.time() * 10e3))
WORKING_REPO_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo"
)
Expand Down Expand Up @@ -78,6 +80,14 @@ def test_create_and_delete_repo(self):
self._api.create_repo(token=self._token, name=REPO_NAME)
self._api.delete_repo(token=self._token, name=REPO_NAME)

def test_create_and_delete_dataset_repo(self):
self._api.create_repo(
token=self._token, name=REPO_NAME, repo_type=REPO_TYPE_DATASET
)
self._api.delete_repo(
token=self._token, name=REPO_NAME, repo_type=REPO_TYPE_DATASET
)


class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self):
Expand Down
3 changes: 3 additions & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
# Example model ids

SAMPLE_DATASET_IDENTIFIER = "lhoestq/custom_squad"
# Example dataset ids


def parse_flag_from_env(key, default=False):
try:
Expand Down

0 comments on commit 2bfc9dd

Please sign in to comment.