Skip to content

Commit

Permalink
Convenience function to load dataset stats
Browse files Browse the repository at this point in the history
  • Loading branch information
dibyaghosh committed Dec 8, 2023
1 parent 71658de commit 12132db
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions orca/utils/pretrained_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,21 @@ def sample_actions(self, observations, tasks, pad_mask=None, train=False, **kwar
**kwargs,
)

@classmethod
@staticmethod
def load_dataset_statistics(
checkpoint_path: str, dataset_name: Optional[str] = None
):
if dataset_name is not None:
statistics_path = f"dataset_statistics_{dataset_name}.json"
else:
statistics_path = "dataset_statistics.json"
statistics_path = tf.io.gfile.join(checkpoint_path, statistics_path)
with tf.io.gfile.GFile(statistics_path, "r") as f:
statistics = json.load(f)
return statistics

@staticmethod
def load_config(
cls,
checkpoint_path: str,
):
config_path = tf.io.gfile.join(checkpoint_path, "config.json")
Expand Down

0 comments on commit 12132db

Please sign in to comment.