Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Jul 31, 2024
1 parent e797c47 commit 40bf4ca
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions memsave_torch/nn/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import torch.nn as nn

from memsave_torch.nn.functional import dropoutMemSave


Expand Down
1 change: 1 addition & 0 deletions memsave_torch/nn/Linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys

import torch.nn as nn

from memsave_torch.nn.functional import linearMemSave

transformers_imported = False
Expand Down
1 change: 1 addition & 0 deletions memsave_torch/nn/MaxPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn as nn

from memsave_torch.nn.functional import maxpool2dMemSave


Expand Down
1 change: 1 addition & 0 deletions memsave_torch/nn/ReLU.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import torch.nn as nn

from memsave_torch.nn.functional import reluMemSave


Expand Down
3 changes: 2 additions & 1 deletion memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys

import torch.nn as nn

from memsave_torch.nn import functional # noqa: F401
from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
from memsave_torch.nn.Conv1d import MemSaveConv1d
Expand Down Expand Up @@ -181,7 +182,7 @@ def recursive_setattr(obj: nn.Module, attr: str, value: nn.Module, clone_params:
setattr(obj, attr_split[0], value)
if clone_params:
# value.load_state_dict(value.state_dict()) # makes a copy
for name,param in value._parameters.items():
for name, param in value._parameters.items():
value._parameters[name] = nn.Parameter(param.clone().detach())
else:
recursive_setattr(
Expand Down

0 comments on commit 40bf4ca

Please sign in to comment.