From d07119670b705796ebfead8aa0d240c5ddf96c07 Mon Sep 17 00:00:00 2001 From: CamDavidsonPilon Date: Sat, 14 Dec 2024 22:35:02 -0500 Subject: [PATCH] fix more --- pioreactor/actions/od_calibration.py | 1 - .../calibrations/stirring_calibration.py | 2 +- pioreactor/calibrations/utils.py | 5 +-- pioreactor/cli/calibrations.py | 32 ++++++++++--------- pioreactor/tests/test_utils.py | 30 +++++++++++++---- pioreactor/utils/__init__.py | 32 +++++++++++-------- 6 files changed, 63 insertions(+), 39 deletions(-) diff --git a/pioreactor/actions/od_calibration.py b/pioreactor/actions/od_calibration.py index d1514f1f..2fb27d36 100644 --- a/pioreactor/actions/od_calibration.py +++ b/pioreactor/actions/od_calibration.py @@ -617,7 +617,6 @@ def curve_to_functional_form(curve_type: str, curve_data) -> str: raise ValueError() - def display(name: str | None) -> None: def display_from_calibration_blob(data_blob: dict) -> None: voltages = data_blob["voltages"] diff --git a/pioreactor/calibrations/stirring_calibration.py b/pioreactor/calibrations/stirring_calibration.py index 894b410f..8a08a47c 100644 --- a/pioreactor/calibrations/stirring_calibration.py +++ b/pioreactor/calibrations/stirring_calibration.py @@ -51,7 +51,7 @@ def run_stirring_calibration(min_dc: float | None = None, max_dc: float | None = # go up and down to observe any hysteresis. dcs = ( - list(range(round(max_dc), round(min_dc), -3)) + list(range(round(max_dc), round(min_dc), -3)) + list(range(round(min_dc), round(max_dc), 3)) + list(range(round(max_dc), round(min_dc) - 3, -3)) ) diff --git a/pioreactor/calibrations/utils.py b/pioreactor/calibrations/utils.py index 3eb4e4ad..bdf38e27 100644 --- a/pioreactor/calibrations/utils.py +++ b/pioreactor/calibrations/utils.py @@ -1,5 +1,7 @@ -from typing import Callable +# -*- coding: utf-8 -*- +from __future__ import annotations +from typing import Callable def curve_to_callable(curve_type: str, curve_data: list[float]) -> Callable: @@ -46,7 +48,6 @@ def plot_data( plt.plot_size(105, 22) - plt.xlim(x_min, x_max) plt.yfrequency(6) plt.xfrequency(6) diff --git a/pioreactor/cli/calibrations.py b/pioreactor/cli/calibrations.py index cde2dfca..1fe2bf46 100644 --- a/pioreactor/cli/calibrations.py +++ b/pioreactor/cli/calibrations.py @@ -8,10 +8,10 @@ from msgspec.yaml import encode as yaml_encode from pioreactor import structs -from pioreactor.whoami import is_testing_env -from pioreactor.calibrations.utils import plot_data from pioreactor.calibrations.utils import curve_to_callable +from pioreactor.calibrations.utils import plot_data from pioreactor.utils import local_persistant_storage +from pioreactor.whoami import is_testing_env if not is_testing_env(): CALIBRATION_PATH = Path("/home/pioreactor/.pioreactor/storage/calibrations/") @@ -55,10 +55,11 @@ def __init__(self): pass def run(self, min_dc: str | None = None, max_dc: str | None = None) -> structs.StirringCalibration: - from pioreactor.calibrations.stirring_calibration import run_stirring_calibration - return run_stirring_calibration(min_dc=float(min_dc) if min_dc is not None else None, max_dc=float(max_dc) if max_dc else None) + return run_stirring_calibration( + min_dc=float(min_dc) if min_dc is not None else None, max_dc=float(max_dc) if max_dc else None + ) @click.group(short_help="calibration utils") @@ -82,14 +83,11 @@ def list_calibrations(cal_type: str): assistant = CALIBRATION_ASSISTANTS.get(cal_type) - - header = f"{'Name':<50}{'Created At':<25}{'Subtype':<15}{'Current?':<15}" click.echo(header) - click.echo('-' * len(header)) + click.echo("-" * len(header)) with local_persistant_storage("current_calibrations") as c: - for file in calibration_dir.glob("*.yaml"): try: data = yaml_decode(file.read_bytes(), type=assistant.calibration_struct) @@ -100,6 +98,7 @@ def list_calibrations(cal_type: str): error_message = f"Error reading {file.stem}: {e}" click.echo(f"{error_message:<60}") + @calibration.command(name="run", context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)) @click.option("--type", "cal_type", required=True, help="Type of calibration (e.g. od, pump, stirring).") @click.pass_context @@ -116,7 +115,9 @@ def run_calibration(ctx, cal_type: str): raise click.Abort() # Run the assistant function to get the final calibration data - calibration_data = assistant().run(**{ctx.args[i][2:].replace("-", "_"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2)},) + calibration_data = assistant().run( + **{ctx.args[i][2:].replace("-", "_"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2)}, + ) calibration_name = calibration_data.calibration_name calibration_dir = CALIBRATION_PATH / cal_type @@ -154,12 +155,13 @@ def display_calibration(cal_type: str, calibration_name: str): click.echo() curve = curve_to_callable(data.curve_type, data.curve_data_) plot_data( - data.recorded_data['x'], - data.recorded_data['y'], - calibration_name, - data.x, - data.y, - interpolation_curve=curve) + data.recorded_data["x"], + data.recorded_data["y"], + calibration_name, + data.x, + data.y, + interpolation_curve=curve, + ) click.echo() click.echo() diff --git a/pioreactor/tests/test_utils.py b/pioreactor/tests/test_utils.py index 2fd4da50..2fce2f14 100644 --- a/pioreactor/tests/test_utils.py +++ b/pioreactor/tests/test_utils.py @@ -75,16 +75,34 @@ def test_caches_will_delete_when_asked() -> None: def test_caches_can_have_tuple_or_singleton_keys() -> None: - with local_persistant_storage("test_caches_can_have_tuple_keys") as c: - c[(1,2)] = 1 - c[("a","b")] = 2 - c[("a",None)] = 3 + c[(1, 2)] = 1 + c[("a", "b")] = 2 + c[("a", None)] = 3 c[4] = 4 - c["5e"] = 5 + c["5"] = 5 with local_persistant_storage("test_caches_can_have_tuple_keys") as c: - assert list(c.iterkeys()) == [4, '5e', ['a', 'b'], ['a', None], [1, 2]] + assert list(c.iterkeys()) == [4, "5", ["a", "b"], ["a", None], [1, 2]] + + +def test_caches_integer_keys() -> None: + with local_persistant_storage("test_caches_integer_keys") as c: + c[1] = "a" + c[2] = "b" + + with local_persistant_storage("test_caches_integer_keys") as c: + assert list(c.iterkeys()) == [1, 2] + + +def test_caches_str_keys_as_ints_stay_as_str() -> None: + with local_persistant_storage("test_caches_str_keys_as_ints_stay_as_str") as c: + c["1"] = "a" + c["2"] = "b" + + with local_persistant_storage("test_caches_str_keys_as_ints_stay_as_str") as c: + assert list(c.iterkeys()) == ["1", "2"] + def test_is_pio_job_running_single() -> None: experiment = "test_is_pio_job_running_single" diff --git a/pioreactor/utils/__init__.py b/pioreactor/utils/__init__.py index 3371a0ad..4d557c1f 100644 --- a/pioreactor/utils/__init__.py +++ b/pioreactor/utils/__init__.py @@ -19,10 +19,10 @@ from typing import Sequence from typing import TYPE_CHECKING -from msgspec import Struct from msgspec import DecodeError -from msgspec.json import encode as dumps +from msgspec import Struct from msgspec.json import decode as loads +from msgspec.json import encode as dumps from pioreactor import structs from pioreactor import types as pt @@ -264,30 +264,32 @@ def publish_setting(self, setting: str, value: Any) -> None: class cache: - # keys can be tuples! @staticmethod def adapt_key(key): + # keys can be tuples! return dumps(key) @staticmethod def convert_key(s): - try: - return loads(s) - except DecodeError: - return s.decode() + if isinstance(s, bytes): + try: + return loads(s) + except DecodeError: + return s.decode() + else: + return s def __init__(self, table_name, db_path): self.table_name = f"cache_{table_name}" self.db_path = db_path def __enter__(self): - sqlite3.register_adapter(tuple, self.adapt_key) - sqlite3.register_converter("_key", self.convert_key) + # sqlite3.register_converter("_key_BLOB", self.convert_key) self.conn = sqlite3.connect(self.db_path, isolation_level=None, detect_types=sqlite3.PARSE_DECLTYPES) - self.conn.execute('pragma journal_mode=wal') + self.conn.execute("pragma journal_mode=wal") self.cursor = self.conn.cursor() self._initialize_table() return self @@ -299,7 +301,7 @@ def _initialize_table(self): self.cursor.execute( f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( - key _key PRIMARY KEY, + key _key_BLOB PRIMARY KEY, value BLOB ) """ @@ -322,7 +324,7 @@ def get(self, key, default=None): def iterkeys(self): self.cursor.execute(f"SELECT key FROM {self.table_name}") - return (row[0] for row in self.cursor.fetchall()) + return (self.convert_key(row[0]) for row in self.cursor.fetchall()) def pop(self, key, default=None): self.cursor.execute(f"SELECT value FROM {self.table_name} WHERE key = ?", (key,)) @@ -370,7 +372,9 @@ def local_intermittent_storage( Opening the same cache in a context manager is tricky, and should be avoided. """ - with cache(f"{cache_name}", db_path=f"{tempfile.gettempdir()}/local_intermittent_pioreactor_metadata.sqlite") as c: + with cache( + f"{cache_name}", db_path=f"{tempfile.gettempdir()}/local_intermittent_pioreactor_metadata.sqlite" + ) as c: yield c @@ -595,7 +599,7 @@ class JobManager: def __init__(self) -> None: db_path = f"{tempfile.gettempdir()}/local_intermittent_pioreactor_metadata.sqlite" self.conn = sqlite3.connect(db_path, isolation_level=None) - self.conn.execute('pragma journal_mode=wal') + self.conn.execute("pragma journal_mode=wal") self.cursor = self.conn.cursor() self._create_tables()