Skip to content

Commit

Permalink
fix(pu): fix padded action_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
dyyoungg committed Dec 13, 2024
1 parent 494921d commit 0540c20
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
7 changes: 6 additions & 1 deletion lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,13 @@ def _preprocess_to_play_and_action_mask(
unroll_steps + 1]
)
if len(action_mask_tmp) < unroll_steps + 1:
# action_mask_tmp += [
# list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
# for _ in range(unroll_steps + 1 - len(action_mask_tmp))
# ]
# TODO
action_mask_tmp += [
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
list(np.zeros(self._cfg.model.action_space_size, dtype=np.int8))
for _ in range(unroll_steps + 1 - len(action_mask_tmp))
]
action_mask.append(action_mask_tmp)
Expand Down
6 changes: 4 additions & 2 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,9 @@ def _compute_target_policy_non_reanalyzed(
]
else:
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]

# print(f'='*20)
# print(f'buffer_muzero: action_mask:{action_mask}')

with torch.no_grad():
policy_index = 0
# 0 -> Invalid target policy for padding outside of game segments,
Expand Down Expand Up @@ -740,7 +742,7 @@ def _compute_target_policy_non_reanalyzed(
except Exception as e:
print('='*20)
print(f'Exception:{e}, distributions:{distributions}, legal_action:{legal_actions[policy_index]}')
# TODO
# TODO: 出现这个问题的原因在于采样的序列末尾可能是padding的action_mask是以np.zeros(self._cfg.model.action_space_size, dtype=np.int8)进行pad的
target_policies.append(policy_tmp)
else:
# NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
Expand Down
5 changes: 3 additions & 2 deletions zoo/jericho/configs/jericho_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main(env_id='detective.z5', seed=0):
num_unroll_steps = 10
infer_context_length = 4
num_layers = 2
replay_ratio = 0.25
replay_ratio = 0.1
embed_dim = 768
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
# buffer_reanalyze_freq = 1/10
Expand Down Expand Up @@ -75,6 +75,7 @@ def main(env_id='detective.z5', seed=0):
encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594',
# The input of the model is text, whose shape is identical to the mlp model.
model_type='mlp',
continuous_action_space=False,
world_model_cfg=dict(
policy_entropy_weight=5e-3,
continuous_action_space=False,
Expand All @@ -91,7 +92,7 @@ def main(env_id='detective.z5', seed=0):
env_num=max(collector_env_num, evaluator_env_num),
),
),
action_type = 'varied_action_space',
action_type='varied_action_space',
model_path=None,
num_unroll_steps=num_unroll_steps,
reanalyze_ratio=0,
Expand Down
2 changes: 1 addition & 1 deletion zoo/jericho/envs/jericho_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def prepare_obs(self, obs, return_str: bool = False):
if not return_str:
full_obs = JerichoEnv.tokenizer(
[full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len)
# obs_attn_mask = full_obs['attn_mask']
# obs_attn_mask = full_obs['attention_mask']
full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) # TODO: attn_mask
if len(self._action_list) <= self.max_action_num:
action_mask = [1] * len(self._action_list) + [0] * \
Expand Down

0 comments on commit 0540c20

Please sign in to comment.