diff --git a/main.py b/main.py index 84fa42c..9687d2c 100644 --- a/main.py +++ b/main.py @@ -55,6 +55,7 @@ def __init__(self, args, outpath): self.parameters = None self.num_params = None self.optimizer = None + self.netdir = None def build_input(self): # build a noise tensor diff --git a/utils/metrics.py b/utils/metrics.py index eec1460..a9b43a3 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -5,7 +5,7 @@ def snr(output: np.ndarray or torch.Tensor, target: np.ndarray or torch.Tensor - ) -> np.float or torch.Tensor: + ) -> float or torch.Tensor: """Compute the Signal-to-Noise Ratio in dB""" if target.shape != output.shape: @@ -19,7 +19,7 @@ def snr(output: np.ndarray or torch.Tensor, def pcorr(output: np.ndarray or torch.Tensor, target: np.ndarray or torch.Tensor - ) -> np.float or torch.Tensor: + ) -> float or torch.Tensor: """ Compute the Pearson Correlation Coefficient https://en.wikipedia.org/wiki/Pearson_correlation_coefficient