Skip to content

Commit

Permalink
py: Detailed datasets metadata in gguf kv store
Browse files Browse the repository at this point in the history
  • Loading branch information
mofosyne committed Aug 28, 2024
1 parent 7c3f55c commit d32c74d
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 25 deletions.
26 changes: 24 additions & 2 deletions examples/convert_legacy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,8 @@ def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None
self.gguf.add_base_model_version(key, base_model_entry["version"])
if "organization" in base_model_entry:
self.gguf.add_base_model_organization(key, base_model_entry["organization"])
if "description" in base_model_entry:
self.gguf.add_base_model_description(key, base_model_entry["description"])
if "url" in base_model_entry:
self.gguf.add_base_model_url(key, base_model_entry["url"])
if "doi" in base_model_entry:
Expand All @@ -849,12 +851,32 @@ def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None
if "repo_url" in base_model_entry:
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])

if metadata.datasets is not None:
self.gguf.add_dataset_count(len(metadata.datasets))
for key, dataset_entry in enumerate(metadata.datasets):
if "name" in dataset_entry:
self.gguf.add_dataset_name(key, dataset_entry["name"])
if "author" in dataset_entry:
self.gguf.add_dataset_author(key, dataset_entry["author"])
if "version" in dataset_entry:
self.gguf.add_dataset_version(key, dataset_entry["version"])
if "organization" in dataset_entry:
self.gguf.add_dataset_organization(key, dataset_entry["organization"])
if "description" in dataset_entry:
self.gguf.add_dataset_description(key, dataset_entry["description"])
if "url" in dataset_entry:
self.gguf.add_dataset_url(key, dataset_entry["url"])
if "doi" in dataset_entry:
self.gguf.add_dataset_doi(key, dataset_entry["doi"])
if "uuid" in dataset_entry:
self.gguf.add_dataset_uuid(key, dataset_entry["uuid"])
if "repo_url" in dataset_entry:
self.gguf.add_dataset_repo_url(key, dataset_entry["repo_url"])

if metadata.tags is not None:
self.gguf.add_tags(metadata.tags)
if metadata.languages is not None:
self.gguf.add_languages(metadata.languages)
if metadata.datasets is not None:
self.gguf.add_datasets(metadata.datasets)

def add_meta_arch(self, params: Params) -> None:
# Metadata About The Neural Architecture Itself
Expand Down
14 changes: 13 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,27 @@ class General:
BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
BASE_MODEL_VERSION = "general.base_model.{id}.version"
BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
BASE_MODEL_DESCRIPTION = "general.base_model.{id}.description"
BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
BASE_MODEL_DOI = "general.base_model.{id}.doi"
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)

# Dataset Source
DATASET_COUNT = "general.dataset.count"
DATASET_NAME = "general.dataset.{id}.name"
DATASET_AUTHOR = "general.dataset.{id}.author"
DATASET_VERSION = "general.dataset.{id}.version"
DATASET_ORGANIZATION = "general.dataset.{id}.organization"
DATASET_DESCRIPTION = "general.dataset.{id}.description"
DATASET_URL = "general.dataset.{id}.url" # Model Website/Paper
DATASET_DOI = "general.dataset.{id}.doi"
DATASET_UUID = "general.dataset.{id}.uuid"
DATASET_REPO_URL = "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...)

# Array based KV stores
TAGS = "general.tags"
LANGUAGES = "general.languages"
DATASETS = "general.datasets"

class LLM:
VOCAB_SIZE = "{arch}.vocab_size"
Expand Down
36 changes: 33 additions & 3 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,9 @@ def add_base_model_version(self, source_id: int, version: str) -> None:
def add_base_model_organization(self, source_id: int, organization: str) -> None:
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)

def add_base_model_description(self, source_id: int, description: str) -> None:
self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description)

def add_base_model_url(self, source_id: int, url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)

Expand All @@ -580,15 +583,42 @@ def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)

def add_dataset_count(self, source_count: int) -> None:
self.add_uint32(Keys.General.DATASET_COUNT, source_count)

def add_dataset_name(self, source_id: int, name: str) -> None:
self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)

def add_dataset_author(self, source_id: int, author: str) -> None:
self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)

def add_dataset_version(self, source_id: int, version: str) -> None:
self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)

def add_dataset_organization(self, source_id: int, organization: str) -> None:
self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization)

def add_dataset_description(self, source_id: int, description: str) -> None:
self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description)

def add_dataset_url(self, source_id: int, url: str) -> None:
self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)

def add_dataset_doi(self, source_id: int, doi: str) -> None:
self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)

def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)

def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)

def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.General.TAGS, tags)

def add_languages(self, languages: Sequence[str]) -> None:
self.add_array(Keys.General.LANGUAGES, languages)

def add_datasets(self, datasets: Sequence[str]) -> None:
self.add_array(Keys.General.DATASETS, datasets)

def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)

Expand Down
137 changes: 119 additions & 18 deletions gguf-py/gguf/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Metadata:
base_models: Optional[list[dict]] = None
tags: Optional[list[str]] = None
languages: Optional[list[str]] = None
datasets: Optional[list[str]] = None
datasets: Optional[list[dict]] = None

@staticmethod
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
Expand Down Expand Up @@ -91,9 +91,11 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat
# Base Models is received here as an array of models
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)

# Datasets is received here as an array of datasets
metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)

metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)

# Direct Metadata Override (via direct cli argument)
if model_name is not None:
Expand Down Expand Up @@ -346,12 +348,12 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
use_model_card_metadata("author", "model_creator")
use_model_card_metadata("basename", "model_type")

if "base_model" in model_card:
if "base_model" in model_card or "base_models" in model_card:
# This represents the parent models that this is based on
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
metadata_base_models = []
base_model_value = model_card.get("base_model", None)
base_model_value = model_card.get("base_model", model_card.get("base_models", None))

if base_model_value is not None:
if isinstance(base_model_value, str):
Expand All @@ -364,18 +366,98 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):

for model_id in metadata_base_models:
# NOTE: model size of base model is assumed to be similar to the size of the current model
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
base_model = {}
if model_full_name_component is not None:
base_model["name"] = Metadata.id_to_title(model_full_name_component)
if org_component is not None:
base_model["organization"] = Metadata.id_to_title(org_component)
if version is not None:
base_model["version"] = version
if org_component is not None and model_full_name_component is not None:
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
if isinstance(model_id, str):
if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
base_model["repo_url"] = model_id

# Check if Hugging Face ID is present in URL
if "huggingface.co" in model_id:
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
if match:
model_id_component = match.group(1)
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)

# Populate model dictionary with extracted components
if model_full_name_component is not None:
base_model["name"] = Metadata.id_to_title(model_full_name_component)
if org_component is not None:
base_model["organization"] = Metadata.id_to_title(org_component)
if version is not None:
base_model["version"] = version

else:
# Likely a Hugging Face ID
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)

# Populate model dictionary with extracted components
if model_full_name_component is not None:
base_model["name"] = Metadata.id_to_title(model_full_name_component)
if org_component is not None:
base_model["organization"] = Metadata.id_to_title(org_component)
if version is not None:
base_model["version"] = version
if org_component is not None and model_full_name_component is not None:
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"

else:
logger.error(f"base model entry '{str(model_id)}' not in a known format")
metadata.base_models.append(base_model)

if "datasets" in model_card or "dataset" in model_card:
# This represents the datasets that this was trained from
metadata_datasets = []
dataset_value = model_card.get("datasets", model_card.get("dataset", None))

if dataset_value is not None:
if isinstance(dataset_value, str):
metadata_datasets.append(dataset_value)
elif isinstance(dataset_value, list):
metadata_datasets.extend(dataset_value)

if metadata.datasets is None:
metadata.datasets = []

for dataset_id in metadata_datasets:
# NOTE: model size of base model is assumed to be similar to the size of the current model
dataset = {}
if isinstance(dataset_id, str):
if dataset_id.startswith(("http://", "https://", "ssh://")):
dataset["repo_url"] = dataset_id

# Check if Hugging Face ID is present in URL
if "huggingface.co" in dataset_id:
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
if match:
dataset_id_component = match.group(1)
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)

# Populate dataset dictionary with extracted components
if dataset_name_component is not None:
dataset["name"] = Metadata.id_to_title(dataset_name_component)
if org_component is not None:
dataset["organization"] = Metadata.id_to_title(org_component)
if version is not None:
dataset["version"] = version

else:
# Likely a Hugging Face ID
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)

# Populate dataset dictionary with extracted components
if dataset_name_component is not None:
dataset["name"] = Metadata.id_to_title(dataset_name_component)
if org_component is not None:
dataset["organization"] = Metadata.id_to_title(org_component)
if version is not None:
dataset["version"] = version
if org_component is not None and dataset_name_component is not None:
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"

else:
logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
metadata.datasets.append(dataset)

use_model_card_metadata("license", "license")
use_model_card_metadata("license_name", "license_name")
use_model_card_metadata("license_link", "license_link")
Expand All @@ -386,9 +468,6 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
use_array_model_card_metadata("languages", "languages")
use_array_model_card_metadata("languages", "language")

use_array_model_card_metadata("datasets", "datasets")
use_array_model_card_metadata("datasets", "dataset")

# Hugging Face Parameter Heuristics
####################################

Expand Down Expand Up @@ -493,6 +572,8 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
gguf_writer.add_base_model_version(key, base_model_entry["version"])
if "organization" in base_model_entry:
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
if "description" in base_model_entry:
gguf_writer.add_base_model_description(key, base_model_entry["description"])
if "url" in base_model_entry:
gguf_writer.add_base_model_url(key, base_model_entry["url"])
if "doi" in base_model_entry:
Expand All @@ -502,9 +583,29 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
if "repo_url" in base_model_entry:
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])

if self.datasets is not None:
gguf_writer.add_dataset_count(len(self.datasets))
for key, dataset_entry in enumerate(self.datasets):
if "name" in dataset_entry:
gguf_writer.add_dataset_name(key, dataset_entry["name"])
if "author" in dataset_entry:
gguf_writer.add_dataset_author(key, dataset_entry["author"])
if "version" in dataset_entry:
gguf_writer.add_dataset_version(key, dataset_entry["version"])
if "organization" in dataset_entry:
gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
if "description" in dataset_entry:
gguf_writer.add_dataset_description(key, dataset_entry["description"])
if "url" in dataset_entry:
gguf_writer.add_dataset_url(key, dataset_entry["url"])
if "doi" in dataset_entry:
gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
if "uuid" in dataset_entry:
gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
if "repo_url" in dataset_entry:
gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])

if self.tags is not None:
gguf_writer.add_tags(self.tags)
if self.languages is not None:
gguf_writer.add_languages(self.languages)
if self.datasets is not None:
gguf_writer.add_datasets(self.datasets)
Loading

0 comments on commit d32c74d

Please sign in to comment.