diff --git a/.github/workflows/test_conda.yml b/.github/workflows/test_conda.yml index e0a5907..99c5988 100644 --- a/.github/workflows/test_conda.yml +++ b/.github/workflows/test_conda.yml @@ -31,8 +31,7 @@ jobs: activate-environment: afar - name: Install dependencies run: | - conda install -y -c conda-forge distributed pytest - pip install innerscope + conda install -y -c conda-forge distributed pytest innerscope pip install -e . - name: PyTest run: | diff --git a/.github/workflows/test_pip.yml b/.github/workflows/test_pip.yml index 63a70e7..4b51141 100644 --- a/.github/workflows/test_pip.yml +++ b/.github/workflows/test_pip.yml @@ -39,7 +39,7 @@ jobs: run: | pip install black flake8 flake8 . - black afar *.py --check --diff + black . --check --diff - name: Coverage env: GITHUB_TOKEN: ${{ secrets.github_token }} diff --git a/afar/__init__.py b/afar/__init__.py index c85f0e6..a5fd7cf 100644 --- a/afar/__init__.py +++ b/afar/__init__.py @@ -1,3 +1,18 @@ +"""afar runs code within a context manager or IPython magic on a Dask cluster. + +>>> with afar.run, remotely: +... import dask_cudf +... df = dask_cudf.read_parquet("s3://...") +... result = df.sum().compute() + +or to use an IPython magic: + +>>> %load_ext afar +>>> %afar z = x + y + +Read the documentation at https://github.com/eriknw/afar +""" + from ._core import get, run # noqa from ._version import get_versions from ._where import later, locally, remotely # noqa diff --git a/afar/_abra.py b/afar/_abra.py index 50eb25a..fca6b63 100644 --- a/afar/_abra.py +++ b/afar/_abra.py @@ -1,3 +1,8 @@ +"""Perform a magic trick: given lines of code, create a function to run remotely. + +This callable object is able to provide the values of the requested argument +names and return the final expression so it can be displayed. +""" import dis from types import FunctionType @@ -5,7 +10,7 @@ from innerscope import scoped_function from ._reprs import get_repr_methods -from ._utils import code_replace, is_kernel +from ._utils import code_replace, is_ipython def endswith_expr(func): @@ -85,7 +90,7 @@ def cadabra(context_body, where, names, data, global_ns, local_ns): # Create a new function from the code block of the context. # For now, we require that the source code is available. source = "def _afar_magic_():\n" + "".join(context_body) - func, display_expr = create_func(source, global_ns, is_kernel()) + func, display_expr = create_func(source, global_ns, is_ipython()) # If no variable names were given, only get the last assignment if not names: diff --git a/afar/_core.py b/afar/_core.py index 78fe8dc..130a952 100644 --- a/afar/_core.py +++ b/afar/_core.py @@ -1,64 +1,26 @@ +"""Define the user-facing `run` object; this is where it all comes together.""" import dis -from functools import partial -from inspect import currentframe, findsource +import sys +from inspect import currentframe +from uuid import uuid4 from weakref import WeakKeyDictionary, WeakSet from dask import distributed +from dask.distributed import get_worker from ._abra import cadabra -from ._printing import PrintRecorder, print_outputs, print_outputs_async -from ._reprs import repr_afar -from ._utils import is_kernel, supports_async_output +from ._inspect import get_body, get_body_start, get_lines +from ._printing import PrintRecorder +from ._reprs import display_repr, repr_afar +from ._utils import supports_async_output from ._where import find_where -def get_body_start(lines, with_start): - line = lines[with_start] - stripped = line.lstrip() - body = line[: len(line) - len(stripped)] + " pass\n" - body *= 2 - with_lines = [stripped] - try: - code = compile(stripped, "", "exec") - except Exception: - pass - else: - raise RuntimeError( - "Failed to analyze the context! When using afar, " - "please put the context body on a new line." - ) - for i, line in enumerate(lines[with_start:]): - if i > 0: - with_lines.append(line) - if ":" in line: - source = "".join(with_lines) + body - try: - code = compile(source, "", "exec") - except Exception: - pass - else: - num_with = code.co_code.count(dis.opmap["SETUP_WITH"]) - body_start = with_start + i + 1 - return num_with, body_start - raise RuntimeError("Failed to analyze the context!") - - -def get_body(lines): - head = "def f():\n with x:\n " - tail = " pass\n pass\n" - while lines: - source = head + " ".join(lines) + tail - try: - compile(source, "", "exec") - except Exception: - lines.pop() - else: - return lines - raise RuntimeError("Failed to analyze the context body!") - - class Run: _gather_data = False + # Used to update outputs asynchronously + _outputs = {} + _channel = "afar-" + uuid4().hex def __init__(self, *names, client=None, data=None): self.names = names @@ -94,36 +56,8 @@ def __enter__(self): if self.data: raise RuntimeError("uh oh!") self.data = {} - try: - lines, offset = findsource(self._frame) - except OSError: - # Try to fine the source if we are in %%time or %%timeit magic - if self._frame.f_code.co_filename in {"", ""} and is_kernel(): - from IPython import get_ipython - - ip = get_ipython() - if ip is None: - raise - cell = ip.history_manager._i00 # The current cell! - lines = cell.splitlines(keepends=True) - # strip the magic - for i, line in enumerate(lines): - if line.strip().startswith("%%time"): - lines = lines[i + 1 :] - break - else: - raise - # strip blank lines - for i, line in enumerate(lines): - if line.strip(): - if i: - lines = lines[i:] - lines[-1] += "\n" - break - else: - raise - else: - raise + + lines = get_lines(self._frame) while not lines[with_lineno].lstrip().startswith("with"): with_lineno -= 1 @@ -236,14 +170,6 @@ def _run( else: weak_futures = self._client_to_futures[client] - has_print = "print" in self._magic_func._scoped.builtin_names - capture_print = ( - self._gather_data # we're blocking anyway to gather data - or display_expr # we need to display an expression (sync or async) - or has_print # print is in the context body - or supports_async_output() # no need to block, so why not? - ) - to_scatter = data.keys() & self._magic_func._scoped.outer_scope.keys() if to_scatter: # Scatter value in `data` that we need in this calculation. @@ -261,34 +187,36 @@ def _run( data.update(scattered) for key in to_scatter: del self._magic_func._scoped.outer_scope[key] + + capture_print = True + if capture_print and self._channel not in client._event_handlers: + client.subscribe_topic(self._channel, self._handle_print) + # When would be a good time to unsubscribe? + async_print = capture_print and supports_async_output() + if capture_print: + unique_key = uuid4().hex + self._setup_print(unique_key, async_print) + else: + unique_key = None + # Scatter magic_func to avoid "Large object" UserWarning - magic_func = client.scatter(self._magic_func) + magic_func = client.scatter(self._magic_func, hash=False) weak_futures.add(magic_func) + remote_dict = client.submit( - run_afar, magic_func, names, futures, capture_print, pure=False, **submit_kwargs + run_afar, + magic_func, + names, + futures, + capture_print, + self._channel, + unique_key, + pure=False, + **submit_kwargs, ) weak_futures.add(remote_dict) magic_func.release() # Let go ASAP - if display_expr: - return_future = client.submit(get_afar, remote_dict, "_afar_return_value_") - repr_future = client.submit( - repr_afar, - return_future, - self._magic_func._repr_methods, - ) - weak_futures.add(repr_future) - if return_expr: - weak_futures.add(return_future) - else: - return_future.release() # Let go ASAP - return_future = None - else: - repr_future = None - if capture_print: - stdout_future = client.submit(get_afar, remote_dict, "_afar_stdout_") - weak_futures.add(stdout_future) - stderr_future = client.submit(get_afar, remote_dict, "_afar_stderr_") - weak_futures.add(stderr_future) + if self._gather_data: futures_to_name = { client.submit(get_afar, remote_dict, name, **submit_kwargs): name @@ -304,21 +232,6 @@ def _run( weak_futures.add(future) data[name] = future remote_dict.release() # Let go ASAP - - if capture_print and supports_async_output(): - # Display in `out` cell when data is ready: non-blocking - from IPython.display import display - from ipywidgets import Output - - out = Output() - display(out) - out.append_stdout("\N{SPARKLES} Running afar... \N{SPARKLES}") - stdout_future.add_done_callback( - partial(print_outputs_async, out, stderr_future, repr_future) - ) - elif capture_print: - # blocks! - print_outputs(stdout_future, stderr_future, repr_future) elif where == "locally": # Run locally. This is handy for testing and debugging. results = self._magic_func() @@ -352,6 +265,52 @@ def cancel(self, *, client=None, force=False): ) weak_futures.clear() + def _setup_print(self, key, async_print): + if async_print: + from IPython.display import display + from ipywidgets import Output + + out = Output() + display(out) + out.append_stdout("\N{SPARKLES} Running afar... \N{SPARKLES}") + else: + out = None + self._outputs[key] = [out, False] # False means has not been updated + + @classmethod + def _handle_print(cls, event): + # XXX: can we assume all messages from a single task arrive in FIFO order? + _, msg = event + key, action, payload = msg + if key not in cls._outputs: + return + out, is_updated = cls._outputs[key] + if out is not None: + if action == "begin": + if is_updated: + out.outputs = type(out.outputs)() + out.append_stdout("\N{SPARKLES} Running afar... (restarted) \N{SPARKLES}") + cls._outputs[key][1] = False # is not updated + else: + if not is_updated: + # Clear the "Running afar..." message + out.outputs = type(out.outputs)() + cls._outputs[key][1] = True # is updated + # ipywidgets.Output is pretty slow if there are lots of messages + if action == "stdout": + out.append_stdout(payload) + elif action == "stderr": + out.append_stderr(payload) + elif action == "stdout": + print(payload, end="") + elif action == "stderr": + print(payload, end="", file=sys.stderr) + if action == "display_expr": + display_repr(payload, out=out) + del cls._outputs[key] + elif action == "finish": + del cls._outputs[key] + class Get(Run): """Unlike ``run``, ``get`` automatically gathers the data locally""" @@ -359,25 +318,41 @@ class Get(Run): _gather_data = True -def run_afar(magic_func, names, futures, capture_print): +def run_afar(magic_func, names, futures, capture_print, channel, unique_key): if capture_print: - rec = PrintRecorder() - if "print" in magic_func._scoped.builtin_names and "print" not in futures: - sfunc = magic_func._scoped.bind(futures, print=rec) + try: + worker = get_worker() + send_finish = True + except ValueError: + worker = None + try: + if capture_print and worker is not None: + worker.log_event(channel, (unique_key, "begin", None)) + rec = PrintRecorder(channel, unique_key) + if "print" in magic_func._scoped.builtin_names and "print" not in futures: + sfunc = magic_func._scoped.bind(futures, print=rec) + else: + sfunc = magic_func._scoped.bind(futures) + with rec: + results = sfunc() else: sfunc = magic_func._scoped.bind(futures) - with rec: results = sfunc() - else: - sfunc = magic_func._scoped.bind(futures) - results = sfunc() - rv = {key: results[key] for key in names} - if magic_func._display_expr: - rv["_afar_return_value_"] = results.return_value - if capture_print: - rv["_afar_stdout_"] = rec.stdout.getvalue() - rv["_afar_stderr_"] = rec.stderr.getvalue() + rv = {key: results[key] for key in names} + + if magic_func._display_expr and worker is not None: + # Hopefully computing the repr is fast. If it is slow, perhaps it would be + # better to add the return value to rv and call repr_afar as a separate task. + # Also, pretty_repr must be msgpack serializable if done via events. Hence, + # custom _ipython_display_ doesn't work, and we resort to using a basic repr. + pretty_repr = repr_afar(results.return_value, magic_func._repr_methods) + if pretty_repr is not None: + worker.log_event(channel, (unique_key, "display_expr", pretty_repr)) + send_finish = False + finally: + if capture_print and worker is not None and send_finish: + worker.log_event(channel, (unique_key, "finish", None)) return rv diff --git a/afar/_inspect.py b/afar/_inspect.py new file mode 100644 index 0000000..dcdc77a --- /dev/null +++ b/afar/_inspect.py @@ -0,0 +1,84 @@ +"""Utilities to get the lines of the context body.""" +import dis +from inspect import findsource + +from ._utils import is_ipython + + +def get_lines(frame): + try: + lines, offset = findsource(frame) + except OSError: + # Try to fine the source if we are in %%time or %%timeit magic + if frame.f_code.co_filename in {"", ""} and is_ipython(): + from IPython import get_ipython + + ip = get_ipython() + if ip is None: + raise + cell = ip.history_manager._i00 # The current cell! + lines = cell.splitlines(keepends=True) + # strip the magic + for i, line in enumerate(lines): + if line.strip().startswith("%%time"): + lines = lines[i + 1 :] + break + else: + raise + # strip blank lines + for i, line in enumerate(lines): + if line.strip(): + if i: + lines = lines[i:] + lines[-1] += "\n" + break + else: + raise + else: + raise + return lines + + +def get_body_start(lines, with_start): + line = lines[with_start] + stripped = line.lstrip() + body = line[: len(line) - len(stripped)] + " pass\n" + body *= 2 + with_lines = [stripped] + try: + code = compile(stripped, "", "exec") + except Exception: + pass + else: + raise RuntimeError( + "Failed to analyze the context! When using afar, " + "please put the context body on a new line." + ) + for i, line in enumerate(lines[with_start:]): + if i > 0: + with_lines.append(line) + if ":" in line: + source = "".join(with_lines) + body + try: + code = compile(source, "", "exec") + except Exception: + pass + else: + num_with = code.co_code.count(dis.opmap["SETUP_WITH"]) + body_start = with_start + i + 1 + return num_with, body_start + raise RuntimeError("Failed to analyze the context!") + + +def get_body(lines): + head = "def f():\n with x:\n " + tail = " pass\n pass\n" + while lines: + source = head + " ".join(lines) + tail + try: + compile(source, "", "exec") + except Exception: + lines.pop() + else: + return lines + raise RuntimeError("Failed to analyze the context body!") diff --git a/afar/_magic.py b/afar/_magic.py index 889a183..194edfa 100644 --- a/afar/_magic.py +++ b/afar/_magic.py @@ -1,3 +1,4 @@ +"""Define the IPython magic for using afar""" from textwrap import indent from dask.distributed import Client diff --git a/afar/_printing.py b/afar/_printing.py index b8083e8..3cf42d4 100644 --- a/afar/_printing.py +++ b/afar/_printing.py @@ -1,9 +1,10 @@ +"""Classes used to capture print statements within a Dask task.""" import builtins import sys from io import StringIO from threading import Lock, local -from ._reprs import display_repr +from dask.distributed import get_worker # Here's the plan: we'll capture all print statements to stdout and stderr @@ -21,9 +22,9 @@ class PrintRecorder: local_print = LocalPrint() print_lock = Lock() - def __init__(self): - self.stdout = StringIO() - self.stderr = StringIO() + def __init__(self, channel, key): + self.channel = channel + self.key = key def __enter__(self): with self.print_lock: @@ -44,48 +45,21 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def __call__(self, *args, file=None, **kwargs): if file is None or file is sys.stdout: - file = self.stdout + file = StringIO() + stream_name = "stdout" elif file is sys.stderr: - file = self.stderr + file = StringIO() + stream_name = "stderr" + else: + stream_name = None LocalPrint.printer(*args, **kwargs, file=file) - - -def print_outputs(stdout_future, stderr_future, repr_future): - """Print results to the user""" - stdout_val = stdout_future.result() - stdout_future.release() - if stdout_val: - print(stdout_val, end="") - stderr_val = stderr_future.result() - stderr_future.release() - if stderr_val: - print(stderr_val, end="", file=sys.stderr) - if repr_future is not None: - repr_val = repr_future.result() - repr_future.release() - if repr_val is not None: - display_repr(repr_val) - - -def print_outputs_async(out, stderr_future, repr_future, stdout_future): - """Display output streams and final expression to the user. - - This is used as a callback to `stdout_future`. - """ - try: - stdout_val = stdout_future.result() - # out.clear_output() # Not thread-safe! - # See: https://github.com/jupyter-widgets/ipywidgets/issues/3260 - out.outputs = type(out.outputs)() # current workaround - if stdout_val: - out.append_stdout(stdout_val) - stderr_val = stderr_future.result() - if stderr_val: - out.append_stderr(stderr_val) - if repr_future is not None: - repr_val = repr_future.result() - if repr_val is not None: - display_repr(repr_val, out=out) - except Exception as exc: - print(exc, file=sys.stderr) - raise + if stream_name is not None: + try: + worker = get_worker() + except ValueError: + pass + else: + worker.log_event(self.channel, (self.key, stream_name, file.getvalue())) + # Print locally too + stream = sys.stdout if stream_name == "stdout" else sys.stderr + LocalPrint.printer(file.getvalue(), end="", file=stream) diff --git a/afar/_reprs.py b/afar/_reprs.py index 19ad0f9..f3ea60d 100644 --- a/afar/_reprs.py +++ b/afar/_reprs.py @@ -1,3 +1,4 @@ +"""Utilities to calculate the (pretty) repr of objects remotely and display locally.""" import sys import traceback @@ -13,7 +14,9 @@ def __init__(self): self._attrs = [] def __getattr__(self, attr): - if "canary" not in attr: + if "canary" not in attr and attr != "_ipython_display_": + # _ipython_display_ requires sending the object back to the client. + # Let's not bother with this hassle for now. self._attrs.append(attr) raise AttributeError(attr) @@ -43,6 +46,7 @@ def repr_afar(val, repr_methods): continue if method_name == "_ipython_display_": # Custom display! Send the object to the client + # We don't allow _ipython_display_ at the moment return val, method_name, False try: rv = method() @@ -99,6 +103,7 @@ def display_repr(results, out=None): from IPython.display import display if method_name == "_ipython_display_": + # We don't allow _ipython_display_ at the moment if out is None: display(val) else: diff --git a/afar/_utils.py b/afar/_utils.py index c957814..c3aec0e 100644 --- a/afar/_utils.py +++ b/afar/_utils.py @@ -1,3 +1,4 @@ +import builtins import sys from types import CodeType @@ -5,13 +6,17 @@ def is_terminal(): - if "IPython" not in sys.modules: # IPython hasn't been imported + if not is_ipython(): return False from IPython import get_ipython return type(get_ipython()).__name__ == "TerminalInteractiveShell" +def is_ipython(): + return hasattr(builtins, "__IPYTHON__") and "IPython" in sys.modules + + def supports_async_output(): if is_kernel() and not is_terminal(): try: diff --git a/pyproject.toml b/pyproject.toml index aa4949a..225abc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,3 @@ [tool.black] line-length = 100 +extend-exclude = "test_notebook.ipynb" diff --git a/requirements.txt b/requirements.txt index 17d4993..bee4088 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ innerscope -distributed +distributed >=2021.9.1 diff --git a/setup.py b/setup.py index 820fff9..b557df5 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ +from setuptools import find_packages, setup + import versioneer -from setuptools import setup, find_packages install_requires = open("requirements.txt").read().strip().split("\n") with open("README.md") as f: