Skip to content

Commit

Permalink
use the gather method to store data
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 16, 2024
1 parent 539b98a commit 51ed98a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 25 deletions.
12 changes: 6 additions & 6 deletions CI/unit_tests/recorders/test_base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_store(self):
name,
storage_path,
[self.measurement_1, self.measurement_2],
1,
4,
)

# Test storing
Expand All @@ -208,15 +208,15 @@ def test_store(self):
# _results should be empty after storing
assert recorder._results == {"dummy_1": [], "dummy_2": []}

# test overwriting
recorder.overwrite = True
# Test storing with ignore_chunk_size=False
recorder._measure(**self.neural_state)
recorder.store()
recorder._measure(**self.neural_state)
recorder.store(ignore_chunk_size=False)
data = recorder.load()

assert set(data.keys()) == {"dummy_1", "dummy_2"}
assert_array_equal(data["dummy_1"], np.ones(shape=(1, 3, 10, 5)))
assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(1, 3, 10)))
assert_array_equal(data["dummy_1"], np.ones(shape=(2, 3, 10, 5)))
assert_array_equal(data["dummy_2"], 10 * np.ones(shape=(2, 3, 10)))

# Delete temporary directory
temp_dir.cleanup()
Expand Down
33 changes: 14 additions & 19 deletions papyrus/recorders/base_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,28 +190,16 @@ def store(self, ignore_chunk_size=True):
TODO: Change this method to use another type of storage.
"""
if self._counter % self.chunk_size == 0 or ignore_chunk_size:
# Load the data from the database
try:
data = self.load()
# Append the new data
if self.overwrite:
data = self._results
else:
for key in self._results.keys():
data[key] = np.append(data[key], self._results[key], axis=0)
# If the file does not exist, create a new one
except FileNotFoundError:
data = self._results

# Gather the data
data = self.gather()
# Write the data back to the database
self._write(data)
# Reinitialize the temporary storage
self._init_internals()

def gather(self):
"""
Gather the results that can be stored in the database or are still in the
temporary storage.
Gather the results from the temporary storage and the database.
Returns
-------
Expand All @@ -220,13 +208,20 @@ def gather(self):
"""
# Load the data from the database
try:
data = self.load()
loaded_data = self.load()
# If the counter is 0, the temporary storage is empty
if self._counter == 0:
return data
data = loaded_data
# Check if loaded data is empty
elif all([len(v) == 0 for v in loaded_data.values()]):
data = self._results
# Append the new data to the loaded data
else:
# Append the new data
for key in self._results.keys():
data[key] = np.append(data[key], self._results[key], axis=0)
loaded_data[key] = np.append(
loaded_data[key], self._results[key], axis=0
)
data = loaded_data
# If the file does not exist, create a new one
except FileNotFoundError:
data = self._results
Expand Down

0 comments on commit 51ed98a

Please sign in to comment.