From 90fbccdd75cefcacdb79c0bd162cfc3354b75cff Mon Sep 17 00:00:00 2001 From: Toradus Date: Sun, 10 Mar 2024 16:43:56 +0100 Subject: [PATCH 1/2] Fixed int * float multiplication error Fixed len call throwing error when attribute "sample_weights" is provided. Since "num_transitions" is an int value, multiplying it with a float threw an error. I simply switched it around, since "sample_weights" is either int / float and "num_transitions" is always int. --- examples/06_pytorch_oxe_dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/06_pytorch_oxe_dataloader.py b/examples/06_pytorch_oxe_dataloader.py index 046d0802..22bebb9f 100644 --- a/examples/06_pytorch_oxe_dataloader.py +++ b/examples/06_pytorch_oxe_dataloader.py @@ -39,7 +39,7 @@ def __len__(self): ] ) if hasattr(self._rlds_dataset, "sample_weights"): - lengths *= np.array(self._rlds_dataset.sample_weights) + lengths = np.array(self._rlds_dataset.sample_weights) * lengths total_len = lengths.sum() if self._is_train: return int(0.95 * total_len) From a1d37622e00299c93c63e5f98c863569e686b512 Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 20 May 2024 16:11:27 +0200 Subject: [PATCH 2/2] fixed rt1_dataset_transform concat using wrong axis + added possibility to load from arbitrary data_dirs (adjustable in OXE_DATASET_CONFIGS now) --- octo/data/oxe/__init__.py | 7 +++++++ octo/data/oxe/oxe_standardization_transforms.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/octo/data/oxe/__init__.py b/octo/data/oxe/__init__.py index 2ec9555c..77a7d1a5 100755 --- a/octo/data/oxe/__init__.py +++ b/octo/data/oxe/__init__.py @@ -1,5 +1,6 @@ import copy import logging +import os from typing import Any, Dict, List, Sequence, Tuple, Union from octo.data.oxe.oxe_dataset_configs import ActionEncoding, OXE_DATASET_CONFIGS @@ -74,6 +75,12 @@ def make_oxe_dataset_kwargs( dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[name] + if "data_dir" in dataset_kwargs: + if dataset_kwargs["data_dir"][0] == "~": + dataset_kwargs["data_dir"] = os.path.expanduser("~") + dataset_kwargs["data_dir"][1:] + data_dir = dataset_kwargs["data_dir"] + del dataset_kwargs["data_dir"] + return {"name": name, "data_dir": data_dir, **dataset_kwargs} diff --git a/octo/data/oxe/oxe_standardization_transforms.py b/octo/data/oxe/oxe_standardization_transforms.py index f25eed28..43dbe421 100755 --- a/octo/data/oxe/oxe_standardization_transforms.py +++ b/octo/data/oxe/oxe_standardization_transforms.py @@ -32,7 +32,7 @@ def bridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: trajectory["action"][:, :6], binarize_gripper_actions(trajectory["action"][:, -1])[:, None], ], - axis=1, + axis=-1, ) trajectory = relabel_actions(trajectory) trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]