Skip to content

Commit

Permalink
Greatly improve database performance for esgpull update (#47)
Browse files Browse the repository at this point in the history
* feat(db): faster insertions for new items

* feat(cli.update): faster link/unlink query & file

new(db): Database.commit_context for bulk transactions
changed(cli.update): direct insert/delete into query_file table
changed(cli.update): bulk insert/delete instead of per file

* feat(cli.update): show current query in progress bar

* test(cli.update): less than 30 seconds to fetch ~6k files

* fix: renamed

* test(cli.update): proper setup for test installs

* test(selection): proper teardown fixtures, avoids messing other tests
  • Loading branch information
svenrdz authored Jul 17, 2024
1 parent e4f5df1 commit 3a1da98
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 27 deletions.
35 changes: 19 additions & 16 deletions esgpull/cli/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,23 @@ def update(
if choice == "y":
legacy = esg.legacy_query
has_legacy = legacy.state.persistent
for file in new_files:
file_db = esg.db.get(File, file.sha)
if file_db is None:
if esg.db.has_file_id(file):
logger.error(
"File id already exists in database, "
"there might be an error with its checksum"
f"\n{file}"
)
continue
file.status = FileStatus.Queued
file_db = esg.db.merge(file)
elif has_legacy and legacy in file_db.queries:
file_db.queries.remove(legacy)
file_db.queries.append(qf.query)
esg.db.add(file_db)
with esg.db.commit_context():
for file in esg.ui.track(
new_files,
description=qf.query.rich_name,
):
file_db = esg.db.get(File, file.sha)
if file_db is None:
if esg.db.has_file_id(file):
logger.error(
"File id already exists in database, "
"there might be an error with its checksum"
f"\n{file}"
)
continue
file.status = FileStatus.Queued
esg.db.session.add(file)
elif has_legacy and legacy in file_db.queries:
esg.db.unlink(query=legacy, file=file_db)
esg.db.link(query=qf.query, file=file)
esg.ui.raise_maybe_record(Exit(0))
1 change: 1 addition & 0 deletions esgpull/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
CONFIG_FILENAME = "config.toml"
INSTALLS_PATH_ENV = "ESGPULL_INSTALLS_PATH"
ROOT_ENV = "ESGPULL_CURRENT"

IDP = "/esgf-idp/openid/"
Expand Down
23 changes: 21 additions & 2 deletions esgpull/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

from esgpull import __file__
from esgpull.config import Config
from esgpull.models import File, Table, sql
from esgpull.models import File, Query, Table, sql
from esgpull.version import __version__

# from esgpull.exceptions import NoClauseError
# from esgpull.models import Query

T = TypeVar("T")

Expand All @@ -42,8 +41,16 @@ def from_config(config: Config, run_migrations: bool = True) -> Database:
url = f"sqlite:///{config.paths.db / config.db.filename}"
return Database(url, run_migrations=run_migrations)

def _setup_sqlite(self, conn, record):
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode = WAL;")
cursor.execute("PRAGMA synchronous = NORMAL;")
cursor.execute("PRAGMA cache_size = 20000;")
cursor.close()

def __post_init__(self, run_migrations: bool) -> None:
self._engine = sa.create_engine(self.url)
sa.event.listen(self._engine, "connect", self._setup_sqlite)
self.session = Session(self._engine)
if run_migrations:
self._update()
Expand Down Expand Up @@ -80,6 +87,12 @@ def safe(self) -> Iterator[None]:
self.session.rollback()
raise

@contextmanager
def commit_context(self) -> Iterator[None]:
with self.safe:
yield
self.session.commit()

def get(
self,
table: type[Table],
Expand Down Expand Up @@ -132,6 +145,12 @@ def delete(self, *items: Table) -> None:
for item in items:
make_transient(item)

def link(self, query: Query, file: File):
self.session.execute(sql.query_file.link(query, file))

def unlink(self, query: Query, file: File):
self.session.execute(sql.query_file.unlink(query, file))

def __contains__(self, item: Table) -> bool:
return self.scalars(sql.count(item))[0] > 0

Expand Down
12 changes: 10 additions & 2 deletions esgpull/install_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import platformdirs
from typing_extensions import NotRequired, TypedDict

from esgpull.constants import ROOT_ENV
from esgpull.constants import INSTALLS_PATH_ENV, ROOT_ENV
from esgpull.exceptions import AlreadyInstalledName, AlreadyInstalledPath


Expand Down Expand Up @@ -45,7 +45,15 @@ class _InstallConfig:
installs: list[Install]

def __init__(self) -> None:
user_config_dir = platformdirs.user_config_path("esgpull")
self.setup()

def setup(self, install_path: Path | None = None):
if install_path is not None:
user_config_dir = install_path
elif (env := os.environ.get(INSTALLS_PATH_ENV)) is not None:
user_config_dir = Path(env)
else:
user_config_dir = platformdirs.user_config_path("esgpull")
self.path = user_config_dir / "installs.json"
if self.path.is_file():
with self.path.open() as f:
Expand Down
6 changes: 5 additions & 1 deletion esgpull/models/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def setter(self: Selection, values: FacetValues):

setattr(cls, name, property(getter, setter))

@classmethod
def reset(cls) -> None:
cls.configure(*DefaultFacets, *BaseFacets, replace=True)

@classmethod
def configure(cls, *names: str, replace: bool = True) -> None:
nameset = set(names) | {f"!{name}" for name in names}
Expand Down Expand Up @@ -198,4 +202,4 @@ def __repr__(self) -> str:
]


Selection.configure(*DefaultFacets, *BaseFacets, replace=True)
Selection.reset()
16 changes: 16 additions & 0 deletions esgpull/models/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,19 @@ def ids() -> sa.Select[tuple[int]]:
@staticmethod
def with_ids(*ids: int) -> sa.Select[tuple[SyndaFile]]:
return sa.select(SyndaFile).where(SyndaFile.file_id.in_(ids))


class query_file:
@staticmethod
def link(query: Query, file: File) -> sa.Insert:
return sa.insert(query_file_proxy).values(
query_sha=query.sha, file_sha=file.sha
)

@staticmethod
def unlink(query: Query, file: File) -> sa.Delete:
return (
sa.delete(query_file_proxy)
.where(query_file_proxy.c.query_sha == query.sha)
.where(query_file_proxy.c.file_sha == file.sha)
)
7 changes: 3 additions & 4 deletions esgpull/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ class DummyLive:
def __enter__(self) -> DummyLive:
return self

def __exit__(self, *args):
...
def __exit__(self, *args): ...

@property
def console(self) -> DummyConsole:
Expand Down Expand Up @@ -259,9 +258,9 @@ def live(
# use _console to avoid recording the progress bar
return Live(renderables, console=_console)

def track(self, iterable: Iterable[T]) -> Iterable[T]:
def track(self, iterable: Iterable[T], **kwargs) -> Iterable[T]:
# use _console to avoid recording the progress bar
return track(iterable, console=_console)
return track(iterable, console=_console, **kwargs)

def make_progress(
self,
Expand Down
Empty file added tests/cli/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions tests/cli/test_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from time import perf_counter

from click.testing import CliRunner

from esgpull.cli.add import add
from esgpull.cli.config import config
from esgpull.cli.self import install
from esgpull.cli.update import update
from esgpull.install_config import InstallConfig


def test_fast_update(tmp_path):
InstallConfig.setup(tmp_path)
install_path = tmp_path / "esgpull"
runner = CliRunner()
result_install = runner.invoke(install, [f"{install_path}"])
assert result_install.exit_code == 0
result_config = runner.invoke(config, ["api.page_limit", "10000"])
assert result_config.exit_code == 0
result_add = runner.invoke(
add,
[
"table_id:fx",
"experiment_id:dcpp*",
"--distrib",
"false",
"--track",
],
)
assert result_add.exit_code == 0
start = perf_counter()
result_update = runner.invoke(update, ["--yes"])
stop = perf_counter()
assert result_update.exit_code == 0
assert stop - start < 30 # 30 seconds to fetch ~6k files is plenty enough
InstallConfig.setup()
6 changes: 4 additions & 2 deletions tests/test_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
@pytest.fixture
def selection():
Selection.configure("a", "b", "c", "d", replace=True)
return Selection()
yield Selection()
Selection.reset()


def test_configure():
Expand All @@ -16,11 +17,12 @@ def test_configure():
assert new_names <= Selection._facet_names
Selection.configure("some", "thing", replace=True)
assert new_names == Selection._facet_names
sel = Selection()
with pytest.raises(KeyError):
sel = Selection()
assert sel["a"] == []
Selection.configure("a") # add 'a' to facets
assert sel["a"] == [] # no more raise
Selection.reset()


def test_basic(selection):
Expand Down

0 comments on commit 3a1da98

Please sign in to comment.