Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
hukkai committed Oct 12, 2023
1 parent 788964f commit 5d935d4
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions mmaction/models/action_segmentors/asformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def forward(self, inputs, data_samples, mode, **kwargs):
- If ``mode="loss"``, return a dict of tensor.
"""
input = torch.stack(inputs)
if mode == 'tensor':
return self._forward(inputs, **kwargs)
if mode == 'predict':
return self.predict(input, data_samples, **kwargs)
elif mode == 'loss':
Expand Down Expand Up @@ -169,19 +167,6 @@ def predict(self, batch_inputs, batch_data_samples, **kwargs):
output = [dict(ground=ground, recognition=recognition)]
return output

def _forward(self, x):
"""Define the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
print(x.shape)

return x.shape


def exponential_descrease(idx_decoder, p=3):
return math.exp(-p * idx_decoder)

Expand Down Expand Up @@ -448,6 +433,13 @@ def __init__(self, dilation, in_channels, out_channels):
dilation=dilation), nn.ReLU())

def forward(self, x):
"""Define the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
return self.layer(x)


Expand Down Expand Up @@ -579,7 +571,7 @@ def forward(self, x, fencoder, mask):


class MyTransformer(nn.Module):

"""An encoder-decoder transformer"""
def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim,
num_classes, channel_masking_rate):
super(MyTransformer, self).__init__()
Expand Down Expand Up @@ -608,6 +600,13 @@ def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim,
]) # num_decoders

def forward(self, x, mask):
"""Define the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
out, feature = self.encoder(x, mask)
outputs = out.unsqueeze(0)

Expand All @@ -617,4 +616,4 @@ def forward(self, x, mask):
feature * mask[:, 0:1, :], mask)
outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)

return outputs
return outputs

0 comments on commit 5d935d4

Please sign in to comment.