Skip to content

Commit

Permalink
fix more
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Dec 15, 2024
1 parent ae7243d commit d071196
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 39 deletions.
1 change: 0 additions & 1 deletion pioreactor/actions/od_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,6 @@ def curve_to_functional_form(curve_type: str, curve_data) -> str:
raise ValueError()



def display(name: str | None) -> None:
def display_from_calibration_blob(data_blob: dict) -> None:
voltages = data_blob["voltages"]
Expand Down
2 changes: 1 addition & 1 deletion pioreactor/calibrations/stirring_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def run_stirring_calibration(min_dc: float | None = None, max_dc: float | None =

# go up and down to observe any hysteresis.
dcs = (
list(range(round(max_dc), round(min_dc), -3))
list(range(round(max_dc), round(min_dc), -3))
+ list(range(round(min_dc), round(max_dc), 3))
+ list(range(round(max_dc), round(min_dc) - 3, -3))
)
Expand Down
5 changes: 3 additions & 2 deletions pioreactor/calibrations/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Callable
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import Callable


def curve_to_callable(curve_type: str, curve_data: list[float]) -> Callable:
Expand Down Expand Up @@ -46,7 +48,6 @@ def plot_data(

plt.plot_size(105, 22)


plt.xlim(x_min, x_max)
plt.yfrequency(6)
plt.xfrequency(6)
Expand Down
32 changes: 17 additions & 15 deletions pioreactor/cli/calibrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from msgspec.yaml import encode as yaml_encode

from pioreactor import structs
from pioreactor.whoami import is_testing_env
from pioreactor.calibrations.utils import plot_data
from pioreactor.calibrations.utils import curve_to_callable
from pioreactor.calibrations.utils import plot_data
from pioreactor.utils import local_persistant_storage
from pioreactor.whoami import is_testing_env

if not is_testing_env():
CALIBRATION_PATH = Path("/home/pioreactor/.pioreactor/storage/calibrations/")
Expand Down Expand Up @@ -55,10 +55,11 @@ def __init__(self):
pass

def run(self, min_dc: str | None = None, max_dc: str | None = None) -> structs.StirringCalibration:

from pioreactor.calibrations.stirring_calibration import run_stirring_calibration

return run_stirring_calibration(min_dc=float(min_dc) if min_dc is not None else None, max_dc=float(max_dc) if max_dc else None)
return run_stirring_calibration(
min_dc=float(min_dc) if min_dc is not None else None, max_dc=float(max_dc) if max_dc else None
)


@click.group(short_help="calibration utils")
Expand All @@ -82,14 +83,11 @@ def list_calibrations(cal_type: str):

assistant = CALIBRATION_ASSISTANTS.get(cal_type)



header = f"{'Name':<50}{'Created At':<25}{'Subtype':<15}{'Current?':<15}"
click.echo(header)
click.echo('-' * len(header))
click.echo("-" * len(header))

with local_persistant_storage("current_calibrations") as c:

for file in calibration_dir.glob("*.yaml"):
try:
data = yaml_decode(file.read_bytes(), type=assistant.calibration_struct)
Expand All @@ -100,6 +98,7 @@ def list_calibrations(cal_type: str):
error_message = f"Error reading {file.stem}: {e}"
click.echo(f"{error_message:<60}")


@calibration.command(name="run", context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.option("--type", "cal_type", required=True, help="Type of calibration (e.g. od, pump, stirring).")
@click.pass_context
Expand All @@ -116,7 +115,9 @@ def run_calibration(ctx, cal_type: str):
raise click.Abort()

# Run the assistant function to get the final calibration data
calibration_data = assistant().run(**{ctx.args[i][2:].replace("-", "_"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2)},)
calibration_data = assistant().run(
**{ctx.args[i][2:].replace("-", "_"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2)},
)
calibration_name = calibration_data.calibration_name

calibration_dir = CALIBRATION_PATH / cal_type
Expand Down Expand Up @@ -154,12 +155,13 @@ def display_calibration(cal_type: str, calibration_name: str):
click.echo()
curve = curve_to_callable(data.curve_type, data.curve_data_)
plot_data(
data.recorded_data['x'],
data.recorded_data['y'],
calibration_name,
data.x,
data.y,
interpolation_curve=curve)
data.recorded_data["x"],
data.recorded_data["y"],
calibration_name,
data.x,
data.y,
interpolation_curve=curve,
)

click.echo()
click.echo()
Expand Down
30 changes: 24 additions & 6 deletions pioreactor/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,34 @@ def test_caches_will_delete_when_asked() -> None:


def test_caches_can_have_tuple_or_singleton_keys() -> None:

with local_persistant_storage("test_caches_can_have_tuple_keys") as c:
c[(1,2)] = 1
c[("a","b")] = 2
c[("a",None)] = 3
c[(1, 2)] = 1
c[("a", "b")] = 2
c[("a", None)] = 3
c[4] = 4
c["5e"] = 5
c["5"] = 5

with local_persistant_storage("test_caches_can_have_tuple_keys") as c:
assert list(c.iterkeys()) == [4, '5e', ['a', 'b'], ['a', None], [1, 2]]
assert list(c.iterkeys()) == [4, "5", ["a", "b"], ["a", None], [1, 2]]


def test_caches_integer_keys() -> None:
with local_persistant_storage("test_caches_integer_keys") as c:
c[1] = "a"
c[2] = "b"

with local_persistant_storage("test_caches_integer_keys") as c:
assert list(c.iterkeys()) == [1, 2]


def test_caches_str_keys_as_ints_stay_as_str() -> None:
with local_persistant_storage("test_caches_str_keys_as_ints_stay_as_str") as c:
c["1"] = "a"
c["2"] = "b"

with local_persistant_storage("test_caches_str_keys_as_ints_stay_as_str") as c:
assert list(c.iterkeys()) == ["1", "2"]


def test_is_pio_job_running_single() -> None:
experiment = "test_is_pio_job_running_single"
Expand Down
32 changes: 18 additions & 14 deletions pioreactor/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from typing import Sequence
from typing import TYPE_CHECKING

from msgspec import Struct
from msgspec import DecodeError
from msgspec.json import encode as dumps
from msgspec import Struct
from msgspec.json import decode as loads
from msgspec.json import encode as dumps

from pioreactor import structs
from pioreactor import types as pt
Expand Down Expand Up @@ -264,30 +264,32 @@ def publish_setting(self, setting: str, value: Any) -> None:


class cache:
# keys can be tuples!
@staticmethod
def adapt_key(key):
# keys can be tuples!
return dumps(key)

@staticmethod
def convert_key(s):
try:
return loads(s)
except DecodeError:
return s.decode()
if isinstance(s, bytes):
try:
return loads(s)
except DecodeError:
return s.decode()
else:
return s

def __init__(self, table_name, db_path):
self.table_name = f"cache_{table_name}"
self.db_path = db_path

def __enter__(self):

sqlite3.register_adapter(tuple, self.adapt_key)
sqlite3.register_converter("_key", self.convert_key)
# sqlite3.register_converter("_key_BLOB", self.convert_key)

self.conn = sqlite3.connect(self.db_path, isolation_level=None, detect_types=sqlite3.PARSE_DECLTYPES)

self.conn.execute('pragma journal_mode=wal')
self.conn.execute("pragma journal_mode=wal")
self.cursor = self.conn.cursor()
self._initialize_table()
return self
Expand All @@ -299,7 +301,7 @@ def _initialize_table(self):
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key _key PRIMARY KEY,
key _key_BLOB PRIMARY KEY,
value BLOB
)
"""
Expand All @@ -322,7 +324,7 @@ def get(self, key, default=None):

def iterkeys(self):
self.cursor.execute(f"SELECT key FROM {self.table_name}")
return (row[0] for row in self.cursor.fetchall())
return (self.convert_key(row[0]) for row in self.cursor.fetchall())

def pop(self, key, default=None):
self.cursor.execute(f"SELECT value FROM {self.table_name} WHERE key = ?", (key,))
Expand Down Expand Up @@ -370,7 +372,9 @@ def local_intermittent_storage(
Opening the same cache in a context manager is tricky, and should be avoided.
"""
with cache(f"{cache_name}", db_path=f"{tempfile.gettempdir()}/local_intermittent_pioreactor_metadata.sqlite") as c:
with cache(
f"{cache_name}", db_path=f"{tempfile.gettempdir()}/local_intermittent_pioreactor_metadata.sqlite"
) as c:
yield c


Expand Down Expand Up @@ -595,7 +599,7 @@ class JobManager:
def __init__(self) -> None:
db_path = f"{tempfile.gettempdir()}/local_intermittent_pioreactor_metadata.sqlite"
self.conn = sqlite3.connect(db_path, isolation_level=None)
self.conn.execute('pragma journal_mode=wal')
self.conn.execute("pragma journal_mode=wal")
self.cursor = self.conn.cursor()
self._create_tables()

Expand Down

0 comments on commit d071196

Please sign in to comment.