Skip to content

Commit

Permalink
API changes: Coalesce Conv and ConvTranspose at functional level
Browse files Browse the repository at this point in the history
add convtranspose to converter, add conversion functions for each layer

format

update
  • Loading branch information
plutonium-239 committed Jul 31, 2024
1 parent 7854aac commit e797c47
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 227 deletions.
11 changes: 7 additions & 4 deletions memsave_torch/nn/Conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv1dMemSave

from memsave_torch.nn.functional import convMemSave


class MemSaveConv1d(nn.Conv1d):
Expand All @@ -15,19 +16,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input: Input to the network [B, C_in, H, W]
input (torch.Tensor): Input to the network [B, C_in, H, W]
Returns:
torch.Tensor: Output [B, C_out, H_out, W_out]
"""
return conv1dMemSave(
return convMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.transposed,
self.output_padding,
)

@classmethod
Expand All @@ -38,7 +41,7 @@ def from_nn_Conv1d(cls, conv1d: nn.Conv1d):
conv1d : The nn.Conv1d layer
Returns:
obj: The MemSaveConv1d object
MemSaveConv1d: The MemSaveConv1d object
"""
obj = cls(
conv1d.in_channels,
Expand Down
11 changes: 7 additions & 4 deletions memsave_torch/nn/Conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv2dMemSave

from memsave_torch.nn.functional import convMemSave


class MemSaveConv2d(nn.Conv2d):
Expand All @@ -20,25 +21,27 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Output [B, C_out, H_out, W_out]
"""
return conv2dMemSave(
return convMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.transposed,
self.output_padding,
)

@classmethod
def from_nn_Conv2d(cls, conv2d: nn.Conv2d):
"""Converts a nn.Conv2d layer to MemSaveConv2d.
Args:
conv2d : The nn.Conv2d layer
conv2d (nn.Conv2d): The nn.Conv2d layer
Returns:
obj: The MemSaveConv2d object
MemSaveConv2d: The MemSaveConv2d object
"""
obj = cls(
conv2d.in_channels,
Expand Down
13 changes: 8 additions & 5 deletions memsave_torch/nn/Conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv3dMemSave

from memsave_torch.nn.functional import convMemSave


class MemSaveConv3d(nn.Conv3d):
Expand All @@ -15,30 +16,32 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input: Input to the network [B, C_in, D, H, W]
input (torch.Tensor): Input to the network [B, C_in, D, H, W]
Returns:
torch.Tensor: Output [B, C_out, D_out, H_out, W_out]
"""
return conv3dMemSave(
return convMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.transposed,
self.output_padding,
)

@classmethod
def from_nn_Conv3d(cls, conv3d: nn.Conv3d):
"""Converts a nn.Conv3d layer to MemSaveConv3d.
Args:
conv3d : The nn.Conv3d layer
conv3d (nn.Conv3d): The nn.Conv3d layer
Returns:
obj: The MemSaveConv3d object
MemSaveConv3d: The MemSaveConv3d object
"""
obj = cls(
conv3d.in_channels,
Expand Down
38 changes: 34 additions & 4 deletions memsave_torch/nn/ConvTranspose1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv_transpose1dMemSave

from memsave_torch.nn.functional import convMemSave


class MemSaveConvTranspose1d(nn.ConvTranspose1d):
Expand All @@ -12,18 +13,47 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input: Input to the network [B, C_in, W]
input (torch.Tensor): Input to the network [B, C_in, W]
Returns:
torch.Tensor: Output [B, C_out, W_out]
"""
return conv_transpose1dMemSave(
return convMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
self.transposed,
self.output_padding,
)

@classmethod
def from_nn_ConvTranspose1d(cls, convT1d: nn.ConvTranspose1d):
"""Converts a nn.ConvTranspose1d layer to MemSaveConvTranspose1d.
Args:
convT1d (nn.ConvTranspose1d): The nn.ConvTranspose1d layer
Returns:
MemSaveConvTranspose1d: The MemSaveConvTranspose1d object
"""
obj = cls(
convT1d.in_channels,
convT1d.out_channels,
convT1d.kernel_size,
convT1d.stride,
convT1d.padding,
convT1d.output_padding,
convT1d.groups,
True if convT1d.bias is not None else False,
convT1d.dilation,
convT1d.padding_mode,
device=getattr(convT1d, "device", None),
dtype=getattr(convT1d, "dtype", None),
)
obj.weight = convT1d.weight
obj.bias = convT1d.bias
return obj
38 changes: 34 additions & 4 deletions memsave_torch/nn/ConvTranspose2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv_transpose2dMemSave

from memsave_torch.nn.functional import convMemSave


class MemSaveConvTranspose2d(nn.ConvTranspose2d):
Expand All @@ -12,18 +13,47 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input: Input to the network [B, C_in, H, W]
input (torch.Tensor): Input to the network [B, C_in, H, W]
Returns:
torch.Tensor: Output [B, C_out, H_out, W_out]
"""
return conv_transpose2dMemSave(
return convMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
self.transposed,
self.output_padding,
)

@classmethod
def from_nn_ConvTranspose2d(cls, convT2d: nn.ConvTranspose2d):
"""Converts a nn.ConvTranspose2d layer to MemSaveConvTranspose2d.
Args:
convT2d (nn.ConvTranspose2d): The nn.ConvTranspose2d layer
Returns:
MemSaveConvTranspose2d: The MemSaveConvTranspose2d object
"""
obj = cls(
convT2d.in_channels,
convT2d.out_channels,
convT2d.kernel_size,
convT2d.stride,
convT2d.padding,
convT2d.output_padding,
convT2d.groups,
True if convT2d.bias is not None else False,
convT2d.dilation,
convT2d.padding_mode,
device=getattr(convT2d, "device", None),
dtype=getattr(convT2d, "dtype", None),
)
obj.weight = convT2d.weight
obj.bias = convT2d.bias
return obj
38 changes: 34 additions & 4 deletions memsave_torch/nn/ConvTranspose3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch
import torch.nn as nn
from memsave_torch.nn.functional import conv_transpose3dMemSave

from memsave_torch.nn.functional import convMemSave


class MemSaveConvTranspose3d(nn.ConvTranspose3d):
Expand All @@ -12,18 +13,47 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input: Input to the network [B, C_in, D, H, W]
input (torch.Tensor): Input to the network [B, C_in, D, H, W]
Returns:
torch.Tensor: Output [B, C_out, D_out, H_out, W_out]
"""
return conv_transpose3dMemSave(
return convMemSave(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
self.transposed,
self.output_padding,
)

@classmethod
def from_nn_ConvTranspose3d(cls, convT3d: nn.ConvTranspose3d):
"""Converts a nn.ConvTranspose3d layer to MemSaveConvTranspose3d.
Args:
convT3d (nn.ConvTranspose3d): The nn.ConvTranspose3d layer
Returns:
MemSaveConvTranspose3d: The MemSaveConvTranspose3d object
"""
obj = cls(
convT3d.in_channels,
convT3d.out_channels,
convT3d.kernel_size,
convT3d.stride,
convT3d.padding,
convT3d.output_padding,
convT3d.groups,
True if convT3d.bias is not None else False,
convT3d.dilation,
convT3d.padding_mode,
device=getattr(convT3d, "device", None),
dtype=getattr(convT3d, "dtype", None),
)
obj.weight = convT3d.weight
obj.bias = convT3d.bias
return obj
15 changes: 15 additions & 0 deletions memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ def convert_to_memory_saving(
"cls": nn.Conv3d,
"convert_fn": MemSaveConv3d.from_nn_Conv3d,
},
{
"allowed": conv1d,
"cls": nn.ConvTranspose1d,
"convert_fn": MemSaveConvTranspose1d.from_nn_ConvTranspose1d,
},
{
"allowed": conv2d,
"cls": nn.ConvTranspose2d,
"convert_fn": MemSaveConvTranspose2d.from_nn_ConvTranspose2d,
},
{
"allowed": conv3d,
"cls": nn.ConvTranspose3d,
"convert_fn": MemSaveConvTranspose3d.from_nn_ConvTranspose3d,
},
{
"allowed": batchnorm2d,
"cls": nn.BatchNorm2d,
Expand Down
Loading

0 comments on commit e797c47

Please sign in to comment.