Skip to content

Commit

Permalink
Merge pull request #7 from ai-med/loss_updates
Browse files Browse the repository at this point in the history
17122020-CE & Dice loss updates
  • Loading branch information
jyotirmay123 authored Dec 17, 2020
2 parents 9c1aefe + 590df23 commit fe5381c
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 64 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ You need to have following in order for this library to work as expected
Always use the latest release. Use following command with appropriate version no(v1.0) in this particular case to install. You can find the link for the latest release in the release section of this github repo

```
pip install https://github.com/ai-med/nn-common-modules/releases/download/v1.0/nn_common_modules-1.0-py2.py3-none-any.whl
pip install https://github.com/ai-med/nn-common-modules/releases/download/v1.1/nn_common_modules-1.3-py3-none-any.whl
```

## Authors
Expand Down
89 changes: 55 additions & 34 deletions build/lib/nn_common_modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from nn_common_modules import losses as additional_losses
loss = additional_losses.DiceLoss()
Note: If you use DiceLoss, insert Softmax layer in the architecture. In case of combined loss, do not put softmax as it is in-built
Members
++++++++++++++++++++++
"""
Expand All @@ -26,21 +28,20 @@ class DiceLoss(_WeightedLoss):
Dice Loss for a batch of samples
"""

def forward(self, output, target, weights=None, ignore_index=None, binary=False):
def forward(self, output, target, weights=None, binary=False):
"""
Forward pass
:param output: NxCxHxW logits
:param target: NxHxW LongTensor
:param weights: C FloatTensor
:param ignore_index: int index to ignore from loss
:param binary: bool for binarized one chaneel(C=1) input
:return: torch.tensor
"""
output = F.softmax(output, dim=1)
if binary:
return self._dice_loss_binary(output, target)
return self._dice_loss_multichannel(output, target, weights, ignore_index)
return self._dice_loss_multichannel(output, target, weights)

@staticmethod
def _dice_loss_binary(output, target):
Expand All @@ -62,43 +63,37 @@ def _dice_loss_binary(output, target):
return loss_per_channel.sum() / output.size(1)

@staticmethod
def _dice_loss_multichannel(output, target, weights=None, ignore_index=None):
def _dice_loss_multichannel(output, target, weights=None):
"""
Forward pass
:param output: NxCxHxW Variable
:param target: NxHxW LongTensor
:param weights: C FloatTensor
:param ignore_index: int index to ignore from loss
:param binary: bool for binarized one chaneel(C=1) input
:return:
"""
eps = 0.0001
encoded_target = output.detach() * 0

if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)
output = F.softmax(output, dim=1)
eps = 0.0001
target = target.unsqueeze(1)
encoded_target = torch.zeros_like(output)

if weights is None:
weights = 1
encoded_target = encoded_target.scatter(1, target, 1)

intersection = output * encoded_target
numerator = 2 * intersection.sum(0).sum(1).sum(1)
denominator = output + encoded_target
intersection = intersection.sum(2).sum(2)

if ignore_index is not None:
denominator[mask] = 0
denominator = denominator.sum(0).sum(1).sum(1) + eps
loss_per_channel = weights * (1 - (numerator / denominator))
num_union_pixels = output + encoded_target
num_union_pixels = num_union_pixels.sum(2).sum(2)

return loss_per_channel.sum() / output.size(1)
loss_per_class = 1 - ((2 * intersection) / (num_union_pixels + eps))
# loss_per_class = 1 - ((2 * intersection + 1) / (num_union_pixels + 1))
if weights is None:
weights = torch.ones_like(loss_per_class)
loss_per_class *= weights

return (loss_per_class.sum(1) / (num_union_pixels != 0).sum(1).float()).mean()


class IoULoss(_WeightedLoss):
Expand Down Expand Up @@ -141,7 +136,7 @@ def forward(self, output, target, weights=None, ignore_index=None):

intersection = output * encoded_target
numerator = intersection.sum(0).sum(1).sum(1)
denominator = (output + encoded_target) - (output*encoded_target)
denominator = (output + encoded_target) - (output * encoded_target)

if ignore_index is not None:
denominator[mask] = 0
Expand Down Expand Up @@ -178,7 +173,7 @@ class CombinedLoss(_Loss):

def __init__(self):
super(CombinedLoss, self).__init__()
self.cross_entropy_loss = CrossEntropyLoss2d()
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') # CrossEntropyLoss2d()
self.dice_loss = DiceLoss()
self.focal_loss = FocalLoss()
self.l2_loss = nn.MSELoss()
Expand All @@ -192,16 +187,42 @@ def forward(self, input, target, weight=None):
:param weight: torch.tensor (NxHxW)
:return: scalar
"""
input_soft = F.softmax(input, dim=1)
y_2 = torch.mean(self.dice_loss(input_soft, target))
# input_soft = F.softmax(input, dim=1)
y_2 = torch.mean(self.dice_loss(input, target))
if weight is None:
y_1 = torch.mean(self.cross_entropy_loss.forward(input, target))
else:
y_1 = torch.mean(
torch.mul(self.cross_entropy_loss.forward(input, target), weight.cuda()))
torch.mul(self.cross_entropy_loss.forward(input, target), weight))
return y_1 + y_2


class CombinedLoss_KLdiv(_Loss):
"""
A combination of dice and cross entropy loss
"""

def __init__(self):
super(CombinedLoss_KLdiv, self).__init__()
self.cross_entropy_loss = CrossEntropyLoss2d()
self.dice_loss = DiceLoss()

def forward(self, input, target, weight=None):
"""
Forward pass
"""
input, kl_div_loss = input
# input_soft = F.softmax(input, dim=1)
y_2 = torch.mean(self.dice_loss(input, target))
if weight is None:
y_1 = torch.mean(self.cross_entropy_loss.forward(input, target))
else:
y_1 = torch.mean(
torch.mul(self.cross_entropy_loss.forward(input, target), weight))
return y_1, y_2, kl_div_loss


# Credit to https://github.com/clcarwin/focal_loss_pytorch
class FocalLoss(nn.Module):
"""
Expand All @@ -214,7 +235,7 @@ def __init__(self, gamma=2, alpha=None, size_average=True):
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha, (float, int)):
self.alpha = torch.Tensor([alpha, 1-alpha])
self.alpha = torch.Tensor([alpha, 1 - alpha])
if isinstance(alpha, list):
self.alpha = torch.Tensor(alpha)
self.size_average = size_average
Expand All @@ -233,8 +254,8 @@ def forward(self, input, target):
if input.dim() > 2:
# N,C,H,W => N,C,H*W
input = input.view(input.size(0), input.size(1), -1)
input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
target = target.view(-1, 1)

logpt = F.log_softmax(input, dim=1)
Expand All @@ -248,7 +269,7 @@ def forward(self, input, target):
at = self.alpha.gather(0, target.data.view(-1))
logpt = logpt * Variable(at)

loss = -1 * (1-pt)**self.gamma * logpt
loss = -1 * (1 - pt) ** self.gamma * logpt
if self.size_average:
return loss.mean()
else:
Expand Down
3 changes: 2 additions & 1 deletion build/lib/nn_common_modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self, params, se_block_type=None):
self.SELayer = se.ChannelSpatialSELayer(params['num_filters'])
else:
self.SELayer = None

print(se.SELayer.CSE.value, se_block_type)

padding_h = int((params['kernel_h'] - 1) / 2)
padding_w = int((params['kernel_w'] - 1) / 2)

Expand Down
Binary file added dist/nn_common_modules-1.2-py3-none-any.whl
Binary file not shown.
Binary file added dist/nn_common_modules-1.3-py3-none-any.whl
Binary file not shown.
45 changes: 19 additions & 26 deletions nn_common_modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,20 @@ class DiceLoss(_WeightedLoss):
Dice Loss for a batch of samples
"""

def forward(self, output, target, weights=None, ignore_index=None, binary=False):
def forward(self, output, target, weights=None, binary=False):
"""
Forward pass
:param output: NxCxHxW logits
:param target: NxHxW LongTensor
:param weights: C FloatTensor
:param ignore_index: int index to ignore from loss
:param binary: bool for binarized one chaneel(C=1) input
:return: torch.tensor
"""
output = F.softmax(output, dim=1)
if binary:
return self._dice_loss_binary(output, target)
return self._dice_loss_multichannel(output, target, weights, ignore_index)
return self._dice_loss_multichannel(output, target, weights)

@staticmethod
def _dice_loss_binary(output, target):
Expand All @@ -64,43 +63,37 @@ def _dice_loss_binary(output, target):
return loss_per_channel.sum() / output.size(1)

@staticmethod
def _dice_loss_multichannel(output, target, weights=None, ignore_index=None):
def _dice_loss_multichannel(output, target, weights=None):
"""
Forward pass
:param output: NxCxHxW Variable
:param target: NxHxW LongTensor
:param weights: C FloatTensor
:param ignore_index: int index to ignore from loss
:param binary: bool for binarized one chaneel(C=1) input
:return:
"""
eps = 0.0001
encoded_target = output.detach() * 0

if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)
output = F.softmax(output, dim=1)
eps = 0.0001
target = target.unsqueeze(1)
encoded_target = torch.zeros_like(output)

if weights is None:
weights = 1
encoded_target = encoded_target.scatter(1, target, 1)

intersection = output * encoded_target
numerator = 2 * intersection.sum(0).sum(1).sum(1)
denominator = output + encoded_target
intersection = intersection.sum(2).sum(2)

if ignore_index is not None:
denominator[mask] = 0
denominator = denominator.sum(0).sum(1).sum(1) + eps
loss_per_channel = weights * (1 - (numerator / denominator))
num_union_pixels = output + encoded_target
num_union_pixels = num_union_pixels.sum(2).sum(2)

return loss_per_channel.sum() / output.size(1)
loss_per_class = 1 - ((2 * intersection) / (num_union_pixels + eps))
# loss_per_class = 1 - ((2 * intersection + 1) / (num_union_pixels + 1))
if weights is None:
weights = torch.ones_like(loss_per_class)
loss_per_class *= weights

return (loss_per_class.sum(1) / (num_union_pixels != 0).sum(1).float()).mean()


class IoULoss(_WeightedLoss):
Expand Down Expand Up @@ -180,7 +173,7 @@ class CombinedLoss(_Loss):

def __init__(self):
super(CombinedLoss, self).__init__()
self.cross_entropy_loss = CrossEntropyLoss2d()
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') # CrossEntropyLoss2d()
self.dice_loss = DiceLoss()
self.focal_loss = FocalLoss()
self.l2_loss = nn.MSELoss()
Expand Down
3 changes: 2 additions & 1 deletion nn_common_modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self, params, se_block_type=None):
self.SELayer = se.ChannelSpatialSELayer(params['num_filters'])
else:
self.SELayer = None

print(se.SELayer.CSE.value, se_block_type)

padding_h = int((params['kernel_h'] - 1) / 2)
padding_w = int((params['kernel_w'] - 1) / 2)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import setuptools

setuptools.setup(name="nn-common-modules",
version="1.1",
version="1.3",
url="https://github.com/abhi4ssj/nn-common-modules",
author="Shayan Ahmad Siddiqui, Abhijit Guha Roy",
author_email="[email protected]",
Expand Down

0 comments on commit fe5381c

Please sign in to comment.