diff --git a/end/backbones/text_classifiers.py b/end/backbones/text_classifiers.py index c2e7034..f95fd94 100755 --- a/end/backbones/text_classifiers.py +++ b/end/backbones/text_classifiers.py @@ -12,7 +12,7 @@ class SoftCrossEntropy(torch.nn.Module): def __init__(self, reduction='mean'): super(SoftCrossEntropy, self).__init__() - self.reduction = reductione + self.reduction = reduction def forward(self, pred, target): lsm = pred.log_softmax(dim=1) loss = torch.sum(-target * lsm) @@ -122,4 +122,4 @@ def get_metrics(self, reset: bool = False): metric_dict['average_F1'] = average_f1 metric_dict['acc'] = self.accuracy.get_metric(reset) # metric_dict['acc'] = self.accuracy.get_metric(reset) - return metric_dict \ No newline at end of file + return metric_dict