Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Greatly improve database performance for esgpull update #47

Merged
merged 7 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions esgpull/cli/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,23 @@ def update(
if choice == "y":
legacy = esg.legacy_query
has_legacy = legacy.state.persistent
for file in new_files:
file_db = esg.db.get(File, file.sha)
if file_db is None:
if esg.db.has_file_id(file):
logger.error(
"File id already exists in database, "
"there might be an error with its checksum"
f"\n{file}"
)
continue
file.status = FileStatus.Queued
file_db = esg.db.merge(file)
elif has_legacy and legacy in file_db.queries:
file_db.queries.remove(legacy)
file_db.queries.append(qf.query)
esg.db.add(file_db)
with esg.db.commit_context():
for file in esg.ui.track(
new_files,
description=qf.query.rich_name,
):
file_db = esg.db.get(File, file.sha)
if file_db is None:
if esg.db.has_file_id(file):
logger.error(
"File id already exists in database, "
"there might be an error with its checksum"
f"\n{file}"
)
continue
file.status = FileStatus.Queued
esg.db.session.add(file)
elif has_legacy and legacy in file_db.queries:
esg.db.unlink(query=legacy, file=file_db)
esg.db.link(query=qf.query, file=file)
esg.ui.raise_maybe_record(Exit(0))
1 change: 1 addition & 0 deletions esgpull/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
CONFIG_FILENAME = "config.toml"
INSTALLS_PATH_ENV = "ESGPULL_INSTALLS_PATH"
ROOT_ENV = "ESGPULL_CURRENT"

IDP = "/esgf-idp/openid/"
Expand Down
23 changes: 21 additions & 2 deletions esgpull/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

from esgpull import __file__
from esgpull.config import Config
from esgpull.models import File, Table, sql
from esgpull.models import File, Query, Table, sql
from esgpull.version import __version__

# from esgpull.exceptions import NoClauseError
# from esgpull.models import Query

T = TypeVar("T")

Expand All @@ -42,8 +41,16 @@ def from_config(config: Config, run_migrations: bool = True) -> Database:
url = f"sqlite:///{config.paths.db / config.db.filename}"
return Database(url, run_migrations=run_migrations)

def _setup_sqlite(self, conn, record):
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode = WAL;")
cursor.execute("PRAGMA synchronous = NORMAL;")
cursor.execute("PRAGMA cache_size = 20000;")
cursor.close()

def __post_init__(self, run_migrations: bool) -> None:
self._engine = sa.create_engine(self.url)
sa.event.listen(self._engine, "connect", self._setup_sqlite)
self.session = Session(self._engine)
if run_migrations:
self._update()
Expand Down Expand Up @@ -80,6 +87,12 @@ def safe(self) -> Iterator[None]:
self.session.rollback()
raise

@contextmanager
def commit_context(self) -> Iterator[None]:
with self.safe:
yield
self.session.commit()

def get(
self,
table: type[Table],
Expand Down Expand Up @@ -132,6 +145,12 @@ def delete(self, *items: Table) -> None:
for item in items:
make_transient(item)

def link(self, query: Query, file: File):
self.session.execute(sql.query_file.link(query, file))

def unlink(self, query: Query, file: File):
self.session.execute(sql.query_file.unlink(query, file))

def __contains__(self, item: Table) -> bool:
return self.scalars(sql.count(item))[0] > 0

Expand Down
12 changes: 10 additions & 2 deletions esgpull/install_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import platformdirs
from typing_extensions import NotRequired, TypedDict

from esgpull.constants import ROOT_ENV
from esgpull.constants import INSTALLS_PATH_ENV, ROOT_ENV
from esgpull.exceptions import AlreadyInstalledName, AlreadyInstalledPath


Expand Down Expand Up @@ -45,7 +45,15 @@ class _InstallConfig:
installs: list[Install]

def __init__(self) -> None:
user_config_dir = platformdirs.user_config_path("esgpull")
self.setup()

def setup(self, install_path: Path | None = None):
if install_path is not None:
user_config_dir = install_path
elif (env := os.environ.get(INSTALLS_PATH_ENV)) is not None:
user_config_dir = Path(env)
else:
user_config_dir = platformdirs.user_config_path("esgpull")
self.path = user_config_dir / "installs.json"
if self.path.is_file():
with self.path.open() as f:
Expand Down
6 changes: 5 additions & 1 deletion esgpull/models/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def setter(self: Selection, values: FacetValues):

setattr(cls, name, property(getter, setter))

@classmethod
def reset(cls) -> None:
cls.configure(*DefaultFacets, *BaseFacets, replace=True)

@classmethod
def configure(cls, *names: str, replace: bool = True) -> None:
nameset = set(names) | {f"!{name}" for name in names}
Expand Down Expand Up @@ -198,4 +202,4 @@ def __repr__(self) -> str:
]


Selection.configure(*DefaultFacets, *BaseFacets, replace=True)
Selection.reset()
16 changes: 16 additions & 0 deletions esgpull/models/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,19 @@ def ids() -> sa.Select[tuple[int]]:
@staticmethod
def with_ids(*ids: int) -> sa.Select[tuple[SyndaFile]]:
return sa.select(SyndaFile).where(SyndaFile.file_id.in_(ids))


class query_file:
@staticmethod
def link(query: Query, file: File) -> sa.Insert:
return sa.insert(query_file_proxy).values(
query_sha=query.sha, file_sha=file.sha
)

@staticmethod
def unlink(query: Query, file: File) -> sa.Delete:
return (
sa.delete(query_file_proxy)
.where(query_file_proxy.c.query_sha == query.sha)
.where(query_file_proxy.c.file_sha == file.sha)
)
7 changes: 3 additions & 4 deletions esgpull/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ class DummyLive:
def __enter__(self) -> DummyLive:
return self

def __exit__(self, *args):
...
def __exit__(self, *args): ...

@property
def console(self) -> DummyConsole:
Expand Down Expand Up @@ -259,9 +258,9 @@ def live(
# use _console to avoid recording the progress bar
return Live(renderables, console=_console)

def track(self, iterable: Iterable[T]) -> Iterable[T]:
def track(self, iterable: Iterable[T], **kwargs) -> Iterable[T]:
# use _console to avoid recording the progress bar
return track(iterable, console=_console)
return track(iterable, console=_console, **kwargs)

def make_progress(
self,
Expand Down
Empty file added tests/cli/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions tests/cli/test_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from time import perf_counter

from click.testing import CliRunner

from esgpull.cli.add import add
from esgpull.cli.config import config
from esgpull.cli.self import install
from esgpull.cli.update import update
from esgpull.install_config import InstallConfig


def test_fast_update(tmp_path):
InstallConfig.setup(tmp_path)
install_path = tmp_path / "esgpull"
runner = CliRunner()
result_install = runner.invoke(install, [f"{install_path}"])
assert result_install.exit_code == 0
result_config = runner.invoke(config, ["api.page_limit", "10000"])
assert result_config.exit_code == 0
result_add = runner.invoke(
add,
[
"table_id:fx",
"experiment_id:dcpp*",
"--distrib",
"false",
"--track",
],
)
assert result_add.exit_code == 0
start = perf_counter()
result_update = runner.invoke(update, ["--yes"])
stop = perf_counter()
assert result_update.exit_code == 0
assert stop - start < 30 # 30 seconds to fetch ~6k files is plenty enough
InstallConfig.setup()
6 changes: 4 additions & 2 deletions tests/test_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
@pytest.fixture
def selection():
Selection.configure("a", "b", "c", "d", replace=True)
return Selection()
yield Selection()
Selection.reset()


def test_configure():
Expand All @@ -16,11 +17,12 @@ def test_configure():
assert new_names <= Selection._facet_names
Selection.configure("some", "thing", replace=True)
assert new_names == Selection._facet_names
sel = Selection()
with pytest.raises(KeyError):
sel = Selection()
assert sel["a"] == []
Selection.configure("a") # add 'a' to facets
assert sel["a"] == [] # no more raise
Selection.reset()


def test_basic(selection):
Expand Down
Loading