Skip to content

Commit

Permalink
polish(pu): polish suz dmc multitask configs
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 authored and puyuan committed Dec 30, 2024
1 parent a5b38b6 commit 1c8c4fb
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 115 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/train_unizero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def train_unizero_multitask_segment_ddp(
if cfg.policy.buffer_reanalyze_freq >= 1:
reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
else:
if train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \
if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \
replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(
reanalyze_batch_size / cfg.policy.reanalyze_partition):
with timer:
Expand Down
182 changes: 83 additions & 99 deletions lzero/policy/sampled_unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,105 +817,89 @@ def _state_dict_learn(self) -> Dict[str, Any]:
'optimizer_world_model': self._optimizer_world_model.state_dict(),
}

# ========== TODO: original version: load all parameters ==========
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Load the state_dict into policy learn mode, excluding multi-task related parameters.
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
"""
exclude_prefixes = [
'_orig_mod.world_model.head_policy_multi_task.',
'_orig_mod.world_model.head_value_multi_task.',
'_orig_mod.world_model.head_rewards_multi_task.',
'_orig_mod.world_model.head_observations_multi_task.',
'_orig_mod.world_model.task_emb.'
]

exclude_keys = [
'_orig_mod.world_model.task_emb.weight',
'_orig_mod.world_model.task_emb.bias',
# 可根据需要添加更多需要排除的具体参数名
]

def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]:
"""
过滤掉需要排除的参数。
"""
filtered = {}
for k, v in state_dict_loader.items():
if any(k.startswith(prefix) for prefix in exclude_prefixes):
logging.info(f"Excluding parameter: {k}") # 调试用
continue
if k in exclude_keys:
logging.info(f"Excluding specific parameter: {k}") # 调试用
continue
filtered[k] = v
return filtered

# 过滤并加载 'model' 部分
if 'model' in state_dict:
model_state_dict = state_dict['model']
filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys)
missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False)
if missing_keys:
logging.info(f"Missing keys when loading _learn_model: {missing_keys}")
if unexpected_keys:
logging.info(f"Unexpected keys when loading _learn_model: {unexpected_keys}")
else:
logging.info("No 'model' key found in the state_dict.")

# 过滤并加载 'target_model' 部分
if 'target_model' in state_dict:
target_model_state_dict = state_dict['target_model']
filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys)
missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False)
if missing_keys:
logging.info(f"Missing keys when loading _target_model: {missing_keys}")
if unexpected_keys:
logging.info(f"Unexpected keys when loading _target_model: {unexpected_keys}")
else:
logging.info("No 'target_model' key found in the state_dict.")

# 加载优化器的 state_dict
if 'optimizer_world_model' in state_dict:
optimizer_state_dict = state_dict['optimizer_world_model']
try:
self._optimizer_world_model.load_state_dict(optimizer_state_dict)
except Exception as e:
logging.info(f"Error loading optimizer state_dict: {e}")
else:
logging.info("No 'optimizer_world_model' key found in the state_dict.")

def _monitor_vars_learn(self, num_tasks=2) -> List[str]:
"""
Register the variables to be monitored in learn mode, including multi-task specific variables.
"""
monitored_vars = [
'Current_GPU',
'Max_GPU',
'collect_mcts_temperature',
'collect_epsilon',
'cur_lr_world_model',
'weighted_total_loss',
'total_grad_norm_before_clip_wm',
]

task_specific_vars = [
'noreduce_obs_loss',
'noreduce_orig_policy_loss',
'noreduce_policy_loss',
'noreduce_latent_recon_loss',
'noreduce_perceptual_loss',
'noreduce_latent_state_l2_norms',
'noreduce_policy_entropy',
'noreduce_target_policy_entropy',
'noreduce_reward_loss',
'noreduce_value_loss',
'noreduce_lambd',
'noreduce_value_priority',
'noreduce_value_priority_mean',
]

for var in task_specific_vars:
for task_idx in range(num_tasks):
monitored_vars.append(f'{var}_task{self.task_id + task_idx}')

return monitored_vars
self._learn_model.load_state_dict(state_dict['model'])
self._target_model.load_state_dict(state_dict['target_model'])
self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model'])

# ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ==========
# def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
# """
# Overview:
# Load the state_dict variable into policy learn mode, excluding multi-task related parameters.
# Arguments:
# - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously.
# """
# # 定义需要排除的参数前缀
# exclude_prefixes = [
# '_orig_mod.world_model.head_policy_multi_task.',
# '_orig_mod.world_model.head_value_multi_task.',
# '_orig_mod.world_model.head_rewards_multi_task.',
# '_orig_mod.world_model.head_observations_multi_task.',
# '_orig_mod.world_model.task_emb.'
# ]

# # 定义需要排除的具体参数(如果有特殊情况)
# exclude_keys = [
# '_orig_mod.world_model.task_emb.weight',
# '_orig_mod.world_model.task_emb.bias', # 如果存在则添加
# # 添加其他需要排除的具体参数名
# ]

# def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]:
# """
# 过滤掉需要排除的参数。
# """
# filtered = {}
# for k, v in state_dict_loader.items():
# if any(k.startswith(prefix) for prefix in exclude_prefixes):
# print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除
# continue
# if k in exclude_keys:
# print(f"Excluding specific parameter: {k}") # 调试用
# continue
# filtered[k] = v
# return filtered

# # 过滤并加载 'model' 部分
# if 'model' in state_dict:
# model_state_dict = state_dict['model']
# filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys)
# missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False)
# if missing_keys:
# print(f"Missing keys when loading _learn_model: {missing_keys}")
# if unexpected_keys:
# print(f"Unexpected keys when loading _learn_model: {unexpected_keys}")
# else:
# print("No 'model' key found in the state_dict.")

# # 过滤并加载 'target_model' 部分
# if 'target_model' in state_dict:
# target_model_state_dict = state_dict['target_model']
# filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys)
# missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False)
# if missing_keys:
# print(f"Missing keys when loading _target_model: {missing_keys}")
# if unexpected_keys:
# print(f"Unexpected keys when loading _target_model: {unexpected_keys}")
# else:
# print("No 'target_model' key found in the state_dict.")

# # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数
# if 'optimizer_world_model' in state_dict:
# optimizer_state_dict = state_dict['optimizer_world_model']
# try:
# self._optimizer_world_model.load_state_dict(optimizer_state_dict)
# except Exception as e:
# print(f"Error loading optimizer state_dict: {e}")
# else:
# print("No 'optimizer_world_model' key found in the state_dict.")

# # 如果需要,还可以加载其他部分,例如 scheduler 等
29 changes: 20 additions & 9 deletions zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
device='cuda',
# device='cpu', # TODO
# num_layers=2,
# num_layers=8, # TODO
num_layers=4, # TODO
num_heads=8,
embed_dim=768,
Expand All @@ -79,14 +80,16 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
),
total_batch_size=total_batch_size,
allocated_batch_sizes=False,
train_start_after_envsteps=int(2e3),
# train_start_after_envsteps=int(2e3),
train_start_after_envsteps=int(0),
use_priority=False,
print_task_priority_logs=False,
cuda=True,
model_path=None,
num_unroll_steps=num_unroll_steps,
# update_per_collect=2, # TODO: 80
update_per_collect=80, # TODO: 80
# update_per_collect=200, # TODO: 8*100*0.25=200
update_per_collect=80, # TODO: 8*100*0.1=80
replay_ratio=reanalyze_ratio,
batch_size=batch_size,
optim_type='AdamW',
Expand All @@ -96,6 +99,14 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
n_episode=n_episode,
replay_buffer_size=int(1e6),
eval_freq=int(5e3),
grad_clip_value=5,
learning_rate=1e-4,
discount_factor=0.99,
td_steps=5,
piecewise_decay_lr_scheduler=False,
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=int(2.5e4),
cos_lr_scheduler=True,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
buffer_reanalyze_freq=buffer_reanalyze_freq,
Expand All @@ -122,7 +133,7 @@ def generate_configs(env_id_list: List[str],
num_segments: int,
total_batch_size: int):
configs = []
exp_name_prefix = f'data_suz_mt_20241224/debug_ddp_8gpu_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_seed{seed}/'
exp_name_prefix = f'data_suz_mt_20241230/ddp_8gpu_nlayer8_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_seed{seed}/'
action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list]
observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list]

Expand Down Expand Up @@ -237,12 +248,12 @@ def create_env_manager():
reanalyze_partition = 0.75

# ======== TODO: only for debug ========
collector_env_num = 2
num_segments = 2
n_episode = 2
evaluator_env_num = 2
num_simulations = 2
batch_size = [4 for _ in range(len(env_id_list))]
# collector_env_num = 2
# num_segments = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 2
# batch_size = [4 for _ in range(len(env_id_list))]
# =======================================

seed = 0 # You can iterate over multiple seeds if needed
Expand Down
12 changes: 6 additions & 6 deletions zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def create_config(env_id, action_space_size_list, observation_shape_list, collec
replay_path_gif='./replay_gif',
),
policy=dict(
multi_gpu=False,
multi_gpu=False, # TODO: nable multi-GPU for DDP
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))),
model=dict(
observation_shape_list=observation_shape_list,
Expand Down Expand Up @@ -87,11 +87,6 @@ def create_config(env_id, action_space_size_list, observation_shape_list, collec
num_simulations=num_simulations,
n_episode=n_episode,
replay_buffer_size=int(1e6),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
buffer_reanalyze_freq=buffer_reanalyze_freq,
reanalyze_batch_size=reanalyze_batch_size,
reanalyze_partition=reanalyze_partition,
grad_clip_value=5,
learning_rate=1e-4,
discount_factor=0.99,
Expand All @@ -100,6 +95,11 @@ def create_config(env_id, action_space_size_list, observation_shape_list, collec
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=int(2.5e4),
cos_lr_scheduler=True,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
buffer_reanalyze_freq=buffer_reanalyze_freq,
reanalyze_batch_size=reanalyze_batch_size,
reanalyze_partition=reanalyze_partition,
),
seed=seed,
))
Expand Down

0 comments on commit 1c8c4fb

Please sign in to comment.