diff --git a/snowfakery/data_generator.py b/snowfakery/data_generator.py index 53c1b6a6..9d027707 100644 --- a/snowfakery/data_generator.py +++ b/snowfakery/data_generator.py @@ -19,7 +19,7 @@ from .data_gen_exceptions import DataGenError from .plugins import SnowfakeryPlugin, PluginOption -from .utils.yaml_utils import SnowfakeryDumper, hydrate +from .utils.yaml_utils import SnowfakeryContinuationDumper, hydrate from snowfakery.standard_plugins.UniqueId import UniqueId # This tool is essentially a three stage interpreter. @@ -95,9 +95,9 @@ def load_continuation_yaml(continuation_file: OpenFileLike): def save_continuation_yaml(continuation_data: Globals, continuation_file: OpenFileLike): """Save the global interpreter state from Globals into a continuation_file""" yaml.dump( - continuation_data.__getstate__(), + continuation_data, continuation_file, - Dumper=SnowfakeryDumper, + Dumper=SnowfakeryContinuationDumper, ) diff --git a/snowfakery/data_generator_runtime.py b/snowfakery/data_generator_runtime.py index 400cb9c0..d05c6d43 100644 --- a/snowfakery/data_generator_runtime.py +++ b/snowfakery/data_generator_runtime.py @@ -13,7 +13,7 @@ import yaml from .utils.template_utils import FakerTemplateLibrary -from .utils.yaml_utils import SnowfakeryDumper, hydrate +from .utils.yaml_utils import hydrate from .row_history import RowHistory from .template_funcs import StandardFuncs from .data_gen_exceptions import DataGenSyntaxError, DataGenNameError @@ -27,6 +27,7 @@ ) from snowfakery.plugins import PluginContext, SnowfakeryPlugin, ScalarTypes from snowfakery.utils.collections import OrderedSet +from snowfakery.utils.yaml_utils import register_for_continuation OutputStream = "snowfakery.output_streams.OutputStream" VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition" @@ -60,17 +61,15 @@ def generate_id(self, table_name: str) -> int: def __getitem__(self, table_name: str) -> int: return self.last_used_ids[table_name] - def __getstate__(self): + # TODO: Fix this to use the new convention of get_continuation_data + def get_continuation_state(self): return {"last_used_ids": dict(self.last_used_ids)} - def __setstate__(self, state): + def restore_from_continuation(self, state): self.last_used_ids = defaultdict(lambda: 0, state["last_used_ids"]) self.start_ids = {name: val + 1 for name, val in self.last_used_ids.items()} -SnowfakeryDumper.add_representer(defaultdict, SnowfakeryDumper.represent_dict) - - class Dependency(NamedTuple): table_name_from: str table_name_to: str @@ -195,29 +194,22 @@ def check_slots_filled(self): def first_new_id(self, tablename): return self.transients.first_new_id(tablename) - def __getstate__(self): - def serialize_dict_of_object_rows(dct): - return {k: v.__getstate__() for k, v in dct.items()} - - persistent_nicknames = serialize_dict_of_object_rows(self.persistent_nicknames) - persistent_objects_by_table = serialize_dict_of_object_rows( - self.persistent_objects_by_table - ) + def get_continuation_state(self): intertable_dependencies = [ dict(v._asdict()) for v in self.intertable_dependencies ] # converts ordered-dict to dict for Python 3.6 and 3.7 state = { - "persistent_nicknames": persistent_nicknames, - "persistent_objects_by_table": persistent_objects_by_table, - "id_manager": self.id_manager.__getstate__(), + "persistent_nicknames": self.persistent_nicknames, + "persistent_objects_by_table": self.persistent_objects_by_table, + "id_manager": self.id_manager.get_continuation_state(), "today": self.today, "nicknames_and_tables": self.nicknames_and_tables, "intertable_dependencies": intertable_dependencies, } return state - def __setstate__(self, state): + def restore_from_continuation(self, state): def deserialize_dict_of_object_rows(dct): return {k: hydrate(ObjectRow, v) for k, v in dct.items()} @@ -244,6 +236,9 @@ def deserialize_dict_of_object_rows(dct): self.reset_slots() +register_for_continuation(Globals, Globals.get_continuation_state) + + class JinjaTemplateEvaluatorFactory: def __init__(self, native_types: bool): if native_types: diff --git a/snowfakery/object_rows.py b/snowfakery/object_rows.py index 3e836a35..ac9696df 100644 --- a/snowfakery/object_rows.py +++ b/snowfakery/object_rows.py @@ -2,7 +2,7 @@ import yaml import snowfakery # noqa -from .utils.yaml_utils import SnowfakeryDumper +from .utils.yaml_utils import register_for_continuation from contextvars import ContextVar IdManager = "snowfakery.data_generator_runtime.IdManager" @@ -14,10 +14,6 @@ class ObjectRow: Uses __getattr__ so that the template evaluator can use dot-notation.""" - yaml_loader = yaml.SafeLoader - yaml_dumper = SnowfakeryDumper - yaml_tag = "!snowfakery_objectrow" - # be careful changing these slots because these objects must be serializable # to YAML and JSON __slots__ = ["_tablename", "_values", "_child_index"] @@ -49,19 +45,28 @@ def __repr__(self): except Exception: return super().__repr__() - def __getstate__(self): + def get_continuation_state(self): """Get the state of this ObjectRow for serialization. Do not include related ObjectRows because circular references in serialization formats cause problems.""" + + # If we decided to try to serialize hierarchies, we could + # do it like this: + # * keep track of if an object has already been serialized using a + # property of the SnowfakeryContinuationDumper + # * If so, output an ObjectReference instead of an ObjectRow values = {k: v for k, v in self._values.items() if not isinstance(v, ObjectRow)} return {"_tablename": self._tablename, "_values": values} - def __setstate__(self, state): + def restore_from_continuation(self, state): for slot, value in state.items(): setattr(self, slot, value) +register_for_continuation(ObjectRow, ObjectRow.get_continuation_state) + + class ObjectReference(yaml.YAMLObject): def __init__(self, tablename: str, id: int): self._tablename = tablename diff --git a/snowfakery/plugins.py b/snowfakery/plugins.py index c7c548ce..71b25429 100644 --- a/snowfakery/plugins.py +++ b/snowfakery/plugins.py @@ -8,13 +8,11 @@ from functools import wraps import typing as T -import yaml -from yaml.representer import Representer from faker.providers import BaseProvider as FakerProvider from dateutil.relativedelta import relativedelta import snowfakery.data_gen_exceptions as exc -from .utils.yaml_utils import SnowfakeryDumper +from snowfakery.utils.yaml_utils import register_for_continuation from .utils.collections import CaseInsensitiveDict from numbers import Number @@ -306,17 +304,7 @@ def _from_continuation(cls, args): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - _register_for_continuation(cls) - - -def _register_for_continuation(cls): - SnowfakeryDumper.add_representer(cls, Representer.represent_object) - yaml.SafeLoader.add_constructor( - f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}", - lambda loader, node: cls._from_continuation( - loader.construct_mapping(node.value[0]) - ), - ) + register_for_continuation(cls) class PluginResultIterator(PluginResult): @@ -372,4 +360,4 @@ def convert(self, value): # round-trip PluginResult objects through continuation YAML if needed. -_register_for_continuation(PluginResult) +register_for_continuation(PluginResult) diff --git a/snowfakery/row_history.py b/snowfakery/row_history.py index 5113a3bf..78fd5bf7 100644 --- a/snowfakery/row_history.py +++ b/snowfakery/row_history.py @@ -6,8 +6,8 @@ from random import randint from snowfakery import data_gen_exceptions as exc -from snowfakery.object_rows import LazyLoadedObjectReference -from snowfakery.utils.pickle import restricted_dumps, restricted_loads +from snowfakery.object_rows import LazyLoadedObjectReference, NicknameSlot, ObjectRow +from snowfakery.utils.pickle import RestrictedPickler class RowHistory: @@ -35,6 +35,7 @@ def __init__( } for table in tables_to_keep_history_for: _make_history_table(self.conn, table) + self.pickler = RestrictedPickler(_DISPATCH_TABLE, _SAFE_CLASSES) def reset_locals(self): """Reset the minimum count that counts as "local" """ @@ -58,7 +59,7 @@ def save_row(self, tablename: str, nickname: T.Optional[str], row: dict): # "join" across multiple tables would have other costs (even if done lazily). # For now this seems best and simplest. # The data de-dupling algorithm would be slightly complex and slow. - data = restricted_dumps(row) + data = self.pickler.dumps(row) self.conn.execute( f'INSERT INTO "{tablename}" VALUES (?, ?, ?, ?)', (row_id, nickname, nickname_id, data), @@ -95,8 +96,6 @@ def random_row_reference(self, name: str, scope: str, unique: bool): self.already_warned = True min_id = 1 elif nickname: - # nickname counters are reset every loop, so 1 is the right choice - # OR they are just_once in which case min_id = self.local_counters.get(nickname, 0) + 1 else: min_id = self.local_counters.get(tablename, 0) + 1 @@ -126,7 +125,7 @@ def load_row(self, tablename: str, row_id: int): first_row = next(qr, None) assert first_row, f"Something went wrong: we cannot find {tablename}: {row_id}" - return restricted_loads(first_row[0]) + return self.pickler.loads(first_row[0]) def find_row_id_for_nickname_id( self, tablename: str, nickname: str, nickname_id: int @@ -161,3 +160,26 @@ def _make_history_table(conn, tablename): c.execute( f'CREATE UNIQUE INDEX "{tablename}_nickname_id" ON "{tablename}" (nickname, nickname_id);' ) + + +_DISPATCH_TABLE = { + NicknameSlot: lambda n: ( + ObjectReference, + (n._tablename, n.allocated_id), + ), + ObjectRow: lambda v: ( + ObjectRow, + (v._tablename, v._values), + ), +} + +_SAFE_CLASSES = { + ("snowfakery.object_rows", "ObjectRow"), + ("snowfakery.object_rows", "ObjectReference"), + ("snowfakery.object_rows", "LazyLoadedObjectReference"), + ("decimal", "Decimal"), + ("datetime", "date"), + ("datetime", "datetime"), + ("datetime", "timedelta"), + ("datetime", "timezone"), +} diff --git a/snowfakery/standard_plugins/datasets.py b/snowfakery/standard_plugins/datasets.py index 51368a8b..c72a1796 100644 --- a/snowfakery/standard_plugins/datasets.py +++ b/snowfakery/standard_plugins/datasets.py @@ -17,7 +17,7 @@ memorable, ) from snowfakery.utils.files import FileLike, open_file_like -from snowfakery.utils.yaml_utils import SnowfakeryDumper +from snowfakery.utils.yaml_utils import SnowfakeryContinuationDumper def _open_db(db_url): @@ -258,4 +258,4 @@ def chdir(path): os.chdir(cwd) -SnowfakeryDumper.add_representer(quoted_name, Representer.represent_str) +SnowfakeryContinuationDumper.add_representer(quoted_name, Representer.represent_str) diff --git a/snowfakery/utils/pickle.py b/snowfakery/utils/pickle.py index ad4001c1..7439710c 100644 --- a/snowfakery/utils/pickle.py +++ b/snowfakery/utils/pickle.py @@ -7,22 +7,31 @@ import warnings -from snowfakery.object_rows import NicknameSlot, ObjectReference -_DISPATCH_TABLE = copyreg.dispatch_table.copy() -_DISPATCH_TABLE[NicknameSlot] = lambda n: ( - ObjectReference, - (n._tablename, n.allocated_id), -) +DispatchDefinition = T.Callable[[T.Any], T.Tuple[type, T.Tuple]] -def restricted_dumps(data): - """Only allow saving "safe" classes""" - outs = io.BytesIO() - pickler = pickle.Pickler(outs) - pickler.dispatch_table = _DISPATCH_TABLE - pickler.dump(data) - return outs.getvalue() +class RestrictedPickler: + def __init__( + self, + dispatchers: T.Mapping[type, DispatchDefinition], + _SAFE_CLASSES: T.Set[T.Tuple[str, str]], + ) -> None: + self._DISPATCH_TABLE = copyreg.dispatch_table.copy() + self._DISPATCH_TABLE.update(dispatchers) + self.RestrictedUnpickler = _get_RestrictedUnpicklerClass(_SAFE_CLASSES) + + def dumps(self, data): + """Only allow saving "safe" classes""" + outs = io.BytesIO() + pickler = pickle.Pickler(outs) + pickler.dispatch_table = self._DISPATCH_TABLE + pickler.dump(data) + return outs.getvalue() + + def loads(self, s): + """Helper function analogous to pickle.loads().""" + return self.RestrictedUnpickler(io.BytesIO(s)).load() class Type_Cannot_Be_Used_With_Random_Reference(T.NamedTuple): @@ -32,35 +41,24 @@ class Type_Cannot_Be_Used_With_Random_Reference(T.NamedTuple): name: str -_SAFE_CLASSES = { - ("snowfakery.object_rows", "ObjectRow"), - ("snowfakery.object_rows", "ObjectReference"), - ("snowfakery.object_rows", "LazyLoadedObjectReference"), - ("decimal", "Decimal"), - ("datetime", "date"), - ("datetime", "datetime"), - ("datetime", "timedelta"), - ("datetime", "timezone"), -} - - -class RestrictedUnpickler(pickle.Unpickler): - """Safe unpickler with an allowed-list""" - - count = 0 +def _get_RestrictedUnpicklerClass(_SAFE_CLASSES): + class RestrictedUnpickler(pickle.Unpickler): + """Safe unpickler with an allowed-list""" - def find_class(self, module, name): - # Only allow safe classes from builtins. - if (module, name) in _SAFE_CLASSES: - return super().find_class(module, name) - else: - # Return a "safe" object that does nothing. - if RestrictedUnpickler.count < 10: - warnings.warn(f"Cannot save and refer to {module}, {name}") - RestrictedUnpickler.count += 1 - return lambda *args: Type_Cannot_Be_Used_With_Random_Reference(module, name) + count = 0 + def find_class(self, module, name): + # Only allow safe classes from builtins. + if (module, name) in _SAFE_CLASSES: + return super().find_class(module, name) + else: + # warn first 10 times + if RestrictedUnpickler.count < 10: + warnings.warn(f"Cannot save and refer to {module}, {name}") + RestrictedUnpickler.count += 1 + # Return a "safe" object that does nothing. + return lambda *args: Type_Cannot_Be_Used_With_Random_Reference( + module, name + ) -def restricted_loads(s): - """Helper function analogous to pickle.loads().""" - return RestrictedUnpickler(io.BytesIO(s)).load() + return RestrictedUnpickler diff --git a/snowfakery/utils/yaml_utils.py b/snowfakery/utils/yaml_utils.py index 73a5a367..4125bffc 100644 --- a/snowfakery/utils/yaml_utils.py +++ b/snowfakery/utils/yaml_utils.py @@ -1,11 +1,44 @@ -from yaml import SafeDumper +from typing import Callable +from yaml import SafeDumper, SafeLoader +from yaml.representer import Representer +from collections import defaultdict -class SnowfakeryDumper(SafeDumper): +class SnowfakeryContinuationDumper(SafeDumper): pass +SnowfakeryContinuationDumper.add_representer( + defaultdict, SnowfakeryContinuationDumper.represent_dict +) + + def hydrate(cls, data): obj = cls.__new__(cls) - obj.__setstate__(data) + obj.restore_from_continuation(data) return obj + + +# Evaluate whether its cleaner for functions to bypass register_for_continuation +# and go directly to SnowfakeryContinuationDumper.add_representer. +# +# + + +def represent_continuation(dumper: SnowfakeryContinuationDumper, data): + if isinstance(data, dict): + return Representer.represent_dict(dumper, data) + else: + return Representer.represent_object(dumper, data) + + +def register_for_continuation(cls, dump_transformer: Callable = lambda x: x): + SnowfakeryContinuationDumper.add_representer( + cls, lambda self, data: represent_continuation(self, dump_transformer(data)) + ) + SafeLoader.add_constructor( + f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}", + lambda loader, node: cls._from_continuation( + loader.construct_mapping(node.value[0]) + ), + ) diff --git a/tests/test_references.py b/tests/test_references.py index 79a133d3..f680eafd 100644 --- a/tests/test_references.py +++ b/tests/test_references.py @@ -416,7 +416,6 @@ def test_reference_by_id(self, generated_rows): class TestRandomReferencesOriginal: - ## For reviewer: These tests were all moved and are not new def test_random_reference_simple(self, generated_rows): yaml = """ #1 - object: A #2 @@ -754,7 +753,7 @@ def test_random_reference_to_nickname__subsequent_iterations(self, generated_row nameref: ${{A_ref.name}} """ with mock.patch("snowfakery.row_history.randint") as randint: - randint.side_effect = lambda x,y: x + randint.side_effect = lambda x, y: x generate(StringIO(yaml), stopping_criteria=StoppingCriteria("B", 10)) assert generated_rows.table_values("B", 10, "A_ref") == "A(11)" assert generated_rows.table_values("B", 10, "nameref") == "nicky" @@ -913,3 +912,57 @@ def test_random_reference_to_objects_with_diverse_types(self, generated_rows): assert float(generated_rows.table_values("B", 1, "decimalref")) assert parse_date(generated_rows.table_values("B", 1, "datetime1ref")) assert parse_date(generated_rows.table_values("B", 1, "datetime2ref")) + + def test_random_references__nested(self, generated_rows): + yaml = """ + - object: Parent + count: 2 + fields: + child1: + - object: Child1 + fields: + child2: + - object: Child2 + fields: + name: TheName + - object: Child3 + fields: + A_ref: + random_reference: + to: Parent + nested_name: + ${{A_ref.child1.child2.name}} + """ + generate(StringIO(yaml)) + assert generated_rows.table_values("Child3", 1, "nested_name") == "TheName" + assert generated_rows.table_values("Child3", -1, "nested_name") == "TheName" + + def test_random_references__nested__with_continuation( + self, generate_data_with_continuation, generated_rows + ): + yaml = """ + - object: Parent + count: 2 + fields: + child1: + - object: Child1 + fields: + child2: + - object: Child2 + fields: + name: TheName + - object: Child3 + fields: + A_ref: + random_reference: + to: Parent + nested_name: + ${{A_ref.child1.child2.name}} + """ + generate_data_with_continuation( + yaml=yaml, + target_number=("Parent", 4), + times=1, + ) + assert generated_rows.table_values("Child3", 1, "nested_name") == "TheName" + assert generated_rows.table_values("Child3", -1, "nested_name") == "TheName"