Skip to content

Commit

Permalink
Merge pull request #5 from eriknw/pickle
Browse files Browse the repository at this point in the history
Try to pickle nicely
  • Loading branch information
eriknw authored Jul 21, 2021
2 parents bbfe0f4 + ed0414c commit 02aedbc
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 46 deletions.
128 changes: 82 additions & 46 deletions afar/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, *names, data=None):
self._frame = None
# For now, save the following to help debug
self._where = None
self._scoped = None
self._magic_func = None
self._body_start = None
self._lines = None

Expand Down Expand Up @@ -184,61 +184,24 @@ def _exit(self, exc_type, exc_value, exc_traceback):
endline = maxline + 5 # give us some wiggle room

self.context_body = get_body(self._lines[self._body_start : endline])

# Create a new function from the code block of the context.
# For now, we require that the source code is available.
# There may be a more reliable way to get the context block,
# but let's see how far this can take us!
source = "def _magic_function_():\n" + "".join(self.context_body)
c = compile(
source,
frame.f_code.co_filename,
"exec",
)
d = {}
exec(c, frame.f_globals, d)
self._func = d["_magic_function_"]

# If no variable names were given, only get the last assignment
names = self.names
if not names:
for inst in list(dis.get_instructions(self._func)):
if inst.opname in {"STORE_NAME", "STORE_FAST", "STORE_DEREF", "STORE_GLOBAL"}:
names = (inst.argval,)

# Use innerscope! We only keep the globals, locals, and closures we need.
self._scoped = innerscope.scoped_function(self._func, self.data)
if self._scoped.missing:
# Gather the necessary closures and locals
f_locals = frame.f_locals
update = {key: f_locals[key] for key in self._scoped.missing if key in f_locals}
self._scoped = self._scoped.bind(update)
self._magic_func, names, futures = abracadabra(self)

if self._where == "remotely":
# Submit to dask.distributed! First, separate the Futures.
futures = {
key: val
for key, val in self._scoped.outer_scope.items()
if isinstance(val, distributed.Future)
}
for key in futures:
del self._scoped.outer_scope[key]

client = distributed.client._get_global_client()
remote_dict = client.submit(afar_run, self._scoped, names, futures, **submit_kwargs)
remote_dict = client.submit(run_afar, self._magic_func, names, futures, **submit_kwargs)
if self._gather_data:
futures_to_name = {
client.submit(afar_get, remote_dict, name, **submit_kwargs): name
client.submit(get_afar, remote_dict, name, **submit_kwargs): name
for name in names
}
for future, result in distributed.as_completed(futures_to_name, with_results=True):
self.data[futures_to_name[future]] = result
else:
for name in names:
self.data[name] = client.submit(afar_get, remote_dict, name, **submit_kwargs)
self.data[name] = client.submit(get_afar, remote_dict, name, **submit_kwargs)
elif self._where == "locally":
# Run locally. This is handy for testing and debugging.
results = self._scoped()
results = self._magic_func()
for name in names:
self.data[name] = results[name]
elif self._where == "later":
Expand All @@ -258,13 +221,86 @@ class Get(Run):
_gather_data = True


def afar_run(sfunc, names, futures):
sfunc = sfunc.bind(futures)
def abracadabra(runner):
# Create a new function from the code block of the context.
# For now, we require that the source code is available.
source = "def _afar_magic_():\n" + "".join(runner.context_body)
code = compile(
source,
"<afar>",
"exec",
)
local_dict = {}
exec(code, runner._frame.f_globals, local_dict)
func = local_dict["_afar_magic_"]

# If no variable names were given, only get the last assignment
names = runner.names
if not names:
for inst in list(dis.get_instructions(func)):
if inst.opname in {"STORE_NAME", "STORE_FAST", "STORE_DEREF", "STORE_GLOBAL"}:
names = (inst.argval,)

# Use innerscope! We only keep the globals, locals, and closures we need.
scoped = innerscope.scoped_function(func, runner.data)
if scoped.missing:
# Gather the necessary closures and locals
f_locals = runner._frame.f_locals
update = {key: f_locals[key] for key in scoped.missing if key in f_locals}
scoped = scoped.bind(update)

if runner._where == "remotely":
# Get ready to submit to dask.distributed by separating the Futures.
futures = {
key: val
for key, val in scoped.outer_scope.items()
if isinstance(val, distributed.Future)
}
for key in futures:
del scoped.outer_scope[key]
else:
futures = None
magic_func = MagicFunction(source, scoped)
return magic_func, names, futures


class MagicFunction:
def __init__(self, source, scoped):
self._source = source
self._scoped = scoped

def __call__(self):
return self._scoped()

def __getstate__(self):
# Instead of trying to serialize the function we created with `compile` and `exec`,
# let's save the source and recreate the function (and self._scoped) again.
state = dict(self.__dict__)
del state["_scoped"]
state["outer_scope"] = self._scoped.outer_scope
return state

def __setstate__(self, state):
outer_scope = state.pop("outer_scope")
self.__dict__.update(state)
code = compile(
self._source,
"<afar>",
"exec",
)
local_dict = {}
exec(code, outer_scope, local_dict)
func = local_dict["_afar_magic_"]
self._scoped = innerscope.scoped_function(func, outer_scope)


def run_afar(magic_func, names, futures):
sfunc = magic_func._scoped.bind(futures)
results = sfunc()
return {key: results[key] for key in names}


def afar_get(d, k):
def get_afar(d, k):
return d[k]


Expand Down
13 changes: 13 additions & 0 deletions afar/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import afar
import pickle
import pytest
from pytest import raises

Expand Down Expand Up @@ -98,6 +99,18 @@ def test_endline():
# fmt: on


def test_pickle():
run = afar.run()
with run, locally:
a = 1
assert run.data == {"a": 1}
func = run._magic_func
s = pickle.dumps(func)
func2 = pickle.loads(s)
assert dict(func2()) == {"a": 1}
assert func._scoped.func.__code__.co_code == func2._scoped.func.__code__.co_code


def test_end_of_file():
data = {}
end_of_file(data)
Expand Down

0 comments on commit 02aedbc

Please sign in to comment.