diff --git a/experiments/paper_demo.py b/experiments/paper_demo.py index d68199a..b050653 100644 --- a/experiments/paper_demo.py +++ b/experiments/paper_demo.py @@ -99,6 +99,7 @@ architecture, vjp_improvements, cases, + 'results' ) for model in models: diff --git a/memsave_torch/nn/Conv1d.py b/memsave_torch/nn/Conv1d.py index c463dac..3bf7c11 100644 --- a/memsave_torch/nn/Conv1d.py +++ b/memsave_torch/nn/Conv1d.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn -from Conv2d import _MemSaveConv +from memsave_torch.nn.Conv2d import _MemSaveConv class MemSaveConv1d(nn.Conv1d): diff --git a/memsave_torch/nn/ReLU.py b/memsave_torch/nn/ReLU.py index 38a1298..53f4265 100644 --- a/memsave_torch/nn/ReLU.py +++ b/memsave_torch/nn/ReLU.py @@ -11,15 +11,7 @@ class MemSaveReLU(nn.ReLU): """MemSaveReLU.""" def __init__(self): - """Inits a MemSaveReLU layer with the given params. - - Args: - in_features: in_features - out_features: out_features - bias: bias - device: device - dtype: dtype - """ + """Inits a MemSaveReLU layer with the given params.""" super().__init__() def forward(self, x):