diff --git a/README.md b/README.md
index 64c7770..68958d0 100644
--- a/README.md
+++ b/README.md
@@ -8,25 +8,33 @@
> _Robert A. Heinlein_
-`afar` explores new syntax around context managers. For example:
+`afar` allows you to run code on a remote [Dask](https://dask.org/) [worker](https://distributed.dask.org/en/latest/) using context managers. For example:
```python
import afar
-with afar.run() as results, locally:
- x = 1
- y = x + 1
->>> results.x
-1
->>> results.y
-2
-```
-Soon, we will be able to run code on a remote [dask](https://dask.org/) worker with syntax like:
-```python
-with afar.run() as result, remotely:
+with afar.run, remotely:
import dask_cudf
df = dask_cudf.read_parquet("s3://...")
result = df.sum().compute()
```
+`result` is a [Dask Future](https://docs.dask.org/en/latest/futures.html) whose data resides on a worker. `result.result()` is necessary to copy the data locally.
+
+By default, only the last assignment is saved. One can specify which variables to save:
+```python
+with afar.run("a", "b"), remotely:
+ a = 1
+ b = a + 1
+```
+`a` and `b` are now both Futures. They can be used directly in other `afar.run` contexts:
+```python
+with afar.run as data, remotely:
+ c = a + b
+
+assert c.result() == 3
+assert data["c"].result() == 3
+```
+`data` is a dictionary of variable names to Futures. It may be necessary at times to get the data from here.
+
For motivation, see https://github.com/dask/distributed/issues/4003
-### *This code is highly experimental and magical!*
+### *This code is highly experimental and magical!*
\ No newline at end of file
diff --git a/afar/__init__.py b/afar/__init__.py
index ae546ea..922c773 100644
--- a/afar/__init__.py
+++ b/afar/__init__.py
@@ -1,5 +1,6 @@
from .core import run, remotely, locally # noqa
from ._version import get_versions
-__version__ = get_versions()['version']
+
+__version__ = get_versions()["version"]
del get_versions
diff --git a/afar/_version.py b/afar/_version.py
index ebd65e9..fae6886 100644
--- a/afar/_version.py
+++ b/afar/_version.py
@@ -1,4 +1,3 @@
-
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
@@ -58,17 +57,18 @@ class NotThisMethod(Exception):
def register_vcs_handler(vcs, method): # decorator
"""Create decorator to mark a method as the handler of a VCS."""
+
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
+
return decorate
-def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
- env=None):
+def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
p = None
@@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
try:
dispcmd = str([c] + args)
# remember shell=False, so use git.cmd on windows, not just git
- p = subprocess.Popen([c] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None))
+ p = subprocess.Popen(
+ [c] + args,
+ cwd=cwd,
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr else None),
+ )
break
except EnvironmentError:
e = sys.exc_info()[1]
@@ -114,16 +117,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
for i in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
- return {"version": dirname[len(parentdir_prefix):],
- "full-revisionid": None,
- "dirty": False, "error": None, "date": None}
+ return {
+ "version": dirname[len(parentdir_prefix) :],
+ "full-revisionid": None,
+ "dirty": False,
+ "error": None,
+ "date": None,
+ }
else:
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
- print("Tried directories %s but none started with prefix %s" %
- (str(rootdirs), parentdir_prefix))
+ print(
+ "Tried directories %s but none started with prefix %s"
+ % (str(rootdirs), parentdir_prefix)
+ )
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@@ -183,7 +192,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
- tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
+ tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)])
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
@@ -192,7 +201,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
- tags = set([r for r in refs if re.search(r'\d', r)])
+ tags = set([r for r in refs if re.search(r"\d", r)])
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
@@ -200,19 +209,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
- r = ref[len(tag_prefix):]
+ r = ref[len(tag_prefix) :]
if verbose:
print("picking %s" % r)
- return {"version": r,
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": None,
- "date": date}
+ return {
+ "version": r,
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False,
+ "error": None,
+ "date": date,
+ }
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
- return {"version": "0+unknown",
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": "no suitable tags", "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False,
+ "error": "no suitable tags",
+ "date": None,
+ }
@register_vcs_handler("git", "pieces_from_vcs")
@@ -227,8 +243,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
- out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
+ out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
@@ -236,10 +251,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
- "--always", "--long",
- "--match", "%s*" % tag_prefix],
- cwd=root)
+ describe_out, rc = run_command(
+ GITS,
+ ["describe", "--tags", "--dirty", "--always", "--long", "--match", "%s*" % tag_prefix],
+ cwd=root,
+ )
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
@@ -262,17 +278,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
- git_describe = git_describe[:git_describe.rindex("-dirty")]
+ git_describe = git_describe[: git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
- mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
+ mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparseable. Maybe git-describe is misbehaving?
- pieces["error"] = ("unable to parse git-describe output: '%s'"
- % describe_out)
+ pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
# tag
@@ -281,10 +296,9 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
- pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
- % (full_tag, tag_prefix))
+ pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (full_tag, tag_prefix)
return pieces
- pieces["closest-tag"] = full_tag[len(tag_prefix):]
+ pieces["closest-tag"] = full_tag[len(tag_prefix) :]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
@@ -295,13 +309,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
- count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
- cwd=root)
+ count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
pieces["distance"] = int(count_out) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
- cwd=root)[0].strip()
+ date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
@@ -335,8 +347,7 @@ def render_pep440(pieces):
rendered += ".dirty"
else:
# exception #1
- rendered = "0+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
@@ -450,11 +461,13 @@ def render_git_describe_long(pieces):
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
- return {"version": "unknown",
- "full-revisionid": pieces.get("long"),
- "dirty": None,
- "error": pieces["error"],
- "date": None}
+ return {
+ "version": "unknown",
+ "full-revisionid": pieces.get("long"),
+ "dirty": None,
+ "error": pieces["error"],
+ "date": None,
+ }
if not style or style == "default":
style = "pep440" # the default
@@ -474,9 +487,13 @@ def render(pieces, style):
else:
raise ValueError("unknown style '%s'" % style)
- return {"version": rendered, "full-revisionid": pieces["long"],
- "dirty": pieces["dirty"], "error": None,
- "date": pieces.get("date")}
+ return {
+ "version": rendered,
+ "full-revisionid": pieces["long"],
+ "dirty": pieces["dirty"],
+ "error": None,
+ "date": pieces.get("date"),
+ }
def get_versions():
@@ -490,8 +507,7 @@ def get_versions():
verbose = cfg.verbose
try:
- return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
- verbose)
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
except NotThisMethod:
pass
@@ -500,13 +516,16 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
- for i in cfg.versionfile_source.split('/'):
+ for i in cfg.versionfile_source.split("/"):
root = os.path.dirname(root)
except NameError:
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to find root of source tree",
- "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to find root of source tree",
+ "date": None,
+ }
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
@@ -520,6 +539,10 @@ def get_versions():
except NotThisMethod:
pass
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to compute version", "date": None}
+ return {
+ "version": "0+unknown",
+ "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to compute version",
+ "date": None,
+ }
diff --git a/afar/core.py b/afar/core.py
index 915ec5b..a4845ad 100644
--- a/afar/core.py
+++ b/afar/core.py
@@ -1,6 +1,8 @@
import dis
import inspect
import innerscope
+from operator import getitem
+from dask import distributed
_errors_to_locations = {}
@@ -34,15 +36,40 @@ def __exit__(self, exc_type, exc_value, exc_traceback): # pragma: no cover
locally = Where("locally")
-class run:
+class Run:
+ def __init__(self, *names):
+ self.names = names
+ self._results = None
+ self._frame = None
+ # For now, save the following to help debug
+ self._where = None
+ self._scoped = None
+ self._with_lineno = None
+
+ def __call__(self, *names):
+ return Run(*names)
+
def __enter__(self):
self._frame = inspect.currentframe().f_back
self._with_lineno = self._frame.f_lineno
- return self
+ if self._results is not None:
+ raise RuntimeError("uh oh!")
+ self._results = {}
+ return self._results
def __exit__(self, exc_type, exc_value, exc_traceback):
+ try:
+ return self._exit(exc_type, exc_value, exc_traceback)
+ finally:
+ self._frame = None
+ self._results = None
+
+ def _exit(self, exc_type, exc_value, exc_traceback):
frame = self._frame
- self._frame = None
+ if self._results is None:
+ if exc_type is None:
+ raise RuntimeError("uh oh!")
+ return False
if exc_type is None or exc_traceback.tb_frame is not frame:
return False
@@ -86,6 +113,13 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
exec(c, frame.f_globals, d)
self._func = d["_magic_function_"]
+ # If no variable names were given, only get the last assignment
+ names = self.names
+ if not names:
+ for inst in list(dis.get_instructions(self._func)):
+ if inst.opname in {"STORE_NAME", "STORE_FAST", "STORE_DEREF", "STORE_GLOBAL"}:
+ names = (inst.argval,)
+
# Use innerscope! We only keep the globals, locals, and closures we need.
self._scoped = innerscope.scoped_function(self._func)
if self._scoped.missing:
@@ -93,7 +127,37 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
f_locals = frame.f_locals
update = {key: f_locals[key] for key in self._scoped.missing if key in f_locals}
self._scoped = self._scoped.bind(update)
- # For now, just run the function and set the attributes
- results = self._scoped()
- self.__dict__.update(results.inner_scope)
+
+ if self._where == "remotely":
+ # Submit to dask.distributed! First, separate the Futures.
+ futures = {
+ key: val
+ for key, val in self._scoped.outer_scope.items()
+ if isinstance(val, distributed.Future)
+ }
+ for key in futures:
+ del self._scoped.outer_scope[key]
+
+ client = distributed.client._get_global_client()
+ remote_dict = client.submit(run_on_worker, self._scoped, names, futures)
+ for name in names:
+ self._results[name] = client.submit(getitem, remote_dict, name)
+ else:
+ # Run locally. This is handy for testing and debugging.
+ results = self._scoped()
+ for name in names:
+ self._results[name] = results[name]
+
+ # Try to update the variables in the frame.
+ # This currently only works if f_locals is f_globals, or if tracing (don't ask).
+ frame.f_locals.update(self._results)
return True
+
+
+def run_on_worker(sfunc, names, futures):
+ sfunc = sfunc.bind(futures)
+ results = sfunc()
+ return {key: results[key] for key in names}
+
+
+run = Run()
diff --git a/afar/tests/test_core.py b/afar/tests/test_core.py
index c42a800..f7a7943 100644
--- a/afar/tests/test_core.py
+++ b/afar/tests/test_core.py
@@ -4,10 +4,11 @@
def test_a_modest_beginning():
- with afar.run(), remotely:
- pass
+ with afar.run(), locally:
+ x = 1
+ y = x + 1
- with afar.run(), afar.remotely:
+ with afar.run(), afar.locally:
pass
with afar.run(), locally:
@@ -16,8 +17,8 @@ def test_a_modest_beginning():
with afar.run(), afar.locally:
pass
- with raises(NameError, match="remotelyblah"):
- with afar.run(), remotelyblah:
+ with raises(NameError, match="locallyblah"):
+ with afar.run(), locallyblah:
pass
with afar.run():
@@ -36,16 +37,17 @@ def f():
return results
results = f()
- assert results.x == 1
- assert results.y == 12
+ assert "x" not in results
+ assert results["y"] == 12
assert not hasattr(results, "w")
assert not hasattr(results, "z")
- with afar.run() as results, afar.locally:
+ with afar.run as results, afar.locally:
x = z
y = x + 1
- assert results.x == 1
- assert results.y == 2
+ with raises(UnboundLocalError):
+ x
+ assert results == {"y": 2}
# fmt: off
with \
@@ -54,6 +56,5 @@ def f():
:
x = z
y = x + 1
- assert results.x == 1
- assert results.y == 2
+ assert results == {'y': 2}
# fmt: on
diff --git a/afar/tests/test_remotely.py b/afar/tests/test_remotely.py
new file mode 100644
index 0000000..f10f72f
--- /dev/null
+++ b/afar/tests/test_remotely.py
@@ -0,0 +1,13 @@
+import afar
+from operator import add
+from dask.distributed import Client
+
+# TODO: better testing infrastructure
+if __name__ == "__main__":
+ client = Client()
+ two = client.submit(add, 1, 1)
+
+ with afar.run, remotely:
+ three = two + 1
+
+ assert three.result() == 3
diff --git a/setup.py b/setup.py
index a0597a0..0cb6cc1 100644
--- a/setup.py
+++ b/setup.py
@@ -18,7 +18,7 @@
license="BSD",
python_requires=">=3.7",
setup_requires=[],
- install_requires=["innerscope"],
+ install_requires=["innerscope", "distributed"],
tests_require=["pytest"],
include_package_data=True,
classifiers=[