From c41e06fd835eac6026819849773cb0d0da4d1fe0 Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Thu, 29 Aug 2024 12:54:11 +0530 Subject: [PATCH] remove Linear (no benefits) --- docs_src/index.rst | 3 +- experiments/util/measurements.py | 3 +- experiments/util/models.py | 42 +++----------- memsave_torch/nn/Linear.py | 74 ------------------------- memsave_torch/nn/__init__.py | 20 +------ memsave_torch/nn/functional/Linear.py | 65 ---------------------- memsave_torch/nn/functional/__init__.py | 1 - test/test_layers_cases.py | 15 ----- 8 files changed, 12 insertions(+), 211 deletions(-) delete mode 100644 memsave_torch/nn/Linear.py delete mode 100644 memsave_torch/nn/functional/Linear.py diff --git a/docs_src/index.rst b/docs_src/index.rst index f227fee..447ed84 100644 --- a/docs_src/index.rst +++ b/docs_src/index.rst @@ -41,12 +41,11 @@ They are as fast as their built-in equivalents, but more memory-efficient whenev * :class:`memsave_torch.nn.MemSaveConv1d` * :class:`memsave_torch.nn.MemSaveConv2d` * :class:`memsave_torch.nn.MemSaveConv3d` -* :class:`memsave_torch.nn.MemSaveLinear` * :class:`memsave_torch.nn.MemSaveReLU` +* :class:`memsave_torch.nn.MemSaveMaxPool2d` * :class:`memsave_torch.nn.MemSaveConvTranspose1d` * :class:`memsave_torch.nn.MemSaveConvTranspose2d` * :class:`memsave_torch.nn.MemSaveConvTranspose3d` -* :class:`memsave_torch.nn.MemSaveMaxPool2d` * :class:`memsave_torch.nn.MemSaveBatchNorm2d` .. raw:: html diff --git a/experiments/util/measurements.py b/experiments/util/measurements.py index 991f0c3..91841e7 100644 --- a/experiments/util/measurements.py +++ b/experiments/util/measurements.py @@ -31,7 +31,6 @@ from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d from memsave_torch.nn.Conv2d import MemSaveConv2d -from memsave_torch.nn.Linear import MemSaveLinear def maybe_synchronize(dev: device): @@ -301,7 +300,7 @@ def separate_grad_arguments( Raises: NotImplementedError: If an unknown layer with parameters is encountered. """ - linear = (Linear, MemSaveLinear, Conv1D) + linear = (Linear, Conv1D) conv = ( Conv1d, Conv2d, diff --git a/experiments/util/models.py b/experiments/util/models.py index 8fcfe91..46169f9 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -34,8 +34,6 @@ from memsave_torch.nn import ( MemSaveBatchNorm2d, - MemSaveConv2d, - MemSaveLinear, convert_to_memory_saving, ) @@ -64,38 +62,38 @@ def prefix_in_pairs(prefix: str, it: List[str]) -> List[str]: def convert_to_memory_saving_defaultsoff( model: Module, - linear=False, conv2d=False, conv1d=False, + conv3d=False, batchnorm2d=False, relu=False, maxpool2d=False, - layernorm=False, + dropout=False, ) -> Module: """Extension of the `convert_to_memory_saving` function with all defaults as off Args: model (Module): Input model - linear (bool, optional): Whether to replace linear layers conv2d (bool, optional): Whether to replace conv2d layers conv1d (bool, optional): Whether to replace conv1d layers + conv3d (bool, optional): Whether to replace conv3d layers batchnorm2d (bool, optional): Whether to replace batchnorm2d layers relu (bool, optional): Whether to replace relu layers maxpool2d (bool, optional): Whether to replace maxpool2d layers - layernorm (bool, optional): Whether to replace layernorm layers + dropout (bool, optional): Whether to replace dropout layers Returns: Module: The converted memory saving model """ return convert_to_memory_saving( model, - linear=linear, conv2d=conv2d, conv1d=conv1d, + conv3d=conv3d, batchnorm2d=batchnorm2d, relu=relu, maxpool2d=maxpool2d, - layernorm=layernorm, + dropout=dropout, ) @@ -167,21 +165,6 @@ def _conv_model1() -> Module: ) # (H/8)*(W/8)*64 (filters) -> / 8 because maxpool -def _conv_model2() -> Module: - return Sequential( - MemSaveConv2d(conv_input_shape[0], 64, kernel_size=3, padding=1, bias=False), - MaxPool2d(kernel_size=3, stride=2, padding=1), - ReLU(), - *[ - MemSaveConv2d(64, 64, kernel_size=3, padding=1, bias=False) - for _ in range(10) - ], - MaxPool2d(kernel_size=4, stride=4, padding=1), - Flatten(start_dim=1, end_dim=-1), - MemSaveLinear(conv_input_shape[1] * conv_input_shape[2], num_classes), - ) - - def _convrelu_model1() -> Module: return Sequential( Conv2d(conv_input_shape[0], 64, kernel_size=3, padding=1, bias=False), @@ -242,7 +225,7 @@ def _convrelupool_model1(num_blocks=5) -> Module: conv_model_fns = { "deepmodel": _conv_model1, - "memsave_deepmodel": _conv_model2, + "memsave_deepmodel": lambda: convert_to_memory_saving(_conv_model1()), "deeprelumodel": _convrelu_model1, "memsave_deeprelumodel": lambda: convert_to_memory_saving(_convrelu_model1()), "deeprelupoolmodel": _convrelupool_model1, @@ -563,15 +546,8 @@ def _linear_model1() -> Module: ) -def _linear_model2() -> Module: - return Sequential( - MemSaveLinear(linear_input_shape, 1024), - *[MemSaveLinear(1024, 1024) for _ in range(12)], - MemSaveLinear(1024, num_classes), - ) - - linear_model_fns = { "deeplinearmodel": _linear_model1, - "memsave_deeplinearmodel": _linear_model2, + # Doesn't do anything, just kept for consistency: + "memsave_deeplinearmodel": lambda: convert_to_memory_saving(_linear_model1()), } diff --git a/memsave_torch/nn/Linear.py b/memsave_torch/nn/Linear.py deleted file mode 100644 index 1ba78c6..0000000 --- a/memsave_torch/nn/Linear.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Implementation of a memory saving Linear layer. - -This is done by not saving the inputs/weights if weight/inputs dont require grad. -""" - -import sys - -import torch.nn as nn - -from memsave_torch.nn.functional import linearMemSave - -transformers_imported = False -if "transformers" in sys.modules: - import transformers - - transformers_imported = True - - -class MemSaveLinear(nn.Linear): - """MemSaveLinear.""" - - def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): - """Inits a MemSaveLinear layer with the given params. - - Args: - in_features: in_features - out_features: out_features - bias: bias - device: device - dtype: dtype - """ - super().__init__(in_features, out_features, bias, device, dtype) - - def forward(self, x): - """Forward pass. - - Args: - x: Input to the network [B, F_in] - - Returns: - torch.Tensor: Output [B, F_out] - """ - return linearMemSave(x, self.weight, self.bias) - - @classmethod - def from_nn_Linear(cls, linear: nn.Linear): - """Converts a nn.Linear/transformers.Conv1D layer to MemSaveLinear. - - Args: - linear : The nn.Linear/transformers.Conv1D layer - - Returns: - obj: The MemSaveLinear object - """ - isTransformersConv1D = False - if transformers_imported: - isTransformersConv1D = isinstance(linear, transformers.Conv1D) - if isTransformersConv1D: - # it only saves output features in the model (linear.nf); need to take input features from weight anyway - # weight and bias are still defined - linear.in_features, linear.out_features = linear.weight.shape - obj = cls( - linear.in_features, - linear.out_features, - True if linear.bias is not None else False, - device=getattr(linear, "device", None), - dtype=getattr(linear, "dtype", None), - ) - if isTransformersConv1D: - obj.weight = nn.Parameter(linear.weight.T) - else: - obj.weight = linear.weight - obj.bias = linear.bias - return obj diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index e1d65ef..48dd028 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -6,8 +6,6 @@ - BatchNorm2d """ -import sys - import torch.nn as nn from memsave_torch.nn import functional # noqa: F401 @@ -19,20 +17,13 @@ from memsave_torch.nn.ConvTranspose2d import MemSaveConvTranspose2d from memsave_torch.nn.ConvTranspose3d import MemSaveConvTranspose3d from memsave_torch.nn.Dropout import MemSaveDropout -from memsave_torch.nn.Linear import MemSaveLinear from memsave_torch.nn.MaxPool import MemSaveMaxPool2d from memsave_torch.nn.ReLU import MemSaveReLU -transformers_imported = False -if "transformers" in sys.modules: - import transformers - - transformers_imported = True - def convert_to_memory_saving( model: nn.Module, - linear=True, + *, conv2d=True, conv1d=True, conv3d=True, @@ -51,7 +42,6 @@ def convert_to_memory_saving( Args: model (nn.Module): The input model - linear (bool, optional): Whether to replace `nn.Linear` layers conv2d (bool, optional): Whether to replace `nn.Conv2d` layers conv1d (bool, optional): Whether to replace `nn.Conv1d` layers conv3d (bool, optional): Whether to replace `nn.Conv3d` layers @@ -65,15 +55,7 @@ def convert_to_memory_saving( Returns: memsavemodel (nn.Module): The converted memory saving model """ - linear_cls = nn.Linear - if transformers_imported: - linear_cls = (nn.Linear, transformers.Conv1D) layers = [ - { - "allowed": linear, - "cls": linear_cls, - "convert_fn": MemSaveLinear.from_nn_Linear, - }, {"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU}, { "allowed": maxpool2d, diff --git a/memsave_torch/nn/functional/Linear.py b/memsave_torch/nn/functional/Linear.py deleted file mode 100644 index d36870e..0000000 --- a/memsave_torch/nn/functional/Linear.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Implementation of a memory saving Linear layer. - -This is done by not saving the inputs/weights if weight/inputs dont require grad. -""" - -import torch -import torch.nn as nn - - -class _MemSaveLinear(torch.autograd.Function): - @staticmethod - def forward(x, weight, bias): - return nn.functional.linear(x, weight, bias) - - @staticmethod - def setup_context(ctx, inputs, output): - x, weight, bias = inputs - need_grad = [] - if ctx.needs_input_grad[0]: - need_grad.append(weight) - if ctx.needs_input_grad[1]: - need_grad.append(x) - # bias doesnt need anything for calc - - ctx.save_for_backward(*need_grad) - - @staticmethod - def backward(ctx, grad_output): - x = weight = None - current_idx = 0 - if ctx.needs_input_grad[0]: - # print('0 needs weight') - weight = ctx.saved_tensors[current_idx] - current_idx += 1 - if ctx.needs_input_grad[1]: - # print('1 needs x') - x = ctx.saved_tensors[current_idx] - current_idx += 1 - - # print(current_idx) - - grad_x = grad_weight = grad_bias = None - - if ctx.needs_input_grad[0]: - grad_x = grad_output @ weight - if ctx.needs_input_grad[1]: - grad_weight = grad_output.mT @ x - if ctx.needs_input_grad[2]: - grad_bias = grad_output.sum(0) - - return grad_x, grad_weight, grad_bias - - -def linearMemSave(x, weight, bias=None) -> torch.Tensor: - """Functional form of the memory saving linear. - - Args: - x: Input to the network [B, F_in] - weight: weight - bias: bias - - Returns: - torch.Tensor: Output of the network [B, F_out] - """ - return _MemSaveLinear.apply(x, weight, bias) diff --git a/memsave_torch/nn/functional/__init__.py b/memsave_torch/nn/functional/__init__.py index 3f2d96b..3139a49 100644 --- a/memsave_torch/nn/functional/__init__.py +++ b/memsave_torch/nn/functional/__init__.py @@ -10,6 +10,5 @@ from memsave_torch.nn.functional.BatchNorm import batch_normMemSave # noqa: F401 from memsave_torch.nn.functional.Conv import convMemSave # noqa: F401 from memsave_torch.nn.functional.Dropout import dropoutMemSave # noqa: F401 -from memsave_torch.nn.functional.Linear import linearMemSave # noqa: F401 from memsave_torch.nn.functional.MaxPool import maxpool2dMemSave # noqa: F401 from memsave_torch.nn.functional.ReLU import reluMemSave # noqa: F401 diff --git a/test/test_layers_cases.py b/test/test_layers_cases.py index 835366d..2000f8a 100644 --- a/test/test_layers_cases.py +++ b/test/test_layers_cases.py @@ -25,21 +25,6 @@ class Case: cases = [ - Case( - name="Linear1dims", - layer_fn=lambda: torch.nn.Linear(3, 5), - data_fn=lambda: torch.rand(7, 3), - ), - Case( - name="Linear2dims", - layer_fn=lambda: torch.nn.Linear(3, 5), - data_fn=lambda: torch.rand(7, 12, 3), # weight sharing - ), - Case( - name="Linear3dims", - layer_fn=lambda: torch.nn.Linear(3, 5), - data_fn=lambda: torch.rand(7, 12, 12, 3), # weight sharing - ), Case( name="Conv1d", layer_fn=lambda: torch.nn.Conv1d(3, 5, 3),