-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix bug, support transpose convolution
- Loading branch information
Showing
8 changed files
with
288 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Implementation of a memory saving 1d transpose convolution layer.""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from memsave_torch.nn.functional import conv_transpose1dMemSave | ||
|
||
|
||
class MemSaveConvTranspose1d(nn.ConvTranspose1d): | ||
"""Differentiability-agnostic 1d transpose convolution layer.""" | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
"""Forward pass. | ||
Args: | ||
input: Input to the network [B, C_in, W] | ||
Returns: | ||
torch.Tensor: Output [B, C_out, W_out] | ||
""" | ||
return conv_transpose1dMemSave( | ||
input, | ||
self.weight, | ||
self.bias, | ||
self.stride, | ||
self.padding, | ||
self.output_padding, | ||
self.dilation, | ||
self.groups, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Implementation of a memory saving 1d transpose convolution layer.""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from memsave_torch.nn.functional import conv_transpose2dMemSave | ||
|
||
|
||
class MemSaveConvTranspose2d(nn.ConvTranspose2d): | ||
"""Differentiability-agnostic 2d transpose convolution layer.""" | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
"""Forward pass. | ||
Args: | ||
input: Input to the network [B, C_in, H, W] | ||
Returns: | ||
torch.Tensor: Output [B, C_out, H_out, W_out] | ||
""" | ||
return conv_transpose2dMemSave( | ||
input, | ||
self.weight, | ||
self.bias, | ||
self.stride, | ||
self.padding, | ||
self.output_padding, | ||
self.dilation, | ||
self.groups, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Implementation of a memory saving 1d transpose convolution layer.""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from memsave_torch.nn.functional import conv_transpose3dMemSave | ||
|
||
|
||
class MemSaveConvTranspose3d(nn.ConvTranspose3d): | ||
"""Differentiability-agnostic 3d transpose convolution layer.""" | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
"""Forward pass. | ||
Args: | ||
input: Input to the network [B, C_in, D, H, W] | ||
Returns: | ||
torch.Tensor: Output [B, C_out, D_out, H_out, W_out] | ||
""" | ||
return conv_transpose3dMemSave( | ||
input, | ||
self.weight, | ||
self.bias, | ||
self.stride, | ||
self.padding, | ||
self.output_padding, | ||
self.dilation, | ||
self.groups, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
"""Implementation of memory saving transpose convolution layers. | ||
This is done by not saving the inputs/weights if weight/inputs dont require grad. | ||
""" | ||
|
||
import torch | ||
|
||
|
||
class _MemSaveConvTranspose(torch.autograd.Function): | ||
@staticmethod | ||
def forward(x, weight, bias, stride, padding, output_padding, dilation, groups): | ||
return torch.ops.aten.convolution( | ||
x, | ||
weight, | ||
bias, | ||
stride, | ||
padding, | ||
dilation, | ||
True, | ||
output_padding, | ||
groups, | ||
) | ||
|
||
@staticmethod | ||
def setup_context(ctx, inputs, output): | ||
x, weight, bias, stride, padding, output_padding, dilation, groups = inputs | ||
need_grad = [] | ||
if ctx.needs_input_grad[0]: | ||
need_grad.append(weight) | ||
if ctx.needs_input_grad[1]: | ||
need_grad.append(x) | ||
# bias doesnt need anything for calc | ||
ctx.bias_exists = bias is not None | ||
ctx.stride = stride | ||
ctx.padding = padding | ||
ctx.output_padding = output_padding | ||
ctx.dilation = dilation | ||
ctx.groups = groups | ||
ctx.x_shape = x.shape | ||
ctx.weight_shape = weight.shape | ||
ctx.device = x.device | ||
|
||
ctx.save_for_backward(*need_grad) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = weight = None | ||
|
||
current_idx = 0 | ||
if ctx.needs_input_grad[0]: | ||
weight = ctx.saved_tensors[current_idx] | ||
current_idx += 1 | ||
if ctx.needs_input_grad[1]: | ||
x = ctx.saved_tensors[current_idx] | ||
current_idx += 1 | ||
|
||
if x is None: | ||
x = torch.zeros(ctx.x_shape, device=ctx.device) | ||
if weight is None: | ||
weight = torch.zeros(ctx.weight_shape, device=ctx.device) | ||
|
||
grad_x, grad_weight, grad_bias = torch.ops.aten.convolution_backward( | ||
grad_output, | ||
x, | ||
weight, | ||
[weight.shape[0]] if ctx.bias_exists else None, | ||
ctx.stride, | ||
ctx.padding, | ||
ctx.dilation, | ||
True, | ||
ctx.output_padding, | ||
ctx.groups, | ||
ctx.needs_input_grad[:3], | ||
) | ||
|
||
return grad_x, grad_weight, grad_bias, None, None, None, None, None, None | ||
|
||
|
||
def conv_transpose1dMemSave( | ||
input, weight, bias, stride, padding, output_padding, dilation, groups | ||
) -> torch.Tensor: | ||
"""Functional form of the memory saving transpose convolution. | ||
Args: | ||
input: input [B, C_in, H, W] | ||
weight: weight | ||
bias: bias | ||
stride: stride | ||
padding: padding | ||
dilation: dilation | ||
groups: groups | ||
Returns: | ||
torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out] | ||
""" | ||
return _MemSaveConvTranspose.apply( | ||
input, weight, bias, stride, padding, output_padding, dilation, groups | ||
) | ||
|
||
|
||
def conv_transpose2dMemSave( | ||
input, weight, bias, stride, padding, output_padding, dilation, groups | ||
) -> torch.Tensor: | ||
"""Functional form of the memory saving transpose convolution. | ||
Args: | ||
input: input [B, C_in, H, W] | ||
weight: weight | ||
bias: bias | ||
stride: stride | ||
padding: padding | ||
dilation: dilation | ||
groups: groups | ||
Returns: | ||
torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out] | ||
""" | ||
return _MemSaveConvTranspose.apply( | ||
input, weight, bias, stride, padding, output_padding, dilation, groups | ||
) | ||
|
||
|
||
def conv_transpose3dMemSave( | ||
input, weight, bias, stride, padding, output_padding, dilation, groups | ||
) -> torch.Tensor: | ||
"""Functional form of the memory saving transpose convolution. | ||
Args: | ||
input: input [B, C_in, D, H, W] | ||
weight: weight | ||
bias: bias | ||
stride: stride | ||
padding: padding | ||
dilation: dilation | ||
groups: groups | ||
Returns: | ||
torch.Tensor: Output of the conv operation [B, C_out, D_out, H_out, W_out] | ||
""" | ||
return _MemSaveConvTranspose.apply( | ||
input, weight, bias, stride, padding, output_padding, dilation, groups | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters