diff --git a/memsave_torch/nn/Dropout.py b/memsave_torch/nn/Dropout.py index 247cef8..da19f48 100644 --- a/memsave_torch/nn/Dropout.py +++ b/memsave_torch/nn/Dropout.py @@ -4,6 +4,7 @@ """ import torch.nn as nn + from memsave_torch.nn.functional import dropoutMemSave diff --git a/memsave_torch/nn/Linear.py b/memsave_torch/nn/Linear.py index ddca31a..d2535a5 100644 --- a/memsave_torch/nn/Linear.py +++ b/memsave_torch/nn/Linear.py @@ -6,6 +6,7 @@ import sys import torch.nn as nn + from memsave_torch.nn.functional import linearMemSave transformers_imported = False diff --git a/memsave_torch/nn/MaxPool.py b/memsave_torch/nn/MaxPool.py index 7272e1b..9f5fd77 100644 --- a/memsave_torch/nn/MaxPool.py +++ b/memsave_torch/nn/MaxPool.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn + from memsave_torch.nn.functional import maxpool2dMemSave diff --git a/memsave_torch/nn/ReLU.py b/memsave_torch/nn/ReLU.py index a7e0b29..6b24343 100644 --- a/memsave_torch/nn/ReLU.py +++ b/memsave_torch/nn/ReLU.py @@ -4,6 +4,7 @@ """ import torch.nn as nn + from memsave_torch.nn.functional import reluMemSave diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index acd3217..65ee93c 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -9,6 +9,7 @@ import sys import torch.nn as nn + from memsave_torch.nn import functional # noqa: F401 from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d from memsave_torch.nn.Conv1d import MemSaveConv1d @@ -181,7 +182,7 @@ def recursive_setattr(obj: nn.Module, attr: str, value: nn.Module, clone_params: setattr(obj, attr_split[0], value) if clone_params: # value.load_state_dict(value.state_dict()) # makes a copy - for name,param in value._parameters.items(): + for name, param in value._parameters.items(): value._parameters[name] = nn.Parameter(param.clone().detach()) else: recursive_setattr(