Skip to content

Commit

Permalink
Merge pull request octo-models#77 from rail-berkeley/dibya-fix-bridge…
Browse files Browse the repository at this point in the history
…-eval

Updates to Bridge Evaluation
  • Loading branch information
dibyaghosh authored Nov 14, 2023
2 parents 7bac65d + 81b29ac commit 156cd9c
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 82 deletions.
202 changes: 130 additions & 72 deletions experiments/homer/bridge/eval.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
#!/usr/bin/env python3

from datetime import datetime
from functools import partial
import json
import os
from pathlib import Path, PurePath
import time
from datetime import datetime
from functools import partial

from absl import app, flags, logging
import click
import cv2
import flax
import imageio
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from absl import app, flags, logging

# bridge_data_robot imports
from widowx_envs.widowx_env_service import WidowXClient, WidowXStatus, WidowXConfigs
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus
from widowx_wrapper import convert_obs, state_to_eep, wait_for_obs, WidowXGym

from orca.utils.gym_wrappers import HistoryWrapper, RHCWrapper, TemporalEnsembleWrapper
from orca.utils.pretrained_utils import PretrainedModel
from widowx_wrapper import WidowXGym, convert_obs, wait_for_obs, state_to_eep
from orca.utils.gym_wrappers import (
HistoryWrapper,
RHCWrapper,
TemporalEnsembleWrapper,
)

np.set_printoptions(suppress=True)

Expand All @@ -35,21 +33,19 @@
flags.DEFINE_multi_string(
"checkpoint_weights_path", None, "Path to checkpoint", required=True
)
flags.DEFINE_multi_string(
"checkpoint_step", None, "Checkpoint step", required=True
)
flags.DEFINE_multi_string(
"checkpoint_config_path", None, "Path to checkpoint config JSON", required=True
)
flags.DEFINE_multi_string(
"checkpoint_metadata_path", None, "Path to checkpoint metadata JSON", required=True
flags.DEFINE_multi_integer("checkpoint_step", None, "Checkpoint step", required=True)
flags.DEFINE_bool("add_jaxrlm_baseline", False, "Also compare to jaxrl_m baseline")


flags.DEFINE_string(
"checkpoint_cache_dir",
"/tmp/",
"Where to cache checkpoints downloaded from GCS",
)
flags.DEFINE_multi_string(
"checkpoint_example_batch_path",
None,
"Path to checkpoint metadata JSON",
required=True,
flags.DEFINE_string(
"modality", "", "Either 'g', 'goal', 'l', 'language' (leave empty to prompt when running)"
)

flags.DEFINE_integer("im_size", None, "Image size", required=True)
flags.DEFINE_string("video_save_path", None, "Path to save video")
flags.DEFINE_integer("num_timesteps", 120, "num timesteps")
Expand All @@ -67,8 +63,15 @@
# show image flag
flags.DEFINE_bool("show_image", False, "Show image")


##############################################################################

STEP_DURATION_MESSAGE = """
Bridge data was collected with non-blocking control and a step duration of 0.2s.
However, we relabel the actions to make it look like the data was collected with blocking control and we evaluate with blocking control.
We also use a step duration of 0.4s to reduce the jerkiness of the policy.
Be sure to change the step duration back to 0.2 if evaluating with non-blocking control.
"""
STEP_DURATION = 0.4
STICKY_GRIPPER_NUM_STEPS = 1
WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]
Expand All @@ -82,6 +85,35 @@
##############################################################################


def maybe_download_checkpoint_from_gcs(cloud_path, step, save_path):
if not cloud_path.startswith("gs://"):
return cloud_path, step # Actually on the local filesystem

checkpoint_path = tf.io.gfile.join(cloud_path, f"{step}")
norm_path = tf.io.gfile.join(cloud_path, "action_proprio*")
config_path = tf.io.gfile.join(cloud_path, "config.json*")
example_batch_path = tf.io.gfile.join(cloud_path, "example_batch.msgpack*")

run_name = Path(cloud_path).name
save_path = os.path.join(save_path, run_name)

target_checkpoint_path = os.path.join(save_path, f"{step}")
if os.path.exists(target_checkpoint_path):
logging.warning(
"Checkpoint already exists at %s, skipping download", target_checkpoint_path
)
return save_path, step
os.makedirs(save_path, exist_ok=True)
logging.warning("Downloading checkpoint and metadata to %s", save_path)

os.system(f"gsutil cp -r {checkpoint_path} {save_path}/")
os.system(f"gsutil cp {norm_path} {save_path}/")
os.system(f"gsutil cp {config_path} {save_path}/")
os.system(f"gsutil cp {example_batch_path} {save_path}/")

return save_path, step


def supply_rng(f, rng=jax.random.PRNGKey(0)):
def wrapped(*args, **kwargs):
nonlocal rng
Expand Down Expand Up @@ -120,11 +152,42 @@ def sample_actions(
return actions[0] * std + mean


def load_checkpoint(weights_path, config_path, metadata_path, example_batch_path, step):
model = PretrainedModel.load_pretrained(
weights_path, config_path, example_batch_path, step
)
def load_jaxrlm_checkpoint(
weights_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/gcbc_256/checkpoint_300000/",
config_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/gcbc_256/gcbc_256_config.json",
code_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/bridge_data_v2.zip",
):
from codesave import UniqueCodebase

with UniqueCodebase(code_path) as cs:
pretrained_utils = cs.import_module("jaxrl_m.pretrained_utils")
loaded = pretrained_utils.load_checkpoint(
weights_path, config_path, im_size=256
)
# loaded contains: {
# "agent": jaxrlm Agent,
# "policy_fn": callable taking in observation and goal inputs and outputs **unnormalized** actions,
# "normalization_stats": {"action": {"mean": [7], "std": [7]}}
# "obs_horizon": int
# }

class Dummy:
def create_tasks(self, goals):
return goals.copy()

def new_policy_fn(observations, goals):
observations = {"image": observations["image_0"]}
goals = {"image": goals["image_0"]}
return loaded["policy_fn"](observations, goals)

return new_policy_fn, Dummy()


def load_checkpoint(weights_path, step):
model = PretrainedModel.load_pretrained(weights_path, step=int(step))
metadata_path = os.path.join(
weights_path, "action_proprio_metadata_bridge_dataset.json"
)
with open(metadata_path, "r") as f:
action_proprio_metadata = json.load(f)
action_mean = jnp.array(action_proprio_metadata["action"]["mean"])
Expand All @@ -144,38 +207,31 @@ def load_checkpoint(weights_path, config_path, metadata_path, example_batch_path


def main(_):
assert (
len(FLAGS.checkpoint_weights_path)
== len(FLAGS.checkpoint_config_path)
== len(FLAGS.checkpoint_metadata_path)
== len(FLAGS.checkpoint_example_batch_path)
== len(FLAGS.checkpoint_step)
)
assert len(FLAGS.checkpoint_weights_path) == len(FLAGS.checkpoint_step)
FLAGS.modality = FLAGS.modality[:1]
assert FLAGS.modality in ["g", "l", ""]
if not FLAGS.blocking:
assert STEP_DURATION == 0.2, STEP_DURATION_MESSAGE

# policies is a dict from run_name to policy function
policies = {}
for (
checkpoint_weights_path,
checkpoint_config_path,
checkpoint_metadata_path,
checkpoint_example_batch_path,
checkpoint_step,
) in zip(
for (checkpoint_weights_path, checkpoint_step,) in zip(
FLAGS.checkpoint_weights_path,
FLAGS.checkpoint_config_path,
FLAGS.checkpoint_metadata_path,
FLAGS.checkpoint_example_batch_path,
FLAGS.checkpoint_step,
):
checkpoint_weights_path, checkpoint_step = maybe_download_checkpoint_from_gcs(
checkpoint_weights_path,
checkpoint_step,
FLAGS.checkpoint_cache_dir,
)
assert tf.io.gfile.exists(checkpoint_weights_path), checkpoint_weights_path
run_name = checkpoint_config_path.split("/")[-2]
run_name = checkpoint_weights_path.rpartition("/")[2]
policies[f"{run_name}-{checkpoint_step}"] = load_checkpoint(
checkpoint_weights_path,
checkpoint_config_path,
checkpoint_metadata_path,
checkpoint_example_batch_path,
checkpoint_step
checkpoint_step,
)
if FLAGS.add_jaxrlm_baseline:
policies["jaxrl_gcbc"] = load_jaxrlm_checkpoint()

if FLAGS.initial_eep is not None:
assert isinstance(FLAGS.initial_eep, list)
Expand All @@ -197,9 +253,8 @@ def main(_):
# env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon)
env = RHCWrapper(env, FLAGS.pred_horizon, FLAGS.exec_horizon)

task = {
"image_0": jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8),
}
goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8)
goal_instruction = ""

# goal sampling loop
while True:
Expand All @@ -211,21 +266,20 @@ def main(_):
print("policies:")
for i, name in enumerate(policies.keys()):
print(f"{i}) {name}")
policy_idx = int(input("select policy: "))
policy_idx = click.prompt("Select policy", type=int)

policy_name = list(policies.keys())[policy_idx]
policy_fn, model = policies[policy_name]
model: PretrainedModel # type hinting

modality = input("Language or goal image? [l/g]")
modality = FLAGS.modality
if not modality:
modality = click.prompt(
"Language or goal image?", type=click.Choice(["l", "g"])
)

if modality == "g":
# ask for new goal
if task["image_0"] is None:
print("Taking a new goal...")
ch = "y"
else:
ch = input("Take a new goal? [y/n]")
if ch == "y":
if click.confirm("Take a new goal?", default=True):
assert isinstance(FLAGS.goal_eep, list)
_eep = [float(e) for e in FLAGS.goal_eep]
goal_eep = state_to_eep(_eep, 0)
Expand All @@ -237,17 +291,21 @@ def main(_):

input("Press [Enter] when ready for taking the goal image. ")
obs = wait_for_obs(widowx_client)
goals = jax.tree_map(lambda x: x[None], convert_obs(obs, FLAGS.im_size))
task = model.create_tasks(goals=goals)
else:
# ask for new instruction
if "language_instruction" not in task or ["language_instruction"] is None:
ch = "y"
else:
ch = input("New instruction? [y/n]")
if ch == "y":
goal = jax.tree_map(lambda x: x[None], convert_obs(obs, FLAGS.im_size))

task = model.create_tasks(goals=goal)
goal_image = goal["image_0"][0]
goal_instruction = ""
elif modality == "l":
print("Current instruction: ", goal_instruction)
if click.confirm("Take a new instruction?", default=True):
text = input("Instruction?")
task = model.create_tasks(text=[text])

task = model.create_tasks(text=[text])
goal_instruction = text
goal_image = jnp.zeros_like(goal_image)
else:
raise NotImplementedError()

input("Press [Enter] to start.")

Expand All @@ -267,7 +325,7 @@ def main(_):

# save images
images.append(obs["image_0"][-1])
goals.append(task["image_0"][0])
goals.append(goal_image)

if FLAGS.show_image:
bgr_img = cv2.cvtColor(obs["full_image"][-1], cv2.COLOR_RGB2BGR)
Expand Down
24 changes: 14 additions & 10 deletions experiments/homer/scripts/eval.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
NAMES=(
"gc_bridge_match_old_20231026_193653"
PATHS=(
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_vits_20231111_165439"
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_baseline_20231112_025236"
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_jaxrlm_baseline_20231112_073307"
)

STEPS=(
"345000"
"120000"
"500000"
"300000"
)

VIDEO_DIR="11-3"
CONDITIONING_MODE="goal"
VIDEO_DIR="11-12"

TIMESTEPS="50"

Expand All @@ -21,17 +26,16 @@ EXEC_HORIZON="1"
CMD="python experiments/homer/bridge/eval.py \
--num_timesteps $TIMESTEPS \
--video_save_path /mount/harddrive/homer/videos/$VIDEO_DIR \
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_weights_path /mount/harddrive/homer/checkpoints/${NAMES[$i]} "; done) \
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_step /mount/harddrive/homer/checkpoints/${STEPS[$i]} "; done) \
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_config_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/config.json "; done) \
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_metadata_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/action_proprio_metadata_bridge_dataset.json "; done) \
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_example_batch_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/example_batch.msgpack "; done) \
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_weights_path ${NAMES[$i]} "; done) \
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_step ${STEPS[$i]} "; done) \
--im_size 256 \
--temperature $TEMPERATURE \
--horizon $HORIZON \
--pred_horizon $PRED_HORIZON \
--exec_horizon $EXEC_HORIZON \
--blocking
--blocking \
--modality $CONDITIONING_MODE \
--checkpoint_cache_dir /mount/harddrive/homer/checkpoints/
"

echo $CMD
Expand Down

0 comments on commit 156cd9c

Please sign in to comment.