Skip to content

Commit

Permalink
polish(pu): polish reward/value/policy_head_hidden_channels
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jan 3, 2025
1 parent 5614025 commit 79b1029
Show file tree
Hide file tree
Showing 37 changed files with 580 additions and 499 deletions.
4 changes: 2 additions & 2 deletions lzero/agent/config/alphazero/tictactoe_play_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
# We use the small size model for tictactoe.
num_res_blocks=1,
num_channels=16,
fc_value_layers=[8],
fc_policy_layers=[8],
value_head_hidden_channels=[8],
policy_head_hidden_channels=[8],
),
cuda=True,
board_size=3,
Expand Down
6 changes: 3 additions & 3 deletions lzero/agent/config/gumbel_muzero/tictactoe_play_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
# We use the small size model for tictactoe.
num_res_blocks=1,
num_channels=16,
fc_reward_layers=[8],
fc_value_layers=[8],
fc_policy_layers=[8],
reward_head_hidden_channels=[8],
value_head_hidden_channels=[8],
policy_head_hidden_channels=[8],
support_scale=10,
reward_support_size=21,
value_support_size=21,
Expand Down
6 changes: 3 additions & 3 deletions lzero/agent/config/muzero/tictactoe_play_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
# We use the small size model for tictactoe.
num_res_blocks=1,
num_channels=16,
fc_reward_layers=[8],
fc_value_layers=[8],
fc_policy_layers=[8],
reward_head_hidden_channels=[8],
value_head_hidden_channels=[8],
policy_head_hidden_channels=[8],
support_scale=10,
reward_support_size=21,
value_support_size=21,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
# We use the small size model for tictactoe.
num_res_blocks=1,
num_channels=16,
fc_value_layers=[8],
fc_policy_layers=[8],
value_head_hidden_channels=[8],
policy_head_hidden_channels=[8],
),
sampled_algo=True,
mcts_ctree=mcts_ctree,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
num_channels=16,
frame_stack_num=1,
model_type='conv',
fc_reward_layers=[8],
fc_value_layers=[8],
fc_policy_layers=[8],
reward_head_hidden_channels=[8],
value_head_hidden_channels=[8],
policy_head_hidden_channels=[8],
support_scale=10,
reward_support_size=21,
value_support_size=21,
Expand Down
64 changes: 32 additions & 32 deletions lzero/model/alphazero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(
num_channels: int = 64,
value_head_channels: int = 16,
policy_head_channels: int = 16,
fc_value_layers: SequenceType = [32],
fc_policy_layers: SequenceType = [32],
value_head_hidden_channels: SequenceType = [32],
policy_head_hidden_channels: SequenceType = [32],
value_support_size: int = 601,
# ==============================================================
# specific sampled related config
Expand Down Expand Up @@ -66,8 +66,8 @@ def __init__(
- num_channels (:obj:`int`): The channels of hidden states.
- value_head_channels (:obj:`int`): The channels of value head.
- policy_head_channels (:obj:`int`): The channels of policy head.
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
- value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
- policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
- value_support_size (:obj:`int`): The size of categorical value.
"""
super(AlphaZeroModel, self).__init__()
Expand Down Expand Up @@ -102,14 +102,14 @@ def __init__(
self.num_of_sampled_actions = num_of_sampled_actions

# TODO use more adaptive way to get the flatten output size
flatten_output_size_for_value_head = (
flatten_input_size_for_value_head = (
(
value_head_channels * math.ceil(self.observation_shape[1] / 16) *
math.ceil(self.observation_shape[2] / 16)
) if downsample else (value_head_channels * self.observation_shape[1] * self.observation_shape[2])
)

flatten_output_size_for_policy_head = (
flatten_input_size_for_policy_head = (
(
policy_head_channels * math.ceil(self.observation_shape[1] / 16) *
math.ceil(self.observation_shape[2] / 16)
Expand All @@ -123,11 +123,11 @@ def __init__(
num_channels,
value_head_channels,
policy_head_channels,
fc_value_layers,
fc_policy_layers,
value_head_hidden_channels,
policy_head_hidden_channels,
self.value_support_size,
flatten_output_size_for_value_head,
flatten_output_size_for_policy_head,
flatten_input_size_for_value_head,
flatten_input_size_for_policy_head,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
activation=activation,
sigma_type=self.sigma_type,
Expand Down Expand Up @@ -216,11 +216,11 @@ def __init__(
num_channels: int,
value_head_channels: int,
policy_head_channels: int,
fc_value_layers: SequenceType,
fc_policy_layers: SequenceType,
value_head_hidden_channels: SequenceType,
policy_head_hidden_channels: SequenceType,
output_support_size: int,
flatten_output_size_for_value_head: int,
flatten_output_size_for_policy_head: int,
flatten_input_size_for_value_head: int,
flatten_input_size_for_policy_head: int,
last_linear_layer_init_zero: bool = True,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
# ==============================================================
Expand All @@ -241,12 +241,12 @@ def __init__(
- num_channels (:obj:`int`): The channels of hidden states.
- value_head_channels (:obj:`int`): The channels of value head.
- policy_head_channels (:obj:`int`): The channels of policy head.
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
- value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
- policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
- output_support_size (:obj:`int`): The size of categorical value output.
- flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
- flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
of the value head.
- flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
- flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
of the policy head.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
value/policy mlp, default sets it to True.
Expand All @@ -255,8 +255,8 @@ def __init__(
"""
super().__init__()
self.continuous_action_space = continuous_action_space
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
self.flatten_input_size_for_value_head = flatten_input_size_for_value_head
self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head
self.norm_type = norm_type
self.sigma_type = sigma_type
self.fixed_sigma_value = fixed_sigma_value
Expand All @@ -274,13 +274,13 @@ def __init__(
self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
self.norm_value = nn.BatchNorm2d(value_head_channels)
self.norm_policy = nn.BatchNorm2d(policy_head_channels)
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
self.flatten_input_size_for_value_head = flatten_input_size_for_value_head
self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head
self.fc_value_head = MLP(
in_channels=self.flatten_output_size_for_value_head,
hidden_channels=fc_value_layers[0],
in_channels=self.flatten_input_size_for_value_head,
hidden_channels=value_head_hidden_channels[0],
out_channels=output_support_size,
layer_num=len(fc_value_layers) + 1,
layer_num=len(value_head_hidden_channels) + 1,
activation=activation,
norm_type='LN',
output_activation=False,
Expand All @@ -291,9 +291,9 @@ def __init__(
# sampled related core code
if self.continuous_action_space:
self.fc_policy_head = ReparameterizationHead(
input_size=self.flatten_output_size_for_policy_head,
input_size=self.flatten_input_size_for_policy_head,
output_size=action_space_size,
layer_num=len(fc_policy_layers) + 1,
layer_num=len(policy_head_hidden_channels) + 1,
sigma_type=self.sigma_type,
fixed_sigma_value=self.fixed_sigma_value,
activation=nn.ReLU(),
Expand All @@ -302,10 +302,10 @@ def __init__(
)
else:
self.fc_policy_head = MLP(
in_channels=self.flatten_output_size_for_policy_head,
hidden_channels=fc_policy_layers[0],
in_channels=self.flatten_input_size_for_policy_head,
hidden_channels=policy_head_hidden_channels[0],
out_channels=action_space_size,
layer_num=len(fc_policy_layers) + 1,
layer_num=len(policy_head_hidden_channels) + 1,
activation=activation,
norm_type='LN',
output_activation=False,
Expand Down Expand Up @@ -340,8 +340,8 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
policy = self.norm_policy(policy)
policy = self.activation(policy)

value = value.reshape(-1, self.flatten_output_size_for_value_head)
policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
value = value.reshape(-1, self.flatten_input_size_for_value_head)
policy = policy.reshape(-1, self.flatten_input_size_for_policy_head)

value = self.fc_value_head(value)

Expand Down
Loading

0 comments on commit 79b1029

Please sign in to comment.