Skip to content

Commit

Permalink
Finish testing for recorders using h5py data storage
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed Jun 8, 2024
1 parent 1662a66 commit edc5ba1
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 34 deletions.
72 changes: 41 additions & 31 deletions CI/unit_tests/recorders/test_base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import tempfile

import numpy as np
from numpy.testing import assert_array_equal
from numpy.testing import assert_array_equal, assert_raises

from papyrus.measurements import BaseMeasurement
from papyrus.recorders import BaseRecorder
Expand Down Expand Up @@ -150,30 +150,6 @@ def test_measure(self):
)
assert_array_equal(recorder._results["dummy_2"], 10 * np.ones(shape=(2, 3, 10)))

# def test_write_read(self):
# """
# Test the write and read methods of the BaseRecorder class.
# """
# # Create a temporary directory
# temp_dir = tempfile.TemporaryDirectory()
# name = "test"
# storage_path = temp_dir.name
# recorder = BaseRecorder(
# name, storage_path, [self.measurement_1, self.measurement_2], 10
# )

# # Test writing and reading
# recorder._measure(**self.neural_state)
# recorder._write(recorder._results)
# data = recorder.load()

# assert set(data.keys()) == {"dummy_1", "dummy_2"}
# assert_array_equal(data["dummy_1"], np.ones(shape=(1, 3, 10, 5)))
# assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(1, 3, 10)))

# # Delete temporary directory
# temp_dir.cleanup()

def test_store(self):
"""
Test the store method of the BaseRecorder class.
Expand Down Expand Up @@ -356,11 +332,7 @@ def test_overwrite(self):
2,
overwrite=True,
)
data = recorder.load()
print(data)
assert set(data.keys()) == {"dummy_1", "dummy_2"}
assert_array_equal(data["dummy_1"], [])
assert_array_equal(data["dummy_2"], [])
assert_raises(KeyError, recorder.load)

# Measure and save data again
recorder._measure(**self.neural_state)
Expand All @@ -379,4 +351,42 @@ def test_recoding_order(self):
"""
Test the order of the recordings.
"""
pass
# Create a temporary directory
temp_dir = tempfile.TemporaryDirectory()
name = "test"
storage_path = temp_dir.name
recorder = BaseRecorder(
name,
storage_path,
[self.measurement_1, self.measurement_2],
3,
)

# Prepare distinct neural states
neural_state_1 = {
"a": np.ones(shape=(3, 10, 5)),
"b": np.ones(shape=(3, 10, 5)),
"c": np.ones(shape=(3, 10, 5)),
}
neural_state_2 = {k: v * 2 for k, v in neural_state_1.items()}
neural_state_3 = {k: v * 3 for k, v in neural_state_1.items()}

# Measure and store data
recorder._measure(**neural_state_1)
recorder.store()
recorder._measure(**neural_state_2)
recorder.store()
recorder._measure(**neural_state_3)
recorder.store()

# Gather data
data = recorder.gather()
print(data)

# Check the order of the recordings
assert_array_equal(data["dummy_1"][0], 2 * np.ones(shape=(3, 10, 5)))
assert_array_equal(data["dummy_1"][1], 4 * np.ones(shape=(3, 10, 5)))
assert_array_equal(data["dummy_1"][2], 6 * np.ones(shape=(3, 10, 5)))

# Clear the temporary directory
temp_dir.cleanup()
6 changes: 3 additions & 3 deletions papyrus/recorders/base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def __init__(
# Check for existing data and overwrite if necessary
if self.overwrite:
try:
self.load()
# If overwrite is True, delete the existing data
self._data_storage.write(self._results)
keys = self._data_storage.read_keys()
for key in keys:
self._data_storage.del_dataset(key)
except FileNotFoundError:
pass

Expand Down
43 changes: 43 additions & 0 deletions papyrus/recorders/data_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,49 @@ def _resize_dataset(self, _data_group: str, _chunk_size: int):
current_size = len(db[_data_group])
db[_data_group].resize(int(_chunk_size + current_size), axis=0)

def del_dataset(self, data_group: str):
"""
Delete a dataset.
Parameters
----------
data_group : str
Group to delete.
Returns
-------
Deletes a dataset.
"""
with hf.File(self.database_path, "a") as db:
del db[data_group]

def clear_dataset(self, data_group: str):
"""
Clear a dataset.
Parameters
----------
data_group : str
Group to clear.
Returns
-------
Clears a dataset.
"""
with hf.File(self.database_path, "a") as db:
db[data_group][:] = np.zeros_like(db[data_group][:1])

def read_keys(self):
"""
Read the keys in the database.
Returns
-------
Returns the keys in the database.
"""
with hf.File(self.database_path, "r") as db:
return [key for key in db.keys()]

def _write_to_dataset(self, data_group: str, data: np.ndarray):
"""
Write a numpy array to a dataset.
Expand Down

0 comments on commit edc5ba1

Please sign in to comment.