diff --git a/tensorflow_similarity/distances/distance.py b/tensorflow_similarity/distances/distance.py index ed8bcc87..ff318a6a 100644 --- a/tensorflow_similarity/distances/distance.py +++ b/tensorflow_similarity/distances/distance.py @@ -54,3 +54,20 @@ def get_config(self) -> dict[str, Any]: config = {"name": self.name} return config + + @classmethod + def from_config(cls, config: dict[str, Any]) -> Distance: + """Build a distance from a config. + + Args: + config: A Python dict containing the configuration of the distance. + + Returns: + A distance instance. + """ + try: + return cls(**config) + except Exception as e: + raise TypeError( + f"Error when deserializing '{cls.__name__}' using" f"config={config}.\n\nException encountered: {e}" + ) diff --git a/tensorflow_similarity/search/search.py b/tensorflow_similarity/search/search.py index eb9f42eb..45b6ea3d 100644 --- a/tensorflow_similarity/search/search.py +++ b/tensorflow_similarity/search/search.py @@ -135,3 +135,20 @@ def get_config(self) -> dict[str, Any]: def is_built(self): "Returns whether or not the index is built and ready for querying." "" return self.built + + @classmethod + def from_config(cls, config: dict[str, Any]) -> Search: + """Build a search from a config. + + Args: + config: A Python dict containing the configuration of the search. + + Returns: + A distance instance. + """ + try: + return cls(**config) + except Exception as e: + raise TypeError( + f"Error when deserializing '{cls.__name__}' using" f"config={config}.\n\nException encountered: {e}" + ) diff --git a/tensorflow_similarity/stores/store.py b/tensorflow_similarity/stores/store.py index ca0673dc..83fe8197 100644 --- a/tensorflow_similarity/stores/store.py +++ b/tensorflow_similarity/stores/store.py @@ -151,3 +151,20 @@ def get_config(self) -> dict[str, Any]: } return config + + @classmethod + def from_config(cls, config: dict[str, Any]) -> Store: + """Build a store from a config. + + Args: + config: A Python dict containing the configuration of the store. + + Returns: + A distance instance. + """ + try: + return cls(**config) + except Exception as e: + raise TypeError( + f"Error when deserializing '{cls.__name__}' using" f"config={config}.\n\nException encountered: {e}" + )