Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to iterate deterministically with RLDS? What am I doing wrong? #131

Open
peter-mitrano-bg opened this issue Sep 13, 2024 · 3 comments
Open

Comments

@peter-mitrano-bg
Copy link

I'm unable to iterate deterministically. I've looked at this page and I've tried to set up the config so it iterates deterministically but the order is still random.

Here's my test script.

#!/usr/bin/env python

from pathlib import Path

import tensorflow as tf

from octo.data.dataset import make_single_dataset
from octo.utils.spec import ModuleSpec


def main():
    tf.config.set_visible_devices([], "GPU")

    dataset_kwargs = {
        "skip_norm": True,
        "shuffle": False,
        "batch_size": 1,
        "name": "bg_p2_dataset",
        "data_dir": str(Path("~/tensorflow_datasets").expanduser()),
        "image_obs_keys": {"primary": "scene", "wrist": "left"},
        "proprio_obs_key": "state",
        "language_key": "language_instruction",
        "action_proprio_normalization_type": "normal",
        "action_normalization_mask": [True, True, True, True, True, True, False],
        "standardize_fn": ModuleSpec.create(
            "octo.data.oxe.oxe_standardization_transforms:bg_p2_dataset_transform",
        ),
    }
    traj_transform_kwargs = {
        "window_size": 2,
        "action_horizon": 4,
        "goal_relabeling_strategy": None,
        "task_augment_strategy": "delete_task_conditioning",
        "task_augment_kwargs": {
            "keep_image_prob": 0,
        },
        "num_parallel_calls": 1,
    }
    primary_img_augment_kwargs = {
        "random_resized_crop": {"scale": [0.8, 1.0], "ratio": [0.9, 1.1]},
        "random_brightness": [0.1],
        "random_contrast": [0.9, 1.1],
        "random_saturation": [0.9, 1.1],
        "random_hue": [0.05],
        "augment_order": [
            "random_resized_crop",
            "random_brightness",
            "random_contrast",
            "random_saturation",
            "random_hue",
        ],
    }
    wrist_img_augment_kwargs = {
        "random_brightness": [0.1],
        "random_contrast": [0.9, 1.1],
        "random_saturation": [0.9, 1.1],
        "random_hue": [0.05],
        "augment_order": [
            "random_brightness",
            "random_contrast",
            "random_saturation",
            "random_hue",
        ],
    }
    frame_transform_kwargs = {
        "crop_size": {
            "wrist": (int(480 / 2) - 128, int(640 / 2) - 128, 256, 256),
        },
        "resize_size": {
            "primary": (256, 256),
            "wrist": (128, 128),
        },
        "image_augment_kwargs": {
            "primary": primary_img_augment_kwargs,
            "wrist": wrist_img_augment_kwargs,
        },
        "num_parallel_calls": 1,
    }

    dataset_kwargs['num_parallel_calls'] = 1
    dataset_kwargs['num_parallel_reads'] = 1
    dataset, full_dataset_name = make_single_dataset(
        dataset_kwargs,
        traj_transform_kwargs=traj_transform_kwargs,
        frame_transform_kwargs=frame_transform_kwargs,
        train=False,
    )

    for k in range(5):
        print(f"--- {k} ---")
        train_data_iter = dataset.iterator()
        for batch in train_data_iter:
            print(batch['action'].sum())

if __name__ == "__main__":
    main()

And the output:

WARNING:tensorflow:AutoGraph could not transform <function _gcd_import at 0x7906b5963400> and will run it as-is.
Cause: Unable to locate the source code of <function _gcd_import at 0x7906b5963400>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function _gcd_import at 0x7906b5963400> and will run it as-is.
Cause: Unable to locate the source code of <function _gcd_import at 0x7906b5963400>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:absl:Dataset normalization turned off -- set skip_norm=False to apply normalization.
--- 0 ---
184.0448
177.0222
198.7096
189.02994
183.54613
--- 1 ---
184.0448
177.0222
198.7096
189.02994
183.54613
--- 2 ---
184.0448
198.7096
177.0222
189.02994
183.54613
@peter-mitrano-bg
Copy link
Author

The motivation here is to have the validation set iteration be deterministic so the plots and such can be compared. Right now the freeze_trajs argument doesn't actually seem to work, so you can't easily compare the visualizations over the course of training.

@peter-mitrano-bg
Copy link
Author

Just coming back to say I've given this another go and still can't figure out. Being able to deterministically iterate is also pretty essential for making comparisons between different visualization and analysis scripts, so I'm still very interested in help figuring this out!

@peter-mitrano-bg
Copy link
Author

It seems like applying kvablack/dlimp#4 solves this problem! But I found it's only a partial implementation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant