Skip to content

Commit

Permalink
feature(pu): add eval_offline option in unizero multitask pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 26, 2024
1 parent 2ebeff1 commit c007917
Show file tree
Hide file tree
Showing 7 changed files with 779 additions and 101 deletions.
94 changes: 85 additions & 9 deletions lzero/entry/train_unizero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def train_unizero_multitask_segment_ddp(
reanalyze_batch_size = cfg.policy.reanalyze_batch_size
update_per_collect = cfg.policy.update_per_collect

if cfg.policy.eval_offline:
eval_train_iter_list = []
eval_train_envstep_list = []

while True:
# 动态调整batch_size
if cfg.policy.allocated_batch_sizes:
Expand Down Expand Up @@ -320,16 +324,21 @@ def train_unizero_multitask_segment_ddp(

# 判断是否需要进行评估
if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
print('=' * 20)
print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...')

# 执行安全评估
stop, reward = safe_eval(evaluator, learner, collector, rank, world_size)
# 判断评估是否成功
if stop is None or reward is None:
print(f"Rank {rank} 在评估过程中遇到问题,继续训练...")
print(f'cfg.policy.eval_offline:{cfg.policy.eval_offline}')
if cfg.policy.eval_offline:
eval_train_iter_list.append(learner.train_iter)
eval_train_envstep_list.append(collector.envstep)
else:
print(f"评估成功: stop={stop}, reward={reward}")
print('=' * 20)
print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...')

# 执行安全评估
stop, reward = safe_eval(evaluator, learner, collector, rank, world_size)
# 判断评估是否成功
if stop is None or reward is None:
print(f"Rank {rank} 在评估过程中遇到问题,继续训练...")
else:
print(f"评估成功: stop={stop}, reward={reward}")

print('=' * 20)
print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...')
Expand Down Expand Up @@ -468,14 +477,81 @@ def train_unizero_multitask_segment_ddp(

max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter)

# 同步所有Rank,确保所有Rank完成训练
try:
dist.barrier()
logging.info(f'Rank {rank}: 通过训练后的同步障碍')
except Exception as e:
logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}')
break

if max_envstep_reached.item() or max_train_iter_reached.item():
logging.info(f'Rank {rank}: 达到终止条件')

if cfg.policy.eval_offline:
# 对于当前进程的每个任务,进行数据收集和评估
for idx, (cfg, collector, evaluator, replay_buffer) in enumerate(
zip(cfgs, collectors, evaluators, game_buffers)):

logging.info(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}: eval offline beginning...')

# ========= 注意目前只有rank0存储ckpt =========
# ckpt_dirname = './data_unizero_mt_ddp-8gpu_20241226/8games_brf0.02_seed0/Pong_seed0/ckpt'

# 让 rank0 生成 ckpt_dirname,其他 Rank 等待接收
if rank == 0:
ckpt_dirname = './{}/ckpt'.format(learner.exp_name)
logging.info(f'Rank {rank}: 生成 ckpt_dirname 为 {ckpt_dirname}')
else:
ckpt_dirname = None

# 使用一个列表来存储 ckpt_dirname
ckpt_dirname_list = [ckpt_dirname]
# 广播 ckpt_dirname
dist.broadcast_object_list(ckpt_dirname_list, src=0)
# 从列表中提取更新后的 ckpt_dirname
ckpt_dirname = ckpt_dirname_list[0]

# 确认所有 Rank 都接收到正确的 ckpt_dirname
logging.info(f'Rank {rank}: 接收到的 ckpt_dirname 为 {ckpt_dirname}')

# 检查 ckpt_dirname 是否有效
if not isinstance(ckpt_dirname, str):
logging.error(f'Rank {rank}: 接收到的 ckpt_dirname 无效')
continue

# Evaluate the performance of the pretrained model.
for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list):
# if train_iter==0:
# continue
ckpt_name = 'iteration_{}.pth.tar'.format(train_iter)
ckpt_path = os.path.join(ckpt_dirname, ckpt_name)
try:
# load the ckpt of pretrained model
policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device))
except Exception as e:
logging.error(f'Rank {rank}: load_state_dict 失败,错误: {e}')
continue

stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep)
logging.info(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}: eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}')

logging.info(f'eval_train_envstep_list: {eval_train_envstep_list}, eval_train_iter_list:{eval_train_iter_list}')

logging.info(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}: eval offline finished!')



dist.barrier() # 确保所有进程同步
# 评估结束后,显式关闭所有评估器
for evaluator in evaluators:
evaluator.close()
break
except Exception as e:
logging.error(f'Rank {rank}: 终止检查失败,错误: {e}')
break


# 调用learner的after_run钩子
learner.call_hook('after_run')
return policy
156 changes: 78 additions & 78 deletions lzero/policy/unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,88 +1144,88 @@ def _state_dict_learn(self) -> Dict[str, Any]:
}

# ========== TODO: original version: load all parameters ==========
# def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
# """
# Overview:
# Load the state_dict variable into policy learn mode.
# Arguments:
# - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
# """
# self._learn_model.load_state_dict(state_dict['model'])
# self._target_model.load_state_dict(state_dict['target_model'])
# self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model'])

# ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ==========
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load the state_dict variable into policy learn mode, excluding multi-task related parameters.
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously.
- state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
"""
# 定义需要排除的参数前缀
exclude_prefixes = [
'_orig_mod.world_model.head_policy_multi_task.',
'_orig_mod.world_model.head_value_multi_task.',
'_orig_mod.world_model.head_rewards_multi_task.',
'_orig_mod.world_model.head_observations_multi_task.',
'_orig_mod.world_model.task_emb.'
]
self._learn_model.load_state_dict(state_dict['model'])
self._target_model.load_state_dict(state_dict['target_model'])
self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model'])

# # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters, head use re init weight ==========
# def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
# """
# Overview:
# Load the state_dict variable into policy learn mode, excluding multi-task related parameters.
# Arguments:
# - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously.
# """
# # 定义需要排除的参数前缀
# exclude_prefixes = [
# '_orig_mod.world_model.head_policy_multi_task.',
# '_orig_mod.world_model.head_value_multi_task.',
# '_orig_mod.world_model.head_rewards_multi_task.',
# '_orig_mod.world_model.head_observations_multi_task.',
# '_orig_mod.world_model.task_emb.'
# ]

# 定义需要排除的具体参数(如果有特殊情况)
exclude_keys = [
'_orig_mod.world_model.task_emb.weight',
'_orig_mod.world_model.task_emb.bias', # 如果存在则添加
# 添加其他需要排除的具体参数名
]
# # 定义需要排除的具体参数(如果有特殊情况)
# exclude_keys = [
# '_orig_mod.world_model.task_emb.weight',
# '_orig_mod.world_model.task_emb.bias', # 如果存在则添加
# # 添加其他需要排除的具体参数名
# ]

def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]:
"""
过滤掉需要排除的参数。
"""
filtered = {}
for k, v in state_dict_loader.items():
if any(k.startswith(prefix) for prefix in exclude_prefixes):
print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除
continue
if k in exclude_keys:
print(f"Excluding specific parameter: {k}") # 调试用
continue
filtered[k] = v
return filtered

# 过滤并加载 'model' 部分
if 'model' in state_dict:
model_state_dict = state_dict['model']
filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys)
missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False)
if missing_keys:
print(f"Missing keys when loading _learn_model: {missing_keys}")
if unexpected_keys:
print(f"Unexpected keys when loading _learn_model: {unexpected_keys}")
else:
print("No 'model' key found in the state_dict.")

# 过滤并加载 'target_model' 部分
if 'target_model' in state_dict:
target_model_state_dict = state_dict['target_model']
filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys)
missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False)
if missing_keys:
print(f"Missing keys when loading _target_model: {missing_keys}")
if unexpected_keys:
print(f"Unexpected keys when loading _target_model: {unexpected_keys}")
else:
print("No 'target_model' key found in the state_dict.")

# 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数
if 'optimizer_world_model' in state_dict:
optimizer_state_dict = state_dict['optimizer_world_model']
try:
self._optimizer_world_model.load_state_dict(optimizer_state_dict)
except Exception as e:
print(f"Error loading optimizer state_dict: {e}")
else:
print("No 'optimizer_world_model' key found in the state_dict.")

# 如果需要,还可以加载其他部分,例如 scheduler 等
# def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]:
# """
# 过滤掉需要排除的参数。
# """
# filtered = {}
# for k, v in state_dict_loader.items():
# if any(k.startswith(prefix) for prefix in exclude_prefixes):
# print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除
# continue
# if k in exclude_keys:
# print(f"Excluding specific parameter: {k}") # 调试用
# continue
# filtered[k] = v
# return filtered

# # 过滤并加载 'model' 部分
# if 'model' in state_dict:
# model_state_dict = state_dict['model']
# filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys)
# missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False)
# if missing_keys:
# print(f"Missing keys when loading _learn_model: {missing_keys}")
# if unexpected_keys:
# print(f"Unexpected keys when loading _learn_model: {unexpected_keys}")
# else:
# print("No 'model' key found in the state_dict.")

# # 过滤并加载 'target_model' 部分
# if 'target_model' in state_dict:
# target_model_state_dict = state_dict['target_model']
# filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys)
# missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False)
# if missing_keys:
# print(f"Missing keys when loading _target_model: {missing_keys}")
# if unexpected_keys:
# print(f"Unexpected keys when loading _target_model: {unexpected_keys}")
# else:
# print("No 'target_model' key found in the state_dict.")

# # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数
# if 'optimizer_world_model' in state_dict:
# optimizer_state_dict = state_dict['optimizer_world_model']
# try:
# self._optimizer_world_model.load_state_dict(optimizer_state_dict)
# except Exception as e:
# print(f"Error loading optimizer state_dict: {e}")
# else:
# print("No 'optimizer_world_model' key found in the state_dict.")

# # 如果需要,还可以加载其他部分,例如 scheduler 等
Loading

0 comments on commit c007917

Please sign in to comment.