Skip to content

Commit

Permalink
Make base recorder count internally
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 16, 2024
1 parent c3d4b64 commit ec7c3dd
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 170 deletions.
41 changes: 36 additions & 5 deletions CI/unit_tests/recorders/test_base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@ def test_store(self):
name,
storage_path,
[self.measurement_1, self.measurement_2],
10,
1,
)

# Test storing
recorder._measure(**self.neural_state)
recorder._store(0)
recorder._store()
data = recorder.load()

assert set(data.keys()) == {"dummy_1", "dummy_2"}
Expand All @@ -199,11 +199,10 @@ def test_store(self):

# Test storing again
recorder._measure(**self.neural_state)
recorder._store(10)
recorder._store()
data = recorder.load()

assert set(data.keys()) == {"dummy_1", "dummy_2"}
print(data["dummy_1"].shape)
assert_array_equal(data["dummy_1"], np.ones(shape=(2, 3, 10, 5)))
assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(2, 3, 10)))
# _results should be empty after storing
Expand All @@ -212,7 +211,7 @@ def test_store(self):
# test overwriting
recorder.overwrite = True
recorder._measure(**self.neural_state)
recorder._store(20)
recorder._store()
data = recorder.load()

assert set(data.keys()) == {"dummy_1", "dummy_2"}
Expand All @@ -221,3 +220,35 @@ def test_store(self):

# Delete temporary directory
os.system("rm -r temp/")

def test_counter(self):
"""
Test the counter attribute of the BaseRecorder class.
"""
# Create a temporary directory
os.makedirs("temp/", exist_ok=True)

name = "test"
storage_path = "temp/"
recorder = BaseRecorder(
name,
storage_path,
[self.measurement_1, self.measurement_2],
3,
)

# Test counter
assert recorder._counter == 0
recorder._measure(**self.neural_state)
assert recorder._counter == 1
recorder._measure(**self.neural_state)
assert recorder._counter == 2
recorder._store() # It should not story due to the chunk size
assert recorder._counter == 2
recorder._measure(**self.neural_state)
assert recorder._counter == 3
recorder._store() # It should store now
assert recorder._counter == 0

# Delete temporary directory
os.system("rm -r temp/")
187 changes: 31 additions & 156 deletions examples/mnist_flax.ipynb

Large diffs are not rendered by default.

21 changes: 12 additions & 9 deletions papyrus/recorders/base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def __init__(
# Temporary storage for results
self._init_results()

# Initialize internal counter
self._counter = 0

def _read_neural_state_keys(self):
"""
Read the neural state keys from the measurements.
Expand Down Expand Up @@ -167,20 +170,18 @@ def _measure(self, **neural_state):
# Store the result in the temporary storage
self._results[measurement.name].append(result)

def _store(self, epoch: int):
# Increment the counter
self._counter += 1

def _store(self):
"""
Store the results of the measurements in the database.
This method loads and writes the data to the database in chunks.
TODO: Change this method to use another type of storage.
Parameters
----------
epoch : int
The epoch of recording.
"""
if epoch % self.chunk_size == 0:
if self._counter % self.chunk_size == 0:
# Load the data from the database
try:
data = self.load()
Expand All @@ -198,8 +199,10 @@ def _store(self, epoch: int):
self._write(data)
# Reinitialize the temporary storage
self._init_results()
# Reset the counter
self._counter = 0

def record(self, epoch: int, neural_state: dict):
def record(self, neural_state: dict):
"""
Perform the recording of a neural state.
Expand All @@ -218,4 +221,4 @@ def record(self, epoch: int, neural_state: dict):
The result of the recorder.
"""
self._measure(**neural_state)
self._store(epoch)
self._store()

0 comments on commit ec7c3dd

Please sign in to comment.