From efefceb68c6948faccfd393f78ca46dbf9211b11 Mon Sep 17 00:00:00 2001 From: "r.jaepel" Date: Fri, 22 Mar 2024 12:41:13 +0100 Subject: [PATCH] Add .save_as_python_script method and test --- cadet/cadet.py | 94 ++++++++++++++++++++++++++++++++++++ tests/test_save_as_python.py | 64 ++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 tests/test_save_as_python.py diff --git a/cadet/cadet.py b/cadet/cadet.py index 1a3530f..f834f3d 100644 --- a/cadet/cadet.py +++ b/cadet/cadet.py @@ -77,6 +77,32 @@ def load_json(self, filename, update=False): else: self.root = data + def save_as_python_script(self, filename: str, only_return_pythonic_representation=False): + if not filename.endswith(".py"): + raise Warning(f"The filename given to .save_as_python_script isn't a python file name.") + + code_lines_list = [ + "import numpy", + "from cadet import Cadet", + "", + "sim = Cadet()", + "root = sim.root", + ] + + code_lines_list = recursively_turn_dict_to_python_list(dictionary=self.root, + current_lines_list=code_lines_list, + prefix="root") + + filename_for_reproduced_h5_file = filename.replace(".py", ".h5") + code_lines_list.append(f"sim.filename = '{filename_for_reproduced_h5_file}'") + code_lines_list.append("sim.save()") + + if not only_return_pythonic_representation: + with open(filename, "w") as handle: + handle.writelines([line + "\n" for line in code_lines_list]) + else: + return code_lines_list + def append(self, lock=False): "This can only be used to write new keys to the system, this is faster than having to read the data before writing it" if self.filename is not None: @@ -347,3 +373,71 @@ def recursively_save(h5file, path, dic, func): raise KeyError(f'Name conflict with upper and lower case entries for key "{path}{key}".') else: raise + + +def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None): + """ + Recursively turn a nested dictionary or addict.Dict into a list of Python code that + can generate the nested dictionary. + + :param dictionary: + :param current_lines_list: + :param prefix_list: + :return: list of Python code lines + """ + + def merge_to_absolute_key(prefix, key): + """ + Combine key and prefix to "prefix.key" except if there is no prefix, then return key + """ + if prefix is None: + return key + else: + return f"{prefix}.{key}" + + def clean_up_key(absolute_key: str): + """ + Remove problematic phrases from key, such as blank "return" + + :param absolute_key: + :return: + """ + absolute_key = absolute_key.replace(".return", "['return']") + return absolute_key + + def get_pythonic_representation_of_value(value): + """ + Use repr() to get a pythonic representation of the value + and add "np." to "array" and "float64" + + """ + value_representation = repr(value) + value_representation = value_representation.replace("array", "numpy.array") + value_representation = value_representation.replace("float64", "numpy.float64") + try: + eval(value_representation) + except NameError as e: + raise ValueError( + f"Encountered a value of '{value_representation}' that can't be directly reproduced in python.\n" + f"Please report this to the CADET-Python developers.") from e + + return value_representation + + if current_lines_list is None: + current_lines_list = [] + + for key in sorted(dictionary.keys()): + value = dictionary[key] + + absolute_key = merge_to_absolute_key(prefix, key) + + if type(value) in (dict, Dict): + current_lines_list = recursively_turn_dict_to_python_list(value, current_lines_list, prefix=absolute_key) + else: + value_representation = get_pythonic_representation_of_value(value) + + absolute_key = clean_up_key(absolute_key) + + current_lines_list.append(f"{absolute_key} = {value_representation}") + + return current_lines_list diff --git a/tests/test_save_as_python.py b/tests/test_save_as_python.py new file mode 100644 index 0000000..c59b52f --- /dev/null +++ b/tests/test_save_as_python.py @@ -0,0 +1,64 @@ +import tempfile + +import numpy as np +import pytest +from addict import Dict + +from cadet import Cadet + + +@pytest.fixture +def temp_cadet_file(): + """ + Create a new Cadet object for use in tests. + """ + model = Cadet() + + with tempfile.NamedTemporaryFile() as temp: + model.filename = temp + yield model + + +def test_save_as_python(temp_cadet_file): + """ + Test that the Cadet class raises a KeyError exception when duplicate keys are set on it. + """ + # initialize "sim" variable to be overwritten by the exec lines later + sim = Cadet() + + # Populate temp_cadet_file with all tricky cases currently known + temp_cadet_file.root.input.foo = 1 + temp_cadet_file.root.input.bar.baryon = np.arange(10) + temp_cadet_file.root.input.bar.barometer = np.linspace(0, 10, 9) + temp_cadet_file.root.input.bar.init_q = np.array([], dtype=np.float64) + temp_cadet_file.root.input["return"].split_foobar = 1 + + code_lines = temp_cadet_file.save_as_python_script(filename="temp.py", only_return_pythonic_representation=True) + + # remove code lines that save the file + code_lines = code_lines[:-2] + + # populate "sim" variable using the generated code lines + for line in code_lines: + exec(line) + + # test that "sim" is equal to "temp_cadet_file" + recursive_equality_check(sim.root, temp_cadet_file.root) + + +def recursive_equality_check(dict_a: dict, dict_b: dict): + assert dict_a.keys() == dict_b.keys() + for key in dict_a.keys(): + value_a = dict_a[key] + value_b = dict_b[key] + if type(value_a) in (dict, Dict): + recursive_equality_check(value_a, value_b) + elif type(value_a) == np.ndarray: + np.testing.assert_array_equal(value_a, value_b) + else: + assert value_a == value_b + return True + + +if __name__ == "__main__": + pytest.main()