From e1e49c29249fff60dcc025c7fbbbe156c985d409 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 8 Jul 2024 11:57:55 -0400 Subject: [PATCH] Fix bug, support transpose convolution --- experiments/visual_abstract/run.py | 61 ++++++-- memsave_torch/nn/ConvTranspose1d.py | 29 ++++ memsave_torch/nn/ConvTranspose2d.py | 29 ++++ memsave_torch/nn/ConvTranspose3d.py | 29 ++++ memsave_torch/nn/__init__.py | 4 +- memsave_torch/nn/functional/Conv.py | 2 +- memsave_torch/nn/functional/ConvTranspose.py | 142 +++++++++++++++++++ memsave_torch/nn/functional/__init__.py | 5 + 8 files changed, 288 insertions(+), 13 deletions(-) create mode 100644 memsave_torch/nn/ConvTranspose1d.py create mode 100644 memsave_torch/nn/ConvTranspose2d.py create mode 100644 memsave_torch/nn/ConvTranspose3d.py create mode 100644 memsave_torch/nn/functional/ConvTranspose.py diff --git a/experiments/visual_abstract/run.py b/experiments/visual_abstract/run.py index 3426e54..add5462 100644 --- a/experiments/visual_abstract/run.py +++ b/experiments/visual_abstract/run.py @@ -6,17 +6,30 @@ from os import makedirs, path from memory_profiler import memory_usage -from torch import allclose, manual_seed, rand, rand_like -from torch.autograd import grad -from torch.nn import BatchNorm2d, Conv1d, Conv2d, Conv3d, Linear, Sequential - from memsave_torch.nn import ( MemSaveBatchNorm2d, MemSaveConv1d, MemSaveConv2d, MemSaveConv3d, + MemSaveConvTranspose1d, + MemSaveConvTranspose2d, + MemSaveConvTranspose3d, MemSaveLinear, ) +from memsave_torch.nn.ConvTranspose1d import MemSaveConvTranspose1d +from torch import allclose, manual_seed, rand, rand_like +from torch.autograd import grad +from torch.nn import ( + BatchNorm2d, + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, + Linear, + Sequential, +) HEREDIR = path.dirname(path.abspath(__file__)) DATADIR = path.join(HEREDIR, "raw") @@ -36,11 +49,11 @@ def main( # noqa: C901 # create the input if architecture == "linear": X = rand(512, 1024, 256) - elif architecture == "conv1d": + elif architecture in {"conv1d", "conv_transpose1d"}: X = rand(4096, 8, 4096) - elif architecture in {"conv2d", "bn2d"}: + elif architecture in {"conv2d", "bn2d", "conv_transpose2d"}: X = rand(256, 8, 256, 256) - elif architecture == "conv3d": + elif architecture in {"conv3d", "conv_transpose3d"}: X = rand(64, 8, 64, 64, 64) else: raise ValueError(f"Invalid argument for architecture: {architecture}.") @@ -66,6 +79,21 @@ def main( # noqa: C901 implementation ] layers[f"{architecture}{i}"] = layer_cls(8) + elif architecture == "conv_transpose1d": + layer_cls = {"ours": MemSaveConvTranspose1d, "torch": ConvTranspose1d}[ + implementation + ] + layers[f"{architecture}{i}"] = layer_cls(8, 8, 3, padding=1, bias=False) + elif architecture == "conv_transpose2d": + layer_cls = {"ours": MemSaveConvTranspose2d, "torch": ConvTranspose2d}[ + implementation + ] + layers[f"{architecture}{i}"] = layer_cls(8, 8, 3, padding=1, bias=False) + elif architecture == "conv_transpose3d": + layer_cls = {"ours": MemSaveConvTranspose3d, "torch": ConvTranspose3d}[ + implementation + ] + layers[f"{architecture}{i}"] = layer_cls(8, 8, 3, padding=1, bias=False) else: raise ValueError(f"Invalid argument for architecture: {architecture}.") @@ -144,22 +172,33 @@ def check_equality(architecture: str, mode: str, num_layers: int, requires_grad: parser.add_argument( "--requires_grad", type=str, - choices=["all", "none", "4", "4+"], + choices={"all", "none", "4", "4+"}, help="Which layers are differentiable.", ) parser.add_argument( "--implementation", type=str, - choices=["torch", "ours"], + choices={"torch", "ours"}, help="Which implementation to use.", ) parser.add_argument( "--architecture", type=str, - choices=["linear", "conv1d", "conv2d", "conv3d", "bn2d"], + choices={ + "linear", + "conv1d", + "conv2d", + "conv3d", + "bn2d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + }, help="Which architecture to use.", ) - parser.add_argument("--mode", type=str, help="Mode of the network.") + parser.add_argument( + "--mode", type=str, help="Mode of the network.", choices={"train", "eval"} + ) parser.add_argument( "--skip_existing", action="store_true", help="Skip existing files." ) diff --git a/memsave_torch/nn/ConvTranspose1d.py b/memsave_torch/nn/ConvTranspose1d.py new file mode 100644 index 0000000..08e0425 --- /dev/null +++ b/memsave_torch/nn/ConvTranspose1d.py @@ -0,0 +1,29 @@ +"""Implementation of a memory saving 1d transpose convolution layer.""" + +import torch +import torch.nn as nn +from memsave_torch.nn.functional import conv_transpose1dMemSave + + +class MemSaveConvTranspose1d(nn.ConvTranspose1d): + """Differentiability-agnostic 1d transpose convolution layer.""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + input: Input to the network [B, C_in, W] + + Returns: + torch.Tensor: Output [B, C_out, W_out] + """ + return conv_transpose1dMemSave( + input, + self.weight, + self.bias, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) diff --git a/memsave_torch/nn/ConvTranspose2d.py b/memsave_torch/nn/ConvTranspose2d.py new file mode 100644 index 0000000..49170f2 --- /dev/null +++ b/memsave_torch/nn/ConvTranspose2d.py @@ -0,0 +1,29 @@ +"""Implementation of a memory saving 1d transpose convolution layer.""" + +import torch +import torch.nn as nn +from memsave_torch.nn.functional import conv_transpose2dMemSave + + +class MemSaveConvTranspose2d(nn.ConvTranspose2d): + """Differentiability-agnostic 2d transpose convolution layer.""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + input: Input to the network [B, C_in, H, W] + + Returns: + torch.Tensor: Output [B, C_out, H_out, W_out] + """ + return conv_transpose2dMemSave( + input, + self.weight, + self.bias, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) diff --git a/memsave_torch/nn/ConvTranspose3d.py b/memsave_torch/nn/ConvTranspose3d.py new file mode 100644 index 0000000..aa94678 --- /dev/null +++ b/memsave_torch/nn/ConvTranspose3d.py @@ -0,0 +1,29 @@ +"""Implementation of a memory saving 1d transpose convolution layer.""" + +import torch +import torch.nn as nn +from memsave_torch.nn.functional import conv_transpose3dMemSave + + +class MemSaveConvTranspose3d(nn.ConvTranspose3d): + """Differentiability-agnostic 3d transpose convolution layer.""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + input: Input to the network [B, C_in, D, H, W] + + Returns: + torch.Tensor: Output [B, C_out, D_out, H_out, W_out] + """ + return conv_transpose3dMemSave( + input, + self.weight, + self.bias, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index 72817bd..35b5514 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -9,12 +9,14 @@ import sys import torch.nn as nn - from memsave_torch.nn import functional # noqa: F401 from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d from memsave_torch.nn.Conv1d import MemSaveConv1d from memsave_torch.nn.Conv2d import MemSaveConv2d from memsave_torch.nn.Conv3d import MemSaveConv3d +from memsave_torch.nn.ConvTranspose1d import MemSaveConvTranspose1d +from memsave_torch.nn.ConvTranspose2d import MemSaveConvTranspose2d +from memsave_torch.nn.ConvTranspose3d import MemSaveConvTranspose3d from memsave_torch.nn.Dropout import MemSaveDropout from memsave_torch.nn.LayerNorm import ( MemSaveLayerNorm, diff --git a/memsave_torch/nn/functional/Conv.py b/memsave_torch/nn/functional/Conv.py index 2e513b4..2783482 100644 --- a/memsave_torch/nn/functional/Conv.py +++ b/memsave_torch/nn/functional/Conv.py @@ -49,7 +49,7 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: weight = ctx.saved_tensors[current_idx] current_idx += 1 - elif ctx.needs_input_grad[1]: + if ctx.needs_input_grad[1]: x = ctx.saved_tensors[current_idx] current_idx += 1 diff --git a/memsave_torch/nn/functional/ConvTranspose.py b/memsave_torch/nn/functional/ConvTranspose.py new file mode 100644 index 0000000..e44683c --- /dev/null +++ b/memsave_torch/nn/functional/ConvTranspose.py @@ -0,0 +1,142 @@ +"""Implementation of memory saving transpose convolution layers. + +This is done by not saving the inputs/weights if weight/inputs dont require grad. +""" + +import torch + + +class _MemSaveConvTranspose(torch.autograd.Function): + @staticmethod + def forward(x, weight, bias, stride, padding, output_padding, dilation, groups): + return torch.ops.aten.convolution( + x, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + ) + + @staticmethod + def setup_context(ctx, inputs, output): + x, weight, bias, stride, padding, output_padding, dilation, groups = inputs + need_grad = [] + if ctx.needs_input_grad[0]: + need_grad.append(weight) + if ctx.needs_input_grad[1]: + need_grad.append(x) + # bias doesnt need anything for calc + ctx.bias_exists = bias is not None + ctx.stride = stride + ctx.padding = padding + ctx.output_padding = output_padding + ctx.dilation = dilation + ctx.groups = groups + ctx.x_shape = x.shape + ctx.weight_shape = weight.shape + ctx.device = x.device + + ctx.save_for_backward(*need_grad) + + @staticmethod + def backward(ctx, grad_output): + x = weight = None + + current_idx = 0 + if ctx.needs_input_grad[0]: + weight = ctx.saved_tensors[current_idx] + current_idx += 1 + if ctx.needs_input_grad[1]: + x = ctx.saved_tensors[current_idx] + current_idx += 1 + + if x is None: + x = torch.zeros(ctx.x_shape, device=ctx.device) + if weight is None: + weight = torch.zeros(ctx.weight_shape, device=ctx.device) + + grad_x, grad_weight, grad_bias = torch.ops.aten.convolution_backward( + grad_output, + x, + weight, + [weight.shape[0]] if ctx.bias_exists else None, + ctx.stride, + ctx.padding, + ctx.dilation, + True, + ctx.output_padding, + ctx.groups, + ctx.needs_input_grad[:3], + ) + + return grad_x, grad_weight, grad_bias, None, None, None, None, None, None + + +def conv_transpose1dMemSave( + input, weight, bias, stride, padding, output_padding, dilation, groups +) -> torch.Tensor: + """Functional form of the memory saving transpose convolution. + + Args: + input: input [B, C_in, H, W] + weight: weight + bias: bias + stride: stride + padding: padding + dilation: dilation + groups: groups + + Returns: + torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out] + """ + return _MemSaveConvTranspose.apply( + input, weight, bias, stride, padding, output_padding, dilation, groups + ) + + +def conv_transpose2dMemSave( + input, weight, bias, stride, padding, output_padding, dilation, groups +) -> torch.Tensor: + """Functional form of the memory saving transpose convolution. + + Args: + input: input [B, C_in, H, W] + weight: weight + bias: bias + stride: stride + padding: padding + dilation: dilation + groups: groups + + Returns: + torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out] + """ + return _MemSaveConvTranspose.apply( + input, weight, bias, stride, padding, output_padding, dilation, groups + ) + + +def conv_transpose3dMemSave( + input, weight, bias, stride, padding, output_padding, dilation, groups +) -> torch.Tensor: + """Functional form of the memory saving transpose convolution. + + Args: + input: input [B, C_in, D, H, W] + weight: weight + bias: bias + stride: stride + padding: padding + dilation: dilation + groups: groups + + Returns: + torch.Tensor: Output of the conv operation [B, C_out, D_out, H_out, W_out] + """ + return _MemSaveConvTranspose.apply( + input, weight, bias, stride, padding, output_padding, dilation, groups + ) diff --git a/memsave_torch/nn/functional/__init__.py b/memsave_torch/nn/functional/__init__.py index e53cc5f..f747d4d 100644 --- a/memsave_torch/nn/functional/__init__.py +++ b/memsave_torch/nn/functional/__init__.py @@ -13,6 +13,11 @@ conv2dMemSave, conv3dMemSave, ) +from memsave_torch.nn.functional.ConvTranspose import ( # noqa: F401 + conv_transpose1dMemSave, + conv_transpose2dMemSave, + conv_transpose3dMemSave, +) from memsave_torch.nn.functional.Dropout import dropoutMemSave # noqa: F401 from memsave_torch.nn.functional.LayerNorm import ( # noqa: F401 layer_normMemSave,