Skip to content

Commit

Permalink
Addressing homers comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dibyaghosh committed Nov 14, 2023
1 parent 09f2276 commit 81b29ac
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
14 changes: 8 additions & 6 deletions experiments/homer/bridge/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"Where to cache checkpoints downloaded from GCS",
)
flags.DEFINE_string(
"modality", "", "Either 'g' (for goals) or 'l' (for language) or '' to prompt"
"modality", "", "Either 'g', 'goal', 'l', 'language' (leave empty to prompt when running)"
)

flags.DEFINE_integer("im_size", None, "Image size", required=True)
Expand All @@ -66,10 +66,10 @@

##############################################################################

"""
Bridge data was collected with non-blocking control and a step duration of 0.2s.
However, we relabel the actions to make it look like the data was collected with blocking control and we evaluate with blocking control.
We also use a step duration of 0.4s to reduce the jerkiness of the policy.
STEP_DURATION_MESSAGE = """
Bridge data was collected with non-blocking control and a step duration of 0.2s.
However, we relabel the actions to make it look like the data was collected with blocking control and we evaluate with blocking control.
We also use a step duration of 0.4s to reduce the jerkiness of the policy.
Be sure to change the step duration back to 0.2 if evaluating with non-blocking control.
"""
STEP_DURATION = 0.4
Expand Down Expand Up @@ -210,6 +210,9 @@ def main(_):
assert len(FLAGS.checkpoint_weights_path) == len(FLAGS.checkpoint_step)
FLAGS.modality = FLAGS.modality[:1]
assert FLAGS.modality in ["g", "l", ""]
if not FLAGS.blocking:
assert STEP_DURATION == 0.2, STEP_DURATION_MESSAGE

# policies is a dict from run_name to policy function
policies = {}
for (checkpoint_weights_path, checkpoint_step,) in zip(
Expand Down Expand Up @@ -264,7 +267,6 @@ def main(_):
for i, name in enumerate(policies.keys()):
print(f"{i}) {name}")
policy_idx = click.prompt("Select policy", type=int)
# policy_idx = int(input("select policy: "))

policy_name = list(policies.keys())[policy_idx]
policy_fn, model = policies[policy_name]
Expand Down
2 changes: 1 addition & 1 deletion experiments/homer/scripts/eval.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
NAMES=(
PATHS=(
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_vits_20231111_165439"
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_baseline_20231112_025236"
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_jaxrlm_baseline_20231112_073307"
Expand Down

0 comments on commit 81b29ac

Please sign in to comment.