Skip to content

Commit

Permalink
Make store a context manager
Browse files Browse the repository at this point in the history
Keeps it very clear that it must be closed
  • Loading branch information
WardLT committed Oct 20, 2023
1 parent d3d5d7e commit f26fc2e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 31 deletions.
13 changes: 8 additions & 5 deletions examol/store/db/base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
"""Base classes for storage utilities"""
import gzip
from typing import Iterable
from abc import ABC
from pathlib import Path
from typing import Iterable
from contextlib import AbstractContextManager


from examol.store.models import MoleculeRecord


class MoleculeStore:
class MoleculeStore(AbstractContextManager, ABC):
"""Base class defining how to interface with a dataset of molecule records.
Data stores provide the ability to persist the data collected by ExaMol to disk during a run.
The :meth:`update_record` call need not imemdaitely
The :meth:`update_record` call need not immediately persist the data but should ensure that the data
is stored on disk eventually.
In fact, it is actually better for the update operation to not block until the resulting write has completed.
Stores do not need support concurrent access from multiple client, which is why this documentation avoids the word "database."
"""

def __getitem__(self, mol_key: str) -> MoleculeRecord:
Expand Down
36 changes: 19 additions & 17 deletions examol/store/db/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ def __init__(self, path: Path, write_freq: float = 10.):
# Start by loading the molecules
self._load_molecules()

def __enter__(self):
logger.info('Start the writing thread')
self._write_thread = Thread(target=self._writer)
self._write_thread.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
# Trigger a last write
logger.info('Triggering a last write to the database')
self._closing.set()
if self._write_thread is not None:
self._write_thread.join()

# Mark that we're closed
self._write_thread = None
self._closing.clear()

def _load_molecules(self):
"""Load molecules from disk"""
if not self.path.is_file():
Expand All @@ -44,6 +61,7 @@ def _load_molecules(self):
for line in fp:
record = MoleculeRecord.from_json(line)
self.db[record.key] = record
logger.info(f'Loaded {len(self.db)} molecule records')

def iterate_over_records(self) -> Iterable[MoleculeRecord]:
yield from list(self.db.values()) # Use `list` to copy the current state of the db and avoid errors due to concurrent writes
Expand All @@ -56,7 +74,7 @@ def __len__(self):

def _writer(self):
next_write = 0
while not self._closing.is_set():
while not (self._closing.is_set() or self._updates_available.is_set()): # Loop until closing and no updates are available
# Wait until updates are available and the standoff is not met, or if we're closing
while (monotonic() < next_write or not self._updates_available.is_set()) and not self._closing.is_set():
self._updates_available.wait(timeout=1)
Expand All @@ -70,20 +88,4 @@ def _writer(self):

def update_record(self, record: MoleculeRecord):
self.db[record.key] = record

# Start the write thread, if needed, and trigger it
if self._write_thread is None:
logger.info('Start the writing thread')
self._write_thread = Thread(target=self._writer)
self._write_thread.start()
self._updates_available.set()

def close(self):
# Trigger a last write
self._closing.set()
if self._write_thread is not None:
self._write_thread.join()

# Mark that we're closed
self._write_thread = None
self._closing.clear()
11 changes: 2 additions & 9 deletions tests/store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,14 @@ def records() -> list[MoleculeRecord]:
def test_store(tmpdir, records):
# Open the database
db_path = tmpdir / 'db.json.gz'
store = InMemoryStore(db_path)
try:
with InMemoryStore(db_path) as store:
assert len(store) == 0

# Add the records
for record in records:
store.update_record(record)
assert len(store) == 3

finally:
store.close()

# Load database back in
store = InMemoryStore(db_path)
try:
with InMemoryStore(db_path) as store:
assert len(store) == 3
finally:
store.close()

0 comments on commit f26fc2e

Please sign in to comment.