Skip to content

Commit

Permalink
Merge branch 'Konsti_Measurements' into Konsti_Recorders
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 15, 2024
2 parents 6d15dd9 + 21722ac commit e5ff86a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
4 changes: 2 additions & 2 deletions examples/measurements.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"\n",
"loss = Loss(\n",
" name='loss', # Name of the measurement\n",
" loss_fn=loss_fn # The function that will be called to compute the loss\n",
" apply_fn=loss_fn # The function that will be called to compute the loss\n",
")\n",
"print(f\"Neural state keys: {loss.neural_state_keys}\")\n",
"\n",
Expand Down Expand Up @@ -140,7 +140,7 @@
"\n",
"loss = Loss(\n",
" name='loss', # Name of the measurement\n",
" loss_fn=loss_fn # The function that will be called to compute the loss\n",
" apply_fn=loss_fn # The function that will be called to compute the loss\n",
")\n",
"\n",
"# Defining the neural state\n",
Expand Down
28 changes: 15 additions & 13 deletions papyrus/measurements/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
name: str = "loss",
rank: int = 0,
public: bool = False,
loss_fn: Optional[Callable] = None,
apply_fn: Optional[Callable] = None,
):
"""
Constructor method of the Loss class.
Expand All @@ -76,21 +76,22 @@ def __init__(
public : bool (default=False)
Boolean flag to indicate whether the measurement resutls will be
accessible via a public attribute of the recorder.
loss_fn : Optional[Callable] (default=None)
apply_fn : Optional[Callable] (default=None)
The loss function to be used to compute the loss of the neural network.
If the loss function is not provided, the apply method will assume that
the loss is used as the input.
If the loss function is provided, the apply method will assume that the
neural network outputs and the target values are used as inputs.
"""
super().__init__(name, rank, public)
self.loss_fn = loss_fn

self.apply_fn = apply_fn

# Based on the provided loss function, set the apply method
if self.loss_fn is None:
if self.apply_fn is None:
self.apply = self._apply_no_computation
else:
self.apply = self.apply_computation
self.apply = self._apply_computation

self.neural_state_keys = self._get_apply_signature()

Expand Down Expand Up @@ -137,7 +138,7 @@ def _apply_computation(
loss : float
The loss of the neural network.
"""
return self.loss_fn(predictions, targets)
return self.apply_fn(predictions, targets)


class Accuracy(BaseMeasurement):
Expand All @@ -162,7 +163,7 @@ def __init__(
name: str = "accuracy",
rank: int = 0,
public: bool = False,
accuracy_fn: Optional[Callable] = None,
apply_fn: Optional[Callable] = None,
):
"""
Constructor method of the Accuracy class.
Expand All @@ -178,7 +179,7 @@ def __init__(
public : bool (default=False)
Boolean flag to indicate whether the measurement resutls will be
accessible via a public attribute of the recorder.
accuracy_fn : Optional[Callable] (default=None)
apply_fn : Optional[Callable] (default=None)
The accuracy function to be used to compute the accuracy of the neural
network.
# If the accuracy function is not provided, the apply method will assume
Expand All @@ -187,13 +188,14 @@ def __init__(
the neural network outputs and the target values are used as inputs.
"""
super().__init__(name, rank, public)
self.accuracy_fn = accuracy_fn

self.apply_fn = apply_fn

# Based on the provided accuracy function, set the apply method
if self.accuracy_fn is None:
if self.apply_fn is None:
self.apply = self._apply_no_computation
else:
self.apply = self.apply_computation
self.apply = self._apply_computation

self.neural_state_keys = self._get_apply_signature()

Expand All @@ -215,7 +217,7 @@ def _apply_no_computation(self, accuracy: Optional[float] = None) -> float:
"""
return accuracy

def apply_computation(
def _apply_computation(
self,
predictions: Optional[np.ndarray] = None,
targets: Optional[np.ndarray] = None,
Expand All @@ -239,7 +241,7 @@ def apply_computation(
accuracy : float
The accuracy of the neural network.
"""
return self.accuracy_fn(predictions, targets)
return self.apply_fn(predictions, targets)


class NTKTrace(BaseMeasurement):
Expand Down

0 comments on commit e5ff86a

Please sign in to comment.