diff --git a/orca/utils/pretrained_utils.py b/orca/utils/pretrained_utils.py index a94e696c..54779502 100644 --- a/orca/utils/pretrained_utils.py +++ b/orca/utils/pretrained_utils.py @@ -137,9 +137,21 @@ def sample_actions(self, observations, tasks, pad_mask=None, train=False, **kwar **kwargs, ) - @classmethod + @staticmethod + def load_dataset_statistics( + checkpoint_path: str, dataset_name: Optional[str] = None + ): + if dataset_name is not None: + statistics_path = f"dataset_statistics_{dataset_name}.json" + else: + statistics_path = "dataset_statistics.json" + statistics_path = tf.io.gfile.join(checkpoint_path, statistics_path) + with tf.io.gfile.GFile(statistics_path, "r") as f: + statistics = json.load(f) + return statistics + + @staticmethod def load_config( - cls, checkpoint_path: str, ): config_path = tf.io.gfile.join(checkpoint_path, "config.json")