Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Continuation YAML Loading #680

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions snowfakery/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .data_gen_exceptions import DataGenError
from .plugins import SnowfakeryPlugin, PluginOption

from .utils.yaml_utils import SnowfakeryDumper, hydrate
from .utils.yaml_utils import SnowfakeryContinuationDumper, hydrate
from snowfakery.standard_plugins.UniqueId import UniqueId

# This tool is essentially a three stage interpreter.
Expand Down Expand Up @@ -95,9 +95,9 @@ def load_continuation_yaml(continuation_file: OpenFileLike):
def save_continuation_yaml(continuation_data: Globals, continuation_file: OpenFileLike):
"""Save the global interpreter state from Globals into a continuation_file"""
yaml.dump(
continuation_data.__getstate__(),
continuation_data,
continuation_file,
Dumper=SnowfakeryDumper,
Dumper=SnowfakeryContinuationDumper,
)


Expand Down
31 changes: 13 additions & 18 deletions snowfakery/data_generator_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml

from .utils.template_utils import FakerTemplateLibrary
from .utils.yaml_utils import SnowfakeryDumper, hydrate
from .utils.yaml_utils import hydrate
from .row_history import RowHistory
from .template_funcs import StandardFuncs
from .data_gen_exceptions import DataGenSyntaxError, DataGenNameError
Expand All @@ -27,6 +27,7 @@
)
from snowfakery.plugins import PluginContext, SnowfakeryPlugin, ScalarTypes
from snowfakery.utils.collections import OrderedSet
from snowfakery.utils.yaml_utils import register_for_continuation

OutputStream = "snowfakery.output_streams.OutputStream"
VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition"
Expand Down Expand Up @@ -60,17 +61,15 @@ def generate_id(self, table_name: str) -> int:
def __getitem__(self, table_name: str) -> int:
return self.last_used_ids[table_name]

def __getstate__(self):
# TODO: Fix this to use the new convention of get_continuation_data
def get_continuation_state(self):
return {"last_used_ids": dict(self.last_used_ids)}

def __setstate__(self, state):
def restore_from_continuation(self, state):
self.last_used_ids = defaultdict(lambda: 0, state["last_used_ids"])
self.start_ids = {name: val + 1 for name, val in self.last_used_ids.items()}


SnowfakeryDumper.add_representer(defaultdict, SnowfakeryDumper.represent_dict)


class Dependency(NamedTuple):
table_name_from: str
table_name_to: str
Expand Down Expand Up @@ -195,29 +194,22 @@ def check_slots_filled(self):
def first_new_id(self, tablename):
return self.transients.first_new_id(tablename)

def __getstate__(self):
def serialize_dict_of_object_rows(dct):
return {k: v.__getstate__() for k, v in dct.items()}

persistent_nicknames = serialize_dict_of_object_rows(self.persistent_nicknames)
persistent_objects_by_table = serialize_dict_of_object_rows(
self.persistent_objects_by_table
)
def get_continuation_state(self):
intertable_dependencies = [
dict(v._asdict()) for v in self.intertable_dependencies
] # converts ordered-dict to dict for Python 3.6 and 3.7

state = {
"persistent_nicknames": persistent_nicknames,
"persistent_objects_by_table": persistent_objects_by_table,
"id_manager": self.id_manager.__getstate__(),
"persistent_nicknames": self.persistent_nicknames,
"persistent_objects_by_table": self.persistent_objects_by_table,
"id_manager": self.id_manager.get_continuation_state(),
"today": self.today,
"nicknames_and_tables": self.nicknames_and_tables,
"intertable_dependencies": intertable_dependencies,
}
return state

def __setstate__(self, state):
def restore_from_continuation(self, state):
def deserialize_dict_of_object_rows(dct):
return {k: hydrate(ObjectRow, v) for k, v in dct.items()}

Expand All @@ -244,6 +236,9 @@ def deserialize_dict_of_object_rows(dct):
self.reset_slots()


register_for_continuation(Globals, Globals.get_continuation_state)


class JinjaTemplateEvaluatorFactory:
def __init__(self, native_types: bool):
if native_types:
Expand Down
19 changes: 12 additions & 7 deletions snowfakery/object_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import yaml
import snowfakery # noqa
from .utils.yaml_utils import SnowfakeryDumper
from .utils.yaml_utils import register_for_continuation
from contextvars import ContextVar

IdManager = "snowfakery.data_generator_runtime.IdManager"
Expand All @@ -14,10 +14,6 @@ class ObjectRow:

Uses __getattr__ so that the template evaluator can use dot-notation."""

yaml_loader = yaml.SafeLoader
yaml_dumper = SnowfakeryDumper
yaml_tag = "!snowfakery_objectrow"

# be careful changing these slots because these objects must be serializable
# to YAML and JSON
__slots__ = ["_tablename", "_values", "_child_index"]
Expand Down Expand Up @@ -49,19 +45,28 @@ def __repr__(self):
except Exception:
return super().__repr__()

def __getstate__(self):
def get_continuation_state(self):
"""Get the state of this ObjectRow for serialization.

Do not include related ObjectRows because circular
references in serialization formats cause problems."""

# If we decided to try to serialize hierarchies, we could
# do it like this:
# * keep track of if an object has already been serialized using a
# property of the SnowfakeryContinuationDumper
# * If so, output an ObjectReference instead of an ObjectRow
values = {k: v for k, v in self._values.items() if not isinstance(v, ObjectRow)}
return {"_tablename": self._tablename, "_values": values}

def __setstate__(self, state):
def restore_from_continuation(self, state):
for slot, value in state.items():
setattr(self, slot, value)


register_for_continuation(ObjectRow, ObjectRow.get_continuation_state)


class ObjectReference(yaml.YAMLObject):
def __init__(self, tablename: str, id: int):
self._tablename = tablename
Expand Down
18 changes: 3 additions & 15 deletions snowfakery/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from functools import wraps
import typing as T

import yaml
from yaml.representer import Representer
from faker.providers import BaseProvider as FakerProvider
from dateutil.relativedelta import relativedelta

import snowfakery.data_gen_exceptions as exc
from .utils.yaml_utils import SnowfakeryDumper
from snowfakery.utils.yaml_utils import register_for_continuation
from .utils.collections import CaseInsensitiveDict

from numbers import Number
Expand Down Expand Up @@ -306,17 +304,7 @@ def _from_continuation(cls, args):

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
_register_for_continuation(cls)


def _register_for_continuation(cls):
SnowfakeryDumper.add_representer(cls, Representer.represent_object)
yaml.SafeLoader.add_constructor(
f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}",
lambda loader, node: cls._from_continuation(
loader.construct_mapping(node.value[0])
),
)
register_for_continuation(cls)


class PluginResultIterator(PluginResult):
Expand Down Expand Up @@ -372,4 +360,4 @@ def convert(self, value):


# round-trip PluginResult objects through continuation YAML if needed.
_register_for_continuation(PluginResult)
register_for_continuation(PluginResult)
34 changes: 28 additions & 6 deletions snowfakery/row_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from random import randint

from snowfakery import data_gen_exceptions as exc
from snowfakery.object_rows import LazyLoadedObjectReference
from snowfakery.utils.pickle import restricted_dumps, restricted_loads
from snowfakery.object_rows import LazyLoadedObjectReference, NicknameSlot, ObjectRow
from snowfakery.utils.pickle import RestrictedPickler


class RowHistory:
Expand Down Expand Up @@ -35,6 +35,7 @@ def __init__(
}
for table in tables_to_keep_history_for:
_make_history_table(self.conn, table)
self.pickler = RestrictedPickler(_DISPATCH_TABLE, _SAFE_CLASSES)

def reset_locals(self):
"""Reset the minimum count that counts as "local" """
Expand All @@ -58,7 +59,7 @@ def save_row(self, tablename: str, nickname: T.Optional[str], row: dict):
# "join" across multiple tables would have other costs (even if done lazily).
# For now this seems best and simplest.
# The data de-dupling algorithm would be slightly complex and slow.
data = restricted_dumps(row)
data = self.pickler.dumps(row)
self.conn.execute(
f'INSERT INTO "{tablename}" VALUES (?, ?, ?, ?)',
(row_id, nickname, nickname_id, data),
Expand Down Expand Up @@ -95,8 +96,6 @@ def random_row_reference(self, name: str, scope: str, unique: bool):
self.already_warned = True
min_id = 1
elif nickname:
# nickname counters are reset every loop, so 1 is the right choice
# OR they are just_once in which case
min_id = self.local_counters.get(nickname, 0) + 1
else:
min_id = self.local_counters.get(tablename, 0) + 1
Expand Down Expand Up @@ -126,7 +125,7 @@ def load_row(self, tablename: str, row_id: int):
first_row = next(qr, None)
assert first_row, f"Something went wrong: we cannot find {tablename}: {row_id}"

return restricted_loads(first_row[0])
return self.pickler.loads(first_row[0])

def find_row_id_for_nickname_id(
self, tablename: str, nickname: str, nickname_id: int
Expand Down Expand Up @@ -161,3 +160,26 @@ def _make_history_table(conn, tablename):
c.execute(
f'CREATE UNIQUE INDEX "{tablename}_nickname_id" ON "{tablename}" (nickname, nickname_id);'
)


_DISPATCH_TABLE = {
NicknameSlot: lambda n: (
ObjectReference,
(n._tablename, n.allocated_id),
),
ObjectRow: lambda v: (
ObjectRow,
(v._tablename, v._values),
),
}

_SAFE_CLASSES = {
("snowfakery.object_rows", "ObjectRow"),
("snowfakery.object_rows", "ObjectReference"),
("snowfakery.object_rows", "LazyLoadedObjectReference"),
("decimal", "Decimal"),
("datetime", "date"),
("datetime", "datetime"),
("datetime", "timedelta"),
("datetime", "timezone"),
}
4 changes: 2 additions & 2 deletions snowfakery/standard_plugins/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
memorable,
)
from snowfakery.utils.files import FileLike, open_file_like
from snowfakery.utils.yaml_utils import SnowfakeryDumper
from snowfakery.utils.yaml_utils import SnowfakeryContinuationDumper


def _open_db(db_url):
Expand Down Expand Up @@ -258,4 +258,4 @@ def chdir(path):
os.chdir(cwd)


SnowfakeryDumper.add_representer(quoted_name, Representer.represent_str)
SnowfakeryContinuationDumper.add_representer(quoted_name, Representer.represent_str)
82 changes: 40 additions & 42 deletions snowfakery/utils/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,31 @@

import warnings

from snowfakery.object_rows import NicknameSlot, ObjectReference

_DISPATCH_TABLE = copyreg.dispatch_table.copy()
_DISPATCH_TABLE[NicknameSlot] = lambda n: (
ObjectReference,
(n._tablename, n.allocated_id),
)
DispatchDefinition = T.Callable[[T.Any], T.Tuple[type, T.Tuple]]


def restricted_dumps(data):
"""Only allow saving "safe" classes"""
outs = io.BytesIO()
pickler = pickle.Pickler(outs)
pickler.dispatch_table = _DISPATCH_TABLE
pickler.dump(data)
return outs.getvalue()
class RestrictedPickler:
def __init__(
self,
dispatchers: T.Mapping[type, DispatchDefinition],
_SAFE_CLASSES: T.Set[T.Tuple[str, str]],
) -> None:
self._DISPATCH_TABLE = copyreg.dispatch_table.copy()
self._DISPATCH_TABLE.update(dispatchers)
self.RestrictedUnpickler = _get_RestrictedUnpicklerClass(_SAFE_CLASSES)

def dumps(self, data):
"""Only allow saving "safe" classes"""
outs = io.BytesIO()
pickler = pickle.Pickler(outs)
pickler.dispatch_table = self._DISPATCH_TABLE
pickler.dump(data)
return outs.getvalue()

def loads(self, s):
"""Helper function analogous to pickle.loads()."""
return self.RestrictedUnpickler(io.BytesIO(s)).load()


class Type_Cannot_Be_Used_With_Random_Reference(T.NamedTuple):
Expand All @@ -32,35 +41,24 @@ class Type_Cannot_Be_Used_With_Random_Reference(T.NamedTuple):
name: str


_SAFE_CLASSES = {
("snowfakery.object_rows", "ObjectRow"),
("snowfakery.object_rows", "ObjectReference"),
("snowfakery.object_rows", "LazyLoadedObjectReference"),
("decimal", "Decimal"),
("datetime", "date"),
("datetime", "datetime"),
("datetime", "timedelta"),
("datetime", "timezone"),
}


class RestrictedUnpickler(pickle.Unpickler):
"""Safe unpickler with an allowed-list"""

count = 0
def _get_RestrictedUnpicklerClass(_SAFE_CLASSES):
class RestrictedUnpickler(pickle.Unpickler):
"""Safe unpickler with an allowed-list"""

def find_class(self, module, name):
# Only allow safe classes from builtins.
if (module, name) in _SAFE_CLASSES:
return super().find_class(module, name)
else:
# Return a "safe" object that does nothing.
if RestrictedUnpickler.count < 10:
warnings.warn(f"Cannot save and refer to {module}, {name}")
RestrictedUnpickler.count += 1
return lambda *args: Type_Cannot_Be_Used_With_Random_Reference(module, name)
count = 0

def find_class(self, module, name):
# Only allow safe classes from builtins.
if (module, name) in _SAFE_CLASSES:
return super().find_class(module, name)
else:
# warn first 10 times
if RestrictedUnpickler.count < 10:
warnings.warn(f"Cannot save and refer to {module}, {name}")
RestrictedUnpickler.count += 1
# Return a "safe" object that does nothing.
return lambda *args: Type_Cannot_Be_Used_With_Random_Reference(
module, name
)

def restricted_loads(s):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(s)).load()
return RestrictedUnpickler
Loading