Skip to content

Commit

Permalink
restructure convert_to_memory_saving (decreases complexity)
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Mar 30, 2024
1 parent 71a6082 commit 98cd37b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 48 deletions.
4 changes: 3 additions & 1 deletion memsave_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""memsave_torch package"""
"""memsave_torch package"""
import memsave_torch.nn # noqa: F401
import memsave_torch.util # noqa: F401
14 changes: 0 additions & 14 deletions memsave_torch/memsave_torch.egg-info/PKG-INFO

This file was deleted.

73 changes: 40 additions & 33 deletions memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from memsave_torch.nn.ReLU import MemSaveReLU


def convert_to_memory_saving( # noqa: C901
def convert_to_memory_saving(
model: nn.Module,
linear=True,
conv2d=True,
Expand All @@ -41,42 +41,49 @@ def convert_to_memory_saving( # noqa: C901
Returns:
memsavemodel (nn.Module): The converted memory saving model
"""
layers = [
{
"allowed": linear,
"cls": nn.Linear,
"convert_fn": MemSaveLinear.from_nn_Linear,
},
{"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU},
{
"allowed": maxpool2d,
"cls": nn.MaxPool2d,
"convert_fn": MemSaveMaxPool2d.from_nn_MaxPool2d,
},
{
"allowed": conv2d,
"cls": nn.Conv2d,
"convert_fn": MemSaveConv2d.from_nn_Conv2d,
},
{
"allowed": batchnorm2d,
"cls": nn.BatchNorm2d,
"convert_fn": MemSaveBatchNorm2d.from_nn_BatchNorm2d,
},
{
"allowed": layernorm,
"cls": nn.LayerNorm,
"convert_fn": MemSaveLayerNorm.from_nn_LayerNorm,
},
]

import copy

memsavemodel = copy.deepcopy(model)
# using named_modules because it automatically iterates on Sequential/BasicBlock(resnet) etc.
for name, layer in model.named_modules():
if relu and isinstance(layer, nn.ReLU):
if verbose:
print(f"replaced {name}")
recursive_setattr(memsavemodel, name, MemSaveReLU.from_nn_ReLU(layer))
if maxpool2d and isinstance(layer, nn.MaxPool2d):
if verbose:
print(f"replaced {name}")
recursive_setattr(
memsavemodel, name, MemSaveMaxPool2d.from_nn_MaxPool2d(layer)
)
if linear and isinstance(layer, nn.Linear):
if verbose:
print(f"replaced {name}")
recursive_setattr(memsavemodel, name, MemSaveLinear.from_nn_Linear(layer))
if conv2d and isinstance(layer, nn.Conv2d):
if verbose:
print(f"replaced {name}")
recursive_setattr(memsavemodel, name, MemSaveConv2d.from_nn_Conv2d(layer))
if batchnorm2d and isinstance(layer, nn.BatchNorm2d):
for replacement in layers:
if not replacement["allowed"] and isinstance(layer, replacement["cls"]):
continue
if verbose:
print(f"replaced {name}")
recursive_setattr(
memsavemodel, name, MemSaveBatchNorm2d.from_nn_BatchNorm2d(layer)
)
if layernorm and isinstance(layer, nn.LayerNorm):
if verbose:
print(f"replaced {name}")
recursive_setattr(
memsavemodel, name, MemSaveLayerNorm.from_nn_LayerNorm(layer)
)

if name == "":
# In case a module is directly passed without wrapping sequential/moduledict
return replacement["convert_fn"](layer)
recursive_setattr(memsavemodel, name, replacement["convert_fn"](layer))
return memsavemodel


Expand All @@ -90,8 +97,8 @@ def recursive_setattr(obj: nn.Module, attr: str, value: nn.Module):
attr (str): The dot-indexed name of the leaf layer to replace (i.e. layer.0.conv2)
value (nn.Module): The module to replace the leaf with
"""
attr = attr.split(".", 1)
attr_split = attr.split(".", 1)
if len(attr) == 1:
setattr(obj, attr[0], value)
setattr(obj, attr_split[0], value)
else:
recursive_setattr(getattr(obj, attr[0]), attr[1], value)
recursive_setattr(getattr(obj, attr_split[0]), attr_split[1], value)

0 comments on commit 98cd37b

Please sign in to comment.