Skip to content

Commit

Permalink
act on krec v1 (loss goes down)
Browse files Browse the repository at this point in the history
  • Loading branch information
alik-git committed Dec 5, 2024
1 parent e5e4a80 commit 40c96e7
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 19 deletions.
2 changes: 2 additions & 0 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
video_backend=cfg.video_backend,
local_files_only=cfg.get('dataset_local_files_only', False)
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
video_backend=cfg.video_backend,
local_files_only=cfg.get('dataset_local_files_only', False)
)

if cfg.get("override_dataset_stats"):
Expand Down
34 changes: 28 additions & 6 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,26 @@ def __init__(self, config: ACTConfig):
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_robot_state = "observation.state" in config.input_shapes
# Use our robot state instead of hardcoded "observation.state"
self.krec_robot_state_keys = ["observation.joint_pos", "observation.joint_vel", "observation.ang_vel", "observation.euler_rotation"]
self.use_robot_state = any(k in self.krec_robot_state_keys for k in config.input_shapes)
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_env_state = "observation.environment_state" in config.input_shapes

self.robot_state_dim = 0
if self.use_robot_state:
self.robot_state_dim = sum(
shape[0] for key, shape in config.input_shapes.items()
if key in self.krec_robot_state_keys
)

if self.config.use_vae:
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.use_robot_state:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
self.robot_state_dim, config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
Expand Down Expand Up @@ -339,7 +349,7 @@ def __init__(self, config: ACTConfig):
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.use_robot_state:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
self.robot_state_dim, config.dim_model
)
if self.use_env_state:
self.encoder_env_state_input_proj = nn.Linear(
Expand Down Expand Up @@ -375,6 +385,14 @@ def _reset_parameters(self):
if p.dim() > 1:
nn.init.xavier_uniform_(p)

def _get_robot_state(self, batch: dict[str, Tensor]) -> Tensor:
"""Helper method to concatenate robot state tensors."""
robot_states = []
for key in batch.keys():
if key in self.krec_robot_state_keys:
robot_states.append(batch[key])
return torch.cat(robot_states, dim=-1)

def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
Expand Down Expand Up @@ -405,14 +423,17 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
else batch["observation.environment_state"]
).shape[0]

# Get device from the batch
device = next(iter(batch.values())).device
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_robot_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state = self._get_robot_state(batch)
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state)
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)

Expand All @@ -432,7 +453,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_robot_state else 1),
False,
device=batch["observation.state"].device,
device=device
)
key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
Expand Down Expand Up @@ -464,7 +485,8 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.use_robot_state:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
robot_state = self._get_robot_state(batch)
encoder_in_tokens.append(self.encoder_robot_state_input_proj(robot_state))
# Environment state token.
if self.use_env_state:
encoder_in_tokens.append(
Expand Down
30 changes: 17 additions & 13 deletions lerobot/configs/policy/act.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
# @package _global_

seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human

dataset_repo_id: gpr_test_krec # Match the env config
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
observation.images:
# Using ImageNet stats since ACT uses pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)

training:
offline_steps: 100000
offline_steps: 500
online_steps: 0
eval_freq: 20000
save_freq: 20000
eval_freq: 0
save_freq: 200
save_checkpoint: true

batch_size: 8
lr: 1e-5
lr: 1e-4
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
Expand All @@ -40,16 +39,21 @@ policy:
n_action_steps: 100

input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.state: ["${env.state_dim}"]
observation.images: [3, 480, 640]
observation.joint_pos: [5]
observation.joint_vel: [5]
observation.ang_vel: [3]
observation.euler_rotation: [3]
output_shapes:
action: ["${env.action_dim}"]

# Normalization / Unnormalization
input_normalization_modes:
observation.images.top: mean_std
observation.state: mean_std
observation.images: mean_std
observation.joint_pos: mean_std
observation.joint_vel: mean_std
observation.ang_vel: mean_std
observation.euler_rotation: mean_std
output_normalization_modes:
action: mean_std

Expand Down
87 changes: 87 additions & 0 deletions lerobot/configs/policy/act_gpr_real.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
defaults:
- act # Inherit from base ACT config

seed: 1000
dataset_repo_id: gpr_test_krec # Match the env config
override_dataset_stats:
observation.images:
# Using ImageNet stats since ACT uses pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)

training:
offline_steps: 500
online_steps: 0
eval_freq: 0
save_freq: 200
save_checkpoint: true

batch_size: 8
lr: 1e-4
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1

delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"

eval:
n_episodes: 50
batch_size: 50

# See `configuration_act.py` for more details.
policy:
name: act

# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100

input_shapes:
observation.images: [3, 480, 640]
observation.joint_pos: [5]
observation.joint_vel: [5]
observation.ang_vel: [3]
observation.euler_rotation: [3]
output_shapes:
action: ["${env.action_dim}"]

# Normalization / Unnormalization
input_normalization_modes:
observation.images: mean_std
observation.joint_pos: mean_std
observation.joint_vel: mean_std
observation.ang_vel: mean_std
observation.euler_rotation: mean_std
output_normalization_modes:
action: mean_std

# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4

# Inference.
temporal_ensemble_coeff: null

# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

0 comments on commit 40c96e7

Please sign in to comment.