diff --git a/experiments/homer/bridge/eval.py b/experiments/homer/bridge/eval.py index 5a3c4b01..bd89e3b2 100644 --- a/experiments/homer/bridge/eval.py +++ b/experiments/homer/bridge/eval.py @@ -1,11 +1,14 @@ #!/usr/bin/env python3 +from datetime import datetime +from functools import partial import json import os +from pathlib import Path, PurePath import time -from datetime import datetime -from functools import partial +from absl import app, flags, logging +import click import cv2 import flax import imageio @@ -13,18 +16,13 @@ import jax.numpy as jnp import numpy as np import tensorflow as tf -from absl import app, flags, logging # bridge_data_robot imports -from widowx_envs.widowx_env_service import WidowXClient, WidowXStatus, WidowXConfigs +from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus +from widowx_wrapper import convert_obs, state_to_eep, wait_for_obs, WidowXGym +from orca.utils.gym_wrappers import HistoryWrapper, RHCWrapper, TemporalEnsembleWrapper from orca.utils.pretrained_utils import PretrainedModel -from widowx_wrapper import WidowXGym, convert_obs, wait_for_obs, state_to_eep -from orca.utils.gym_wrappers import ( - HistoryWrapper, - RHCWrapper, - TemporalEnsembleWrapper, -) np.set_printoptions(suppress=True) @@ -35,21 +33,19 @@ flags.DEFINE_multi_string( "checkpoint_weights_path", None, "Path to checkpoint", required=True ) -flags.DEFINE_multi_string( - "checkpoint_step", None, "Checkpoint step", required=True -) -flags.DEFINE_multi_string( - "checkpoint_config_path", None, "Path to checkpoint config JSON", required=True -) -flags.DEFINE_multi_string( - "checkpoint_metadata_path", None, "Path to checkpoint metadata JSON", required=True +flags.DEFINE_multi_integer("checkpoint_step", None, "Checkpoint step", required=True) +flags.DEFINE_bool("add_jaxrlm_baseline", False, "Also compare to jaxrl_m baseline") + + +flags.DEFINE_string( + "checkpoint_cache_dir", + "/tmp/", + "Where to cache checkpoints downloaded from GCS", ) -flags.DEFINE_multi_string( - "checkpoint_example_batch_path", - None, - "Path to checkpoint metadata JSON", - required=True, +flags.DEFINE_string( + "modality", "", "Either 'g', 'goal', 'l', 'language' (leave empty to prompt when running)" ) + flags.DEFINE_integer("im_size", None, "Image size", required=True) flags.DEFINE_string("video_save_path", None, "Path to save video") flags.DEFINE_integer("num_timesteps", 120, "num timesteps") @@ -67,8 +63,15 @@ # show image flag flags.DEFINE_bool("show_image", False, "Show image") + ############################################################################## +STEP_DURATION_MESSAGE = """ +Bridge data was collected with non-blocking control and a step duration of 0.2s. +However, we relabel the actions to make it look like the data was collected with blocking control and we evaluate with blocking control. +We also use a step duration of 0.4s to reduce the jerkiness of the policy. +Be sure to change the step duration back to 0.2 if evaluating with non-blocking control. +""" STEP_DURATION = 0.4 STICKY_GRIPPER_NUM_STEPS = 1 WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]] @@ -82,6 +85,35 @@ ############################################################################## +def maybe_download_checkpoint_from_gcs(cloud_path, step, save_path): + if not cloud_path.startswith("gs://"): + return cloud_path, step # Actually on the local filesystem + + checkpoint_path = tf.io.gfile.join(cloud_path, f"{step}") + norm_path = tf.io.gfile.join(cloud_path, "action_proprio*") + config_path = tf.io.gfile.join(cloud_path, "config.json*") + example_batch_path = tf.io.gfile.join(cloud_path, "example_batch.msgpack*") + + run_name = Path(cloud_path).name + save_path = os.path.join(save_path, run_name) + + target_checkpoint_path = os.path.join(save_path, f"{step}") + if os.path.exists(target_checkpoint_path): + logging.warning( + "Checkpoint already exists at %s, skipping download", target_checkpoint_path + ) + return save_path, step + os.makedirs(save_path, exist_ok=True) + logging.warning("Downloading checkpoint and metadata to %s", save_path) + + os.system(f"gsutil cp -r {checkpoint_path} {save_path}/") + os.system(f"gsutil cp {norm_path} {save_path}/") + os.system(f"gsutil cp {config_path} {save_path}/") + os.system(f"gsutil cp {example_batch_path} {save_path}/") + + return save_path, step + + def supply_rng(f, rng=jax.random.PRNGKey(0)): def wrapped(*args, **kwargs): nonlocal rng @@ -120,11 +152,42 @@ def sample_actions( return actions[0] * std + mean -def load_checkpoint(weights_path, config_path, metadata_path, example_batch_path, step): - model = PretrainedModel.load_pretrained( - weights_path, config_path, example_batch_path, step - ) +def load_jaxrlm_checkpoint( + weights_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/gcbc_256/checkpoint_300000/", + config_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/gcbc_256/gcbc_256_config.json", + code_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/bridge_data_v2.zip", +): + from codesave import UniqueCodebase + + with UniqueCodebase(code_path) as cs: + pretrained_utils = cs.import_module("jaxrl_m.pretrained_utils") + loaded = pretrained_utils.load_checkpoint( + weights_path, config_path, im_size=256 + ) + # loaded contains: { + # "agent": jaxrlm Agent, + # "policy_fn": callable taking in observation and goal inputs and outputs **unnormalized** actions, + # "normalization_stats": {"action": {"mean": [7], "std": [7]}} + # "obs_horizon": int + # } + + class Dummy: + def create_tasks(self, goals): + return goals.copy() + + def new_policy_fn(observations, goals): + observations = {"image": observations["image_0"]} + goals = {"image": goals["image_0"]} + return loaded["policy_fn"](observations, goals) + + return new_policy_fn, Dummy() + +def load_checkpoint(weights_path, step): + model = PretrainedModel.load_pretrained(weights_path, step=int(step)) + metadata_path = os.path.join( + weights_path, "action_proprio_metadata_bridge_dataset.json" + ) with open(metadata_path, "r") as f: action_proprio_metadata = json.load(f) action_mean = jnp.array(action_proprio_metadata["action"]["mean"]) @@ -144,38 +207,31 @@ def load_checkpoint(weights_path, config_path, metadata_path, example_batch_path def main(_): - assert ( - len(FLAGS.checkpoint_weights_path) - == len(FLAGS.checkpoint_config_path) - == len(FLAGS.checkpoint_metadata_path) - == len(FLAGS.checkpoint_example_batch_path) - == len(FLAGS.checkpoint_step) - ) + assert len(FLAGS.checkpoint_weights_path) == len(FLAGS.checkpoint_step) + FLAGS.modality = FLAGS.modality[:1] + assert FLAGS.modality in ["g", "l", ""] + if not FLAGS.blocking: + assert STEP_DURATION == 0.2, STEP_DURATION_MESSAGE # policies is a dict from run_name to policy function policies = {} - for ( - checkpoint_weights_path, - checkpoint_config_path, - checkpoint_metadata_path, - checkpoint_example_batch_path, - checkpoint_step, - ) in zip( + for (checkpoint_weights_path, checkpoint_step,) in zip( FLAGS.checkpoint_weights_path, - FLAGS.checkpoint_config_path, - FLAGS.checkpoint_metadata_path, - FLAGS.checkpoint_example_batch_path, FLAGS.checkpoint_step, ): + checkpoint_weights_path, checkpoint_step = maybe_download_checkpoint_from_gcs( + checkpoint_weights_path, + checkpoint_step, + FLAGS.checkpoint_cache_dir, + ) assert tf.io.gfile.exists(checkpoint_weights_path), checkpoint_weights_path - run_name = checkpoint_config_path.split("/")[-2] + run_name = checkpoint_weights_path.rpartition("/")[2] policies[f"{run_name}-{checkpoint_step}"] = load_checkpoint( checkpoint_weights_path, - checkpoint_config_path, - checkpoint_metadata_path, - checkpoint_example_batch_path, - checkpoint_step + checkpoint_step, ) + if FLAGS.add_jaxrlm_baseline: + policies["jaxrl_gcbc"] = load_jaxrlm_checkpoint() if FLAGS.initial_eep is not None: assert isinstance(FLAGS.initial_eep, list) @@ -197,9 +253,8 @@ def main(_): # env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon) env = RHCWrapper(env, FLAGS.pred_horizon, FLAGS.exec_horizon) - task = { - "image_0": jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8), - } + goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) + goal_instruction = "" # goal sampling loop while True: @@ -211,21 +266,20 @@ def main(_): print("policies:") for i, name in enumerate(policies.keys()): print(f"{i}) {name}") - policy_idx = int(input("select policy: ")) + policy_idx = click.prompt("Select policy", type=int) policy_name = list(policies.keys())[policy_idx] policy_fn, model = policies[policy_name] model: PretrainedModel # type hinting - modality = input("Language or goal image? [l/g]") + modality = FLAGS.modality + if not modality: + modality = click.prompt( + "Language or goal image?", type=click.Choice(["l", "g"]) + ) + if modality == "g": - # ask for new goal - if task["image_0"] is None: - print("Taking a new goal...") - ch = "y" - else: - ch = input("Take a new goal? [y/n]") - if ch == "y": + if click.confirm("Take a new goal?", default=True): assert isinstance(FLAGS.goal_eep, list) _eep = [float(e) for e in FLAGS.goal_eep] goal_eep = state_to_eep(_eep, 0) @@ -237,17 +291,21 @@ def main(_): input("Press [Enter] when ready for taking the goal image. ") obs = wait_for_obs(widowx_client) - goals = jax.tree_map(lambda x: x[None], convert_obs(obs, FLAGS.im_size)) - task = model.create_tasks(goals=goals) - else: - # ask for new instruction - if "language_instruction" not in task or ["language_instruction"] is None: - ch = "y" - else: - ch = input("New instruction? [y/n]") - if ch == "y": + goal = jax.tree_map(lambda x: x[None], convert_obs(obs, FLAGS.im_size)) + + task = model.create_tasks(goals=goal) + goal_image = goal["image_0"][0] + goal_instruction = "" + elif modality == "l": + print("Current instruction: ", goal_instruction) + if click.confirm("Take a new instruction?", default=True): text = input("Instruction?") - task = model.create_tasks(text=[text]) + + task = model.create_tasks(text=[text]) + goal_instruction = text + goal_image = jnp.zeros_like(goal_image) + else: + raise NotImplementedError() input("Press [Enter] to start.") @@ -267,7 +325,7 @@ def main(_): # save images images.append(obs["image_0"][-1]) - goals.append(task["image_0"][0]) + goals.append(goal_image) if FLAGS.show_image: bgr_img = cv2.cvtColor(obs["full_image"][-1], cv2.COLOR_RGB2BGR) diff --git a/experiments/homer/scripts/eval.sh b/experiments/homer/scripts/eval.sh index 09a8f5d9..4e87a7cc 100644 --- a/experiments/homer/scripts/eval.sh +++ b/experiments/homer/scripts/eval.sh @@ -1,12 +1,17 @@ -NAMES=( - "gc_bridge_match_old_20231026_193653" +PATHS=( + "gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_vits_20231111_165439" + "gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_baseline_20231112_025236" + "gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_jaxrlm_baseline_20231112_073307" ) STEPS=( - "345000" + "120000" + "500000" + "300000" ) -VIDEO_DIR="11-3" +CONDITIONING_MODE="goal" +VIDEO_DIR="11-12" TIMESTEPS="50" @@ -21,17 +26,16 @@ EXEC_HORIZON="1" CMD="python experiments/homer/bridge/eval.py \ --num_timesteps $TIMESTEPS \ --video_save_path /mount/harddrive/homer/videos/$VIDEO_DIR \ - $(for i in "${!NAMES[@]}"; do echo "--checkpoint_weights_path /mount/harddrive/homer/checkpoints/${NAMES[$i]} "; done) \ - $(for i in "${!NAMES[@]}"; do echo "--checkpoint_step /mount/harddrive/homer/checkpoints/${STEPS[$i]} "; done) \ - $(for i in "${!NAMES[@]}"; do echo "--checkpoint_config_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/config.json "; done) \ - $(for i in "${!NAMES[@]}"; do echo "--checkpoint_metadata_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/action_proprio_metadata_bridge_dataset.json "; done) \ - $(for i in "${!NAMES[@]}"; do echo "--checkpoint_example_batch_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/example_batch.msgpack "; done) \ + $(for i in "${!NAMES[@]}"; do echo "--checkpoint_weights_path ${NAMES[$i]} "; done) \ + $(for i in "${!NAMES[@]}"; do echo "--checkpoint_step ${STEPS[$i]} "; done) \ --im_size 256 \ --temperature $TEMPERATURE \ --horizon $HORIZON \ --pred_horizon $PRED_HORIZON \ --exec_horizon $EXEC_HORIZON \ - --blocking + --blocking \ + --modality $CONDITIONING_MODE \ + --checkpoint_cache_dir /mount/harddrive/homer/checkpoints/ " echo $CMD