Skip to content

Commit

Permalink
Updated print functions and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davenquinn committed Dec 24, 2024
1 parent e010ccb commit a0cf97b
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 18 deletions.
5 changes: 5 additions & 0 deletions database/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## [3.5.3] - 2024-12-23

- Fix errors and add tests for `run_sql` changes
- Rename `PrintMode` -> `OutputMode`

## [3.5.2] - 2024-12-23

- Add the ability to print less with the `run_sql` function
Expand Down
38 changes: 24 additions & 14 deletions database/macrostrat/database/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from contextlib import contextmanager
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -118,11 +119,6 @@ def summarize_statement(sql):
return line.split("(")[0].strip().rstrip(";").replace(" AS", "")


class DevNull(object):
def write(self, *_):
pass


def get_sql_text(sql, interpret_as_file=None, echo_file_name=True):
if interpret_as_file:
sql = Path(sql).read_text()
Expand Down Expand Up @@ -239,13 +235,25 @@ def infer_has_server_binds(sql):
_default_statement_filter = lambda sql_text, params: True


class PrintMode(Enum):
class OutputMode(Enum):
NONE = "none"
ERRORS = "errors"
SUMMARY = "summary"
ALL = "all"


def _normalize_output_args(kwargs):
output_mode = kwargs.pop("output_mode", OutputMode.SUMMARY)
output_file = kwargs.pop("output_file", stderr)

if not isinstance(output_mode, OutputMode):
output_mode = OutputMode(output_mode)

if output_mode == OutputMode.NONE:
output_file = open(os.devnull, "w")
return output_mode, output_file


def _run_sql(connectable, sql, params=None, **kwargs):
"""
Internal function for running a query on a SQLAlchemy connectable,
Expand All @@ -262,11 +270,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
has_server_binds = kwargs.pop("has_server_binds", None)
ensure_single_query = kwargs.pop("ensure_single_query", False)
statement_filter = kwargs.pop("statement_filter", _default_statement_filter)
output_mode = kwargs.pop("output_mode", PrintMode.SUMMARY)
output_file = kwargs.pop("output_file", stderr)

if output_mode == PrintMode.NONE:
output_file = DevNull()
output_mode, output_file = _normalize_output_args(kwargs)

if stop_on_error:
raise_errors = True
Expand Down Expand Up @@ -312,7 +316,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
should_run = statement_filter(sql_text, params)

# Shorten summary text for printing
if output_mode != PrintMode.ALL:
if output_mode != OutputMode.ALL:
sql_text = summarize_statement(sql_text)

if not should_run:
Expand Down Expand Up @@ -448,11 +452,17 @@ def run_fixtures(connectable, fixtures: Union[Path, list[Path]], params=None, **
"""
recursive = kwargs.pop("recursive", False)
order_by_name = kwargs.pop("order_by_name", True)
console = kwargs.pop("console", Console(stderr=True))
output_mode, output_file = _normalize_output_args(kwargs)

console = kwargs.pop("console", Console(stderr=True, file=output_file))
files = get_sql_files(fixtures, recursive=recursive, order_by_name=order_by_name)

prefix = os.path.commonpath(files)

console.print(f"Running fixtures in [cyan bold]{prefix}[/]")
for fixture in files:
console.print(f"[cyan bold]{fixture}[/]")
fn = fixture.relative_to(prefix)
console.print(f"[cyan bold]{fn}[/]")
run_sql_file(connectable, fixture, params, **kwargs)
console.print()

Expand Down
2 changes: 1 addition & 1 deletion database/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ authors = ["Daven Quinn <[email protected]>"]
description = "A SQLAlchemy-based database toolkit."
name = "macrostrat.database"
packages = [{ include = "macrostrat" }]
version = "3.5.2"
version = "3.5.3"

[tool.poetry.dependencies]
GeoAlchemy2 = "^0.15.2"
Expand Down
39 changes: 38 additions & 1 deletion database/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
NOTE: At the moment, these tests are not independent and must run in order.
"""

from io import StringIO, TextIOWrapper
from pathlib import Path
from sys import stdout

Expand All @@ -17,7 +18,12 @@

from macrostrat.database import Database, run_sql
from macrostrat.database.postgresql import table_exists
from macrostrat.database.utils import _print_error, infer_is_sql_text, temp_database
from macrostrat.database.utils import (
_print_error,
infer_is_sql_text,
temp_database,
run_fixtures,
)
from macrostrat.utils import get_logger, relative_path

load_dotenv()
Expand Down Expand Up @@ -412,3 +418,34 @@ def test_database_schema_refresh(db):

def test_print_error():
_print_error("SELECT * FROM test", Exception("Test error"))


def _check_text(_stdout: TextIOWrapper, _text: str):
_stdout.seek(0)
assert _stdout.read() == _text


def test_printing(db):
# Check that nothing was printed to stderr
# Collect printed statements
with StringIO() as _stdout:
run_sql(db.session, "SELECT 1", output_file=_stdout)
_check_text(_stdout, "SELECT 1\n")


def test_no_printing(db):
# Check that nothing was printed to stderr
# Collect printed statements
with StringIO() as _stdout:
run_sql(db.session, "SELECT 1", output_mode="none", output_file=_stdout)
_check_text(_stdout, "")


def test_no_printing_fixtures(db, capsys):
# Check that nothing was printed to stderr
# Collect printed statements

fd = Path(relative_path(__file__, "fixtures"))
with StringIO() as _stdout:
run_fixtures(db.session, fd, output_mode="none", output_file=_stdout)
_check_text(_stdout, "")
2 changes: 1 addition & 1 deletion dinosaur/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a0cf97b

Please sign in to comment.