Skip to content

Commit

Permalink
Merge pull request octo-models#140 from rail-berkeley/easy_dataset_st…
Browse files Browse the repository at this point in the history
…aitstics

Convenience function to load dataset stats
  • Loading branch information
dibyaghosh authored Dec 8, 2023
2 parents 71658de + 12132db commit 60b688e
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions orca/utils/pretrained_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,21 @@ def sample_actions(self, observations, tasks, pad_mask=None, train=False, **kwar
**kwargs,
)

@classmethod
@staticmethod
def load_dataset_statistics(
checkpoint_path: str, dataset_name: Optional[str] = None
):
if dataset_name is not None:
statistics_path = f"dataset_statistics_{dataset_name}.json"
else:
statistics_path = "dataset_statistics.json"
statistics_path = tf.io.gfile.join(checkpoint_path, statistics_path)
with tf.io.gfile.GFile(statistics_path, "r") as f:
statistics = json.load(f)
return statistics

@staticmethod
def load_config(
cls,
checkpoint_path: str,
):
config_path = tf.io.gfile.join(checkpoint_path, "config.json")
Expand Down

0 comments on commit 60b688e

Please sign in to comment.