Skip to content

Commit

Permalink
Fix bug, support transpose convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 8, 2024
1 parent 813c97a commit e1e49c2
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 13 deletions.
61 changes: 50 additions & 11 deletions experiments/visual_abstract/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,30 @@
from os import makedirs, path

from memory_profiler import memory_usage
from torch import allclose, manual_seed, rand, rand_like
from torch.autograd import grad
from torch.nn import BatchNorm2d, Conv1d, Conv2d, Conv3d, Linear, Sequential

from memsave_torch.nn import (
MemSaveBatchNorm2d,
MemSaveConv1d,
MemSaveConv2d,
MemSaveConv3d,
MemSaveConvTranspose1d,
MemSaveConvTranspose2d,
MemSaveConvTranspose3d,
MemSaveLinear,
)
from memsave_torch.nn.ConvTranspose1d import MemSaveConvTranspose1d

Check failure on line 19 in experiments/visual_abstract/run.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F811)

experiments/visual_abstract/run.py:19:46: F811 Redefinition of unused `MemSaveConvTranspose1d` from line 14
from torch import allclose, manual_seed, rand, rand_like
from torch.autograd import grad
from torch.nn import (
BatchNorm2d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
Linear,
Sequential,
)

HEREDIR = path.dirname(path.abspath(__file__))

Check failure on line 34 in experiments/visual_abstract/run.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

experiments/visual_abstract/run.py:3:1: I001 Import block is un-sorted or un-formatted
DATADIR = path.join(HEREDIR, "raw")
Expand All @@ -36,11 +49,11 @@ def main( # noqa: C901
# create the input
if architecture == "linear":
X = rand(512, 1024, 256)
elif architecture == "conv1d":
elif architecture in {"conv1d", "conv_transpose1d"}:
X = rand(4096, 8, 4096)
elif architecture in {"conv2d", "bn2d"}:
elif architecture in {"conv2d", "bn2d", "conv_transpose2d"}:
X = rand(256, 8, 256, 256)
elif architecture == "conv3d":
elif architecture in {"conv3d", "conv_transpose3d"}:
X = rand(64, 8, 64, 64, 64)
else:
raise ValueError(f"Invalid argument for architecture: {architecture}.")
Expand All @@ -66,6 +79,21 @@ def main( # noqa: C901
implementation
]
layers[f"{architecture}{i}"] = layer_cls(8)
elif architecture == "conv_transpose1d":
layer_cls = {"ours": MemSaveConvTranspose1d, "torch": ConvTranspose1d}[
implementation
]
layers[f"{architecture}{i}"] = layer_cls(8, 8, 3, padding=1, bias=False)
elif architecture == "conv_transpose2d":
layer_cls = {"ours": MemSaveConvTranspose2d, "torch": ConvTranspose2d}[
implementation
]
layers[f"{architecture}{i}"] = layer_cls(8, 8, 3, padding=1, bias=False)
elif architecture == "conv_transpose3d":
layer_cls = {"ours": MemSaveConvTranspose3d, "torch": ConvTranspose3d}[
implementation
]
layers[f"{architecture}{i}"] = layer_cls(8, 8, 3, padding=1, bias=False)
else:
raise ValueError(f"Invalid argument for architecture: {architecture}.")

Expand Down Expand Up @@ -144,22 +172,33 @@ def check_equality(architecture: str, mode: str, num_layers: int, requires_grad:
parser.add_argument(
"--requires_grad",
type=str,
choices=["all", "none", "4", "4+"],
choices={"all", "none", "4", "4+"},
help="Which layers are differentiable.",
)
parser.add_argument(
"--implementation",
type=str,
choices=["torch", "ours"],
choices={"torch", "ours"},
help="Which implementation to use.",
)
parser.add_argument(
"--architecture",
type=str,
choices=["linear", "conv1d", "conv2d", "conv3d", "bn2d"],
choices={
"linear",
"conv1d",
"conv2d",
"conv3d",
"bn2d",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
},
help="Which architecture to use.",
)
parser.add_argument("--mode", type=str, help="Mode of the network.")
parser.add_argument(
"--mode", type=str, help="Mode of the network.", choices={"train", "eval"}
)
parser.add_argument(
"--skip_existing", action="store_true", help="Skip existing files."
)
Expand Down
29 changes: 29 additions & 0 deletions memsave_torch/nn/ConvTranspose1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Implementation of a memory saving 1d transpose convolution layer."""

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv_transpose1dMemSave


class MemSaveConvTranspose1d(nn.ConvTranspose1d):

Check failure on line 8 in memsave_torch/nn/ConvTranspose1d.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

memsave_torch/nn/ConvTranspose1d.py:3:1: I001 Import block is un-sorted or un-formatted
"""Differentiability-agnostic 1d transpose convolution layer."""

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input: Input to the network [B, C_in, W]
Returns:
torch.Tensor: Output [B, C_out, W_out]
"""
return conv_transpose1dMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
)
29 changes: 29 additions & 0 deletions memsave_torch/nn/ConvTranspose2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Implementation of a memory saving 1d transpose convolution layer."""

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv_transpose2dMemSave


class MemSaveConvTranspose2d(nn.ConvTranspose2d):

Check failure on line 8 in memsave_torch/nn/ConvTranspose2d.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

memsave_torch/nn/ConvTranspose2d.py:3:1: I001 Import block is un-sorted or un-formatted
"""Differentiability-agnostic 2d transpose convolution layer."""

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input: Input to the network [B, C_in, H, W]
Returns:
torch.Tensor: Output [B, C_out, H_out, W_out]
"""
return conv_transpose2dMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
)
29 changes: 29 additions & 0 deletions memsave_torch/nn/ConvTranspose3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Implementation of a memory saving 1d transpose convolution layer."""

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv_transpose3dMemSave


class MemSaveConvTranspose3d(nn.ConvTranspose3d):

Check failure on line 8 in memsave_torch/nn/ConvTranspose3d.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

memsave_torch/nn/ConvTranspose3d.py:3:1: I001 Import block is un-sorted or un-formatted
"""Differentiability-agnostic 3d transpose convolution layer."""

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input: Input to the network [B, C_in, D, H, W]
Returns:
torch.Tensor: Output [B, C_out, D_out, H_out, W_out]
"""
return conv_transpose3dMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
)
4 changes: 3 additions & 1 deletion memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import sys

import torch.nn as nn

from memsave_torch.nn import functional # noqa: F401
from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
from memsave_torch.nn.Conv1d import MemSaveConv1d
from memsave_torch.nn.Conv2d import MemSaveConv2d
from memsave_torch.nn.Conv3d import MemSaveConv3d
from memsave_torch.nn.ConvTranspose1d import MemSaveConvTranspose1d

Check failure on line 17 in memsave_torch/nn/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

memsave_torch/nn/__init__.py:17:46: F401 `memsave_torch.nn.ConvTranspose1d.MemSaveConvTranspose1d` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
from memsave_torch.nn.ConvTranspose2d import MemSaveConvTranspose2d

Check failure on line 18 in memsave_torch/nn/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

memsave_torch/nn/__init__.py:18:46: F401 `memsave_torch.nn.ConvTranspose2d.MemSaveConvTranspose2d` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
from memsave_torch.nn.ConvTranspose3d import MemSaveConvTranspose3d

Check failure on line 19 in memsave_torch/nn/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

memsave_torch/nn/__init__.py:19:46: F401 `memsave_torch.nn.ConvTranspose3d.MemSaveConvTranspose3d` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
from memsave_torch.nn.Dropout import MemSaveDropout
from memsave_torch.nn.LayerNorm import (
MemSaveLayerNorm,
Expand Down
2 changes: 1 addition & 1 deletion memsave_torch/nn/functional/Conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
elif ctx.needs_input_grad[1]:
if ctx.needs_input_grad[1]:
x = ctx.saved_tensors[current_idx]
current_idx += 1

Expand Down
142 changes: 142 additions & 0 deletions memsave_torch/nn/functional/ConvTranspose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Implementation of memory saving transpose convolution layers.
This is done by not saving the inputs/weights if weight/inputs dont require grad.
"""

import torch


class _MemSaveConvTranspose(torch.autograd.Function):
@staticmethod
def forward(x, weight, bias, stride, padding, output_padding, dilation, groups):
return torch.ops.aten.convolution(
x,
weight,
bias,
stride,
padding,
dilation,
True,
output_padding,
groups,
)

@staticmethod
def setup_context(ctx, inputs, output):
x, weight, bias, stride, padding, output_padding, dilation, groups = inputs
need_grad = []
if ctx.needs_input_grad[0]:
need_grad.append(weight)
if ctx.needs_input_grad[1]:
need_grad.append(x)
# bias doesnt need anything for calc
ctx.bias_exists = bias is not None
ctx.stride = stride
ctx.padding = padding
ctx.output_padding = output_padding
ctx.dilation = dilation
ctx.groups = groups
ctx.x_shape = x.shape
ctx.weight_shape = weight.shape
ctx.device = x.device

ctx.save_for_backward(*need_grad)

@staticmethod
def backward(ctx, grad_output):
x = weight = None

current_idx = 0
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
if ctx.needs_input_grad[1]:
x = ctx.saved_tensors[current_idx]
current_idx += 1

if x is None:
x = torch.zeros(ctx.x_shape, device=ctx.device)
if weight is None:
weight = torch.zeros(ctx.weight_shape, device=ctx.device)

grad_x, grad_weight, grad_bias = torch.ops.aten.convolution_backward(
grad_output,
x,
weight,
[weight.shape[0]] if ctx.bias_exists else None,
ctx.stride,
ctx.padding,
ctx.dilation,
True,
ctx.output_padding,
ctx.groups,
ctx.needs_input_grad[:3],
)

return grad_x, grad_weight, grad_bias, None, None, None, None, None, None


def conv_transpose1dMemSave(

Check failure on line 79 in memsave_torch/nn/functional/ConvTranspose.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D417)

memsave_torch/nn/functional/ConvTranspose.py:79:5: D417 Missing argument description in the docstring for `conv_transpose1dMemSave`: `output_padding`
input, weight, bias, stride, padding, output_padding, dilation, groups
) -> torch.Tensor:
"""Functional form of the memory saving transpose convolution.
Args:
input: input [B, C_in, H, W]
weight: weight
bias: bias
stride: stride
padding: padding
dilation: dilation
groups: groups
Returns:
torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out]
"""
return _MemSaveConvTranspose.apply(
input, weight, bias, stride, padding, output_padding, dilation, groups
)


def conv_transpose2dMemSave(
input, weight, bias, stride, padding, output_padding, dilation, groups
) -> torch.Tensor:
"""Functional form of the memory saving transpose convolution.
Args:
input: input [B, C_in, H, W]
weight: weight
bias: bias
stride: stride
padding: padding
dilation: dilation
groups: groups
Returns:
torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out]
"""
return _MemSaveConvTranspose.apply(
input, weight, bias, stride, padding, output_padding, dilation, groups
)


def conv_transpose3dMemSave(
input, weight, bias, stride, padding, output_padding, dilation, groups
) -> torch.Tensor:
"""Functional form of the memory saving transpose convolution.
Args:
input: input [B, C_in, D, H, W]
weight: weight
bias: bias
stride: stride
padding: padding
dilation: dilation
groups: groups
Returns:
torch.Tensor: Output of the conv operation [B, C_out, D_out, H_out, W_out]
"""
return _MemSaveConvTranspose.apply(
input, weight, bias, stride, padding, output_padding, dilation, groups
)
5 changes: 5 additions & 0 deletions memsave_torch/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
conv2dMemSave,
conv3dMemSave,
)
from memsave_torch.nn.functional.ConvTranspose import ( # noqa: F401
conv_transpose1dMemSave,
conv_transpose2dMemSave,
conv_transpose3dMemSave,
)
from memsave_torch.nn.functional.Dropout import dropoutMemSave # noqa: F401
from memsave_torch.nn.functional.LayerNorm import ( # noqa: F401
layer_normMemSave,
Expand Down

0 comments on commit e1e49c2

Please sign in to comment.