Skip to content

Commit

Permalink
Add from_config for abstract distance, search, and store (#366)
Browse files Browse the repository at this point in the history
* [nightly] Increase version to 0.18.0.dev12

* Change import to use TF keras legacy utils to access serialize and deserialize.

* Add try/except to handle updated import path in tf 2.13

* Add from_config to abstract classes to support deserialization.

* Fix typing on from_config for search and store

---------

Co-authored-by: Github Actions Bot <[email protected]>
  • Loading branch information
owenvallis and actions-user authored Oct 23, 2023
1 parent ad4f815 commit dc506d6
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tensorflow_similarity/distances/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,20 @@ def get_config(self) -> dict[str, Any]:
config = {"name": self.name}

return config

@classmethod
def from_config(cls, config: dict[str, Any]) -> Distance:
"""Build a distance from a config.
Args:
config: A Python dict containing the configuration of the distance.
Returns:
A distance instance.
"""
try:
return cls(**config)
except Exception as e:
raise TypeError(
f"Error when deserializing '{cls.__name__}' using" f"config={config}.\n\nException encountered: {e}"
)
17 changes: 17 additions & 0 deletions tensorflow_similarity/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,20 @@ def get_config(self) -> dict[str, Any]:
def is_built(self):
"Returns whether or not the index is built and ready for querying." ""
return self.built

@classmethod
def from_config(cls, config: dict[str, Any]) -> Search:
"""Build a search from a config.
Args:
config: A Python dict containing the configuration of the search.
Returns:
A distance instance.
"""
try:
return cls(**config)
except Exception as e:
raise TypeError(
f"Error when deserializing '{cls.__name__}' using" f"config={config}.\n\nException encountered: {e}"
)
17 changes: 17 additions & 0 deletions tensorflow_similarity/stores/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,20 @@ def get_config(self) -> dict[str, Any]:
}

return config

@classmethod
def from_config(cls, config: dict[str, Any]) -> Store:
"""Build a store from a config.
Args:
config: A Python dict containing the configuration of the store.
Returns:
A distance instance.
"""
try:
return cls(**config)
except Exception as e:
raise TypeError(
f"Error when deserializing '{cls.__name__}' using" f"config={config}.\n\nException encountered: {e}"
)

0 comments on commit dc506d6

Please sign in to comment.