Skip to content

Commit

Permalink
remove Linear (no benefits)
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Aug 29, 2024
1 parent 5cb41b6 commit c41e06f
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 211 deletions.
3 changes: 1 addition & 2 deletions docs_src/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ They are as fast as their built-in equivalents, but more memory-efficient whenev
* :class:`memsave_torch.nn.MemSaveConv1d`
* :class:`memsave_torch.nn.MemSaveConv2d`
* :class:`memsave_torch.nn.MemSaveConv3d`
* :class:`memsave_torch.nn.MemSaveLinear`
* :class:`memsave_torch.nn.MemSaveReLU`
* :class:`memsave_torch.nn.MemSaveMaxPool2d`
* :class:`memsave_torch.nn.MemSaveConvTranspose1d`
* :class:`memsave_torch.nn.MemSaveConvTranspose2d`
* :class:`memsave_torch.nn.MemSaveConvTranspose3d`
* :class:`memsave_torch.nn.MemSaveMaxPool2d`
* :class:`memsave_torch.nn.MemSaveBatchNorm2d`

.. raw:: html
Expand Down
3 changes: 1 addition & 2 deletions experiments/util/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
from memsave_torch.nn.Conv2d import MemSaveConv2d
from memsave_torch.nn.Linear import MemSaveLinear


def maybe_synchronize(dev: device):
Expand Down Expand Up @@ -301,7 +300,7 @@ def separate_grad_arguments(
Raises:
NotImplementedError: If an unknown layer with parameters is encountered.
"""
linear = (Linear, MemSaveLinear, Conv1D)
linear = (Linear, Conv1D)
conv = (
Conv1d,
Conv2d,
Expand Down
42 changes: 9 additions & 33 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@

from memsave_torch.nn import (
MemSaveBatchNorm2d,
MemSaveConv2d,
MemSaveLinear,
convert_to_memory_saving,
)

Expand Down Expand Up @@ -64,38 +62,38 @@ def prefix_in_pairs(prefix: str, it: List[str]) -> List[str]:

def convert_to_memory_saving_defaultsoff(
model: Module,
linear=False,
conv2d=False,
conv1d=False,
conv3d=False,
batchnorm2d=False,
relu=False,
maxpool2d=False,
layernorm=False,
dropout=False,
) -> Module:
"""Extension of the `convert_to_memory_saving` function with all defaults as off
Args:
model (Module): Input model
linear (bool, optional): Whether to replace linear layers
conv2d (bool, optional): Whether to replace conv2d layers
conv1d (bool, optional): Whether to replace conv1d layers
conv3d (bool, optional): Whether to replace conv3d layers
batchnorm2d (bool, optional): Whether to replace batchnorm2d layers
relu (bool, optional): Whether to replace relu layers
maxpool2d (bool, optional): Whether to replace maxpool2d layers
layernorm (bool, optional): Whether to replace layernorm layers
dropout (bool, optional): Whether to replace dropout layers
Returns:
Module: The converted memory saving model
"""
return convert_to_memory_saving(
model,
linear=linear,
conv2d=conv2d,
conv1d=conv1d,
conv3d=conv3d,
batchnorm2d=batchnorm2d,
relu=relu,
maxpool2d=maxpool2d,
layernorm=layernorm,
dropout=dropout,
)


Expand Down Expand Up @@ -167,21 +165,6 @@ def _conv_model1() -> Module:
) # (H/8)*(W/8)*64 (filters) -> / 8 because maxpool


def _conv_model2() -> Module:
return Sequential(
MemSaveConv2d(conv_input_shape[0], 64, kernel_size=3, padding=1, bias=False),
MaxPool2d(kernel_size=3, stride=2, padding=1),
ReLU(),
*[
MemSaveConv2d(64, 64, kernel_size=3, padding=1, bias=False)
for _ in range(10)
],
MaxPool2d(kernel_size=4, stride=4, padding=1),
Flatten(start_dim=1, end_dim=-1),
MemSaveLinear(conv_input_shape[1] * conv_input_shape[2], num_classes),
)


def _convrelu_model1() -> Module:
return Sequential(
Conv2d(conv_input_shape[0], 64, kernel_size=3, padding=1, bias=False),
Expand Down Expand Up @@ -242,7 +225,7 @@ def _convrelupool_model1(num_blocks=5) -> Module:

conv_model_fns = {
"deepmodel": _conv_model1,
"memsave_deepmodel": _conv_model2,
"memsave_deepmodel": lambda: convert_to_memory_saving(_conv_model1()),
"deeprelumodel": _convrelu_model1,
"memsave_deeprelumodel": lambda: convert_to_memory_saving(_convrelu_model1()),
"deeprelupoolmodel": _convrelupool_model1,
Expand Down Expand Up @@ -563,15 +546,8 @@ def _linear_model1() -> Module:
)


def _linear_model2() -> Module:
return Sequential(
MemSaveLinear(linear_input_shape, 1024),
*[MemSaveLinear(1024, 1024) for _ in range(12)],
MemSaveLinear(1024, num_classes),
)


linear_model_fns = {
"deeplinearmodel": _linear_model1,
"memsave_deeplinearmodel": _linear_model2,
# Doesn't do anything, just kept for consistency:
"memsave_deeplinearmodel": lambda: convert_to_memory_saving(_linear_model1()),
}
74 changes: 0 additions & 74 deletions memsave_torch/nn/Linear.py

This file was deleted.

20 changes: 1 addition & 19 deletions memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
- BatchNorm2d
"""

import sys

import torch.nn as nn

from memsave_torch.nn import functional # noqa: F401
Expand All @@ -19,20 +17,13 @@
from memsave_torch.nn.ConvTranspose2d import MemSaveConvTranspose2d
from memsave_torch.nn.ConvTranspose3d import MemSaveConvTranspose3d
from memsave_torch.nn.Dropout import MemSaveDropout
from memsave_torch.nn.Linear import MemSaveLinear
from memsave_torch.nn.MaxPool import MemSaveMaxPool2d
from memsave_torch.nn.ReLU import MemSaveReLU

transformers_imported = False
if "transformers" in sys.modules:
import transformers

transformers_imported = True


def convert_to_memory_saving(
model: nn.Module,
linear=True,
*,
conv2d=True,
conv1d=True,
conv3d=True,
Expand All @@ -51,7 +42,6 @@ def convert_to_memory_saving(
Args:
model (nn.Module): The input model
linear (bool, optional): Whether to replace `nn.Linear` layers
conv2d (bool, optional): Whether to replace `nn.Conv2d` layers
conv1d (bool, optional): Whether to replace `nn.Conv1d` layers
conv3d (bool, optional): Whether to replace `nn.Conv3d` layers
Expand All @@ -65,15 +55,7 @@ def convert_to_memory_saving(
Returns:
memsavemodel (nn.Module): The converted memory saving model
"""
linear_cls = nn.Linear
if transformers_imported:
linear_cls = (nn.Linear, transformers.Conv1D)
layers = [
{
"allowed": linear,
"cls": linear_cls,
"convert_fn": MemSaveLinear.from_nn_Linear,
},
{"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU},
{
"allowed": maxpool2d,
Expand Down
65 changes: 0 additions & 65 deletions memsave_torch/nn/functional/Linear.py

This file was deleted.

1 change: 0 additions & 1 deletion memsave_torch/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@
from memsave_torch.nn.functional.BatchNorm import batch_normMemSave # noqa: F401
from memsave_torch.nn.functional.Conv import convMemSave # noqa: F401
from memsave_torch.nn.functional.Dropout import dropoutMemSave # noqa: F401
from memsave_torch.nn.functional.Linear import linearMemSave # noqa: F401
from memsave_torch.nn.functional.MaxPool import maxpool2dMemSave # noqa: F401
from memsave_torch.nn.functional.ReLU import reluMemSave # noqa: F401
15 changes: 0 additions & 15 deletions test/test_layers_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,6 @@ class Case:


cases = [
Case(
name="Linear1dims",
layer_fn=lambda: torch.nn.Linear(3, 5),
data_fn=lambda: torch.rand(7, 3),
),
Case(
name="Linear2dims",
layer_fn=lambda: torch.nn.Linear(3, 5),
data_fn=lambda: torch.rand(7, 12, 3), # weight sharing
),
Case(
name="Linear3dims",
layer_fn=lambda: torch.nn.Linear(3, 5),
data_fn=lambda: torch.rand(7, 12, 12, 3), # weight sharing
),
Case(
name="Conv1d",
layer_fn=lambda: torch.nn.Conv1d(3, 5, 3),
Expand Down

0 comments on commit c41e06f

Please sign in to comment.