diff --git a/pioreactor/structs.py b/pioreactor/structs.py index b7d24a76..13a7ec16 100644 --- a/pioreactor/structs.py +++ b/pioreactor/structs.py @@ -154,6 +154,10 @@ class CalibrationBase(Struct, tag_field="calibration_type", kw_only=True): y: str recorded_data: dict[t.Literal["x", "y"], list[X | Y]] + def __post_init__(self): + if len(self.recorded_data["x"]) != len(self.recorded_data["y"]): + raise ValueError("Lists in `recorded_data` should have the same lengths") + @property def calibration_type(self): return self.__struct_config__.tag @@ -212,6 +216,13 @@ def ipredict(self, y: Y, enforce_bounds=False) -> X: roots_ = roots(solve_for_poly).tolist() plausible_sols_: list[X] = sorted([real(r) for r in roots_ if (abs(imag(r)) < 1e-10)]) + if len(self.recorded_data["x"]) == 0: + from math import inf + + min_X, max_X = -inf, inf + else: + min_X, max_X = min(self.recorded_data["x"]), max(self.recorded_data["x"]) + if len(plausible_sols_) == 0: raise exc.NoSolutionsFoundError("No solutions found") elif len(plausible_sols_) == 1: @@ -220,7 +231,6 @@ def ipredict(self, y: Y, enforce_bounds=False) -> X: if not enforce_bounds: return sol - min_X, max_X = min(self.recorded_data["x"]), max(self.recorded_data["x"]) # if we are here, we let the downstream user decide how to proceed if min_X <= sol <= max_X: return sol diff --git a/pioreactor/tests/test_utils.py b/pioreactor/tests/test_utils.py index 95a63417..31a24b16 100644 --- a/pioreactor/tests/test_utils.py +++ b/pioreactor/tests/test_utils.py @@ -10,6 +10,7 @@ from pioreactor.background_jobs.stirring import start_stirring from pioreactor.tests.conftest import capture_requests +from pioreactor.utils import argextrema from pioreactor.utils import callable_stack from pioreactor.utils import ClusterJobManager from pioreactor.utils import is_pio_job_running @@ -374,3 +375,8 @@ def test_retrieve_setting(job_manager, job_id): job_manager.set_not_running(job_key) with pytest.raises(NameError): job_manager.get_setting_from_running_job("test_name", "my_setting_int") + + +def test_argextrema_with_empty_lists() -> None: + with pytest.raises(ValueError): + argextrema([]) diff --git a/pioreactor/utils/__init__.py b/pioreactor/utils/__init__.py index 620bfce7..54fa6656 100644 --- a/pioreactor/utils/__init__.py +++ b/pioreactor/utils/__init__.py @@ -456,6 +456,9 @@ def get_cpu_temperature() -> float: def argextrema(x: Sequence) -> tuple[int, int]: + if len(x) == 0: + raise ValueError("argextrema() arg is an empty sequence") + min_, max_ = float("inf"), float("-inf") argmin_, argmax_ = 0, 0 for i, value in enumerate(x): diff --git a/update_scripts/upcoming/cal_active.py b/update_scripts/upcoming/cal_active.py new file mode 100644 index 00000000..eb1263bf --- /dev/null +++ b/update_scripts/upcoming/cal_active.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import json +import sqlite3 + +import click + +from pioreactor.utils import local_persistent_storage + + +def old_active_calibrations(sql_database: str): + try: + conn = sqlite3.connect(sql_database) + except sqlite3.OperationalError: + return + cursor = conn.cursor() + cursor.execute("SELECT value FROM cache") + + for row in cursor.fetchall(): + data = json.loads(row[0]) + yield data + + +@click.command() +@click.argument("db_location") +def main(db_location: str) -> None: + with local_persistent_storage("active_calibrations") as c: + for old_cal in old_active_calibrations(db_location): + if "pump" in old_cal["type"]: + type_ = old_cal["type"] + elif "od" in old_cal["type"]: + type_ = "od" + else: + print(f"Unknown type {old_cal['type']}") + continue + + c[type_] = old_cal["name"] + print(f"Setting active calibration for {type_} to {old_cal['name']}") + + +if __name__ == "__main__": + main() diff --git a/update_scripts/upcoming/cal_convert.py b/update_scripts/upcoming/cal_convert.py index 5831be9c..89f62871 100644 --- a/update_scripts/upcoming/cal_convert.py +++ b/update_scripts/upcoming/cal_convert.py @@ -63,6 +63,7 @@ def old_calibrations(sql_database: str): def main(db_location: str) -> None: for old_cal in old_calibrations(db_location): try: + print(f"Converting {old_cal['name']} to new calibration format.") if "pump" in old_cal["type"]: new_cal, device = convert_old_to_new_pump(old_cal) elif "od" in old_cal["type"]: diff --git a/update_scripts/upcoming/update.sh b/update_scripts/upcoming/update.sh index a704105e..8619dcfc 100644 --- a/update_scripts/upcoming/update.sh +++ b/update_scripts/upcoming/update.sh @@ -36,6 +36,9 @@ sudo -u pioreactor python "$SCRIPT_DIR"/cal_convert.py "$STORAGE_DIR"/od_calibra sudo -u pioreactor python "$SCRIPT_DIR"/cal_convert.py "$STORAGE_DIR"/pump_calibrations/cache.db chown -R pioreactor:pioreactor "$STORAGE_DIR"/calibrations/*/*.yaml +sudo -u pioreactor python "$SCRIPT_DIR"/cal_active.py "$STORAGE_DIR"/current_pump_calibrations/cache.db +sudo -u pioreactor python "$SCRIPT_DIR"/cal_active.py "$STORAGE_DIR"/current_od_calibrations/cache.db + # if leader if [ "$HOSTNAME" = "$LEADER_HOSTNAME" ]; then