From d907bdbffc38858c6c04fd384c5df5f70a02ee7f Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Mon, 30 Aug 2021 12:14:27 -0500 Subject: [PATCH] Nicer exceptions, and let Run objects have a client. --- afar/_core.py | 49 +++++++++++++++++++++++++++++-------------------- afar/_magic.py | 5 +++-- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/afar/_core.py b/afar/_core.py index 5402ea8..78fe8dc 100644 --- a/afar/_core.py +++ b/afar/_core.py @@ -60,9 +60,10 @@ def get_body(lines): class Run: _gather_data = False - def __init__(self, *names, data=None): + def __init__(self, *names, client=None, data=None): self.names = names self.data = data + self.client = client self.context_body = None # afar.run can be used as a singleton without calling it. # If we do this, we shouldn't keep data around. @@ -76,13 +77,15 @@ def __init__(self, *names, data=None): self._body_start = None self._lines = None - def __call__(self, *names, data=None): + def __call__(self, *names, client=None, data=None): if data is None: if self.data is None: data = {} else: data = self.data - return type(self)(*names, data=data) + if client is None: + client = self.client + return type(self)(*names, client=client, data=data) def __enter__(self): self._frame = currentframe().f_back @@ -143,33 +146,34 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): self._where = None + if self.data is None: + if exc_type is None: + raise RuntimeError("uh oh!") + return False + if exc_type is None or exc_traceback.tb_frame is not self._frame: + return False + where = find_where(exc_type, exc_value) + if where is None: + # The exception is valid + return False + try: - return self._exit(exc_type, exc_value, exc_traceback) - except KeyboardInterrupt: + return self._exit(where) + except KeyboardInterrupt as exc: # Cancel all pending tasks if self._where == "remotely": self.cancel() - raise + raise exc from None + except Exception as exc: + raise exc from None finally: self._frame = None self._lines = None if self._is_singleton: self.data = None - def _exit(self, exc_type, exc_value, exc_traceback): + def _exit(self, where): frame = self._frame - if self.data is None: - if exc_type is None: - raise RuntimeError("uh oh!") - return False - if exc_type is None or exc_traceback.tb_frame is not frame: - return False - - where = find_where(exc_type, exc_value) - if where is None: - # The exception is valid - return False - # What line does the context end? maxline = self._body_start for offset, line in dis.findlinestarts(frame.f_code): @@ -187,7 +191,7 @@ def _exit(self, exc_type, exc_value, exc_traceback): context_body, self.names, self.data, - client=where.client, + client=self.client or where.client, submit_kwargs=where.submit_kwargs, global_ns=frame.f_globals, local_ns=frame.f_locals, @@ -221,6 +225,11 @@ def _run( if where == "remotely": if client is None: client = distributed.client._get_global_client() + if client is None: + raise TypeError( + "No dask.distributed client found. " + "You must create and connect to a Dask cluster before using afar." + ) if client not in self._client_to_futures: weak_futures = WeakSet() self._client_to_futures[client] = weak_futures diff --git a/afar/_magic.py b/afar/_magic.py index dfe60a7..826a915 100644 --- a/afar/_magic.py +++ b/afar/_magic.py @@ -73,7 +73,6 @@ def afar(self, line, cell=None, *, local_ns): if "c" in opts and "client" in opts: raise UsageError("-c and --client options may not be used at the same time") where = remotely - client = None not_found = "argument not found in local namespace" if "r" in opts or "run" in opts: @@ -84,6 +83,7 @@ def afar(self, line, cell=None, *, local_ns): raise UsageError(f"-r or --run argument must be of type Run; got: {type(runner)}") else: runner = run() + client = runner.client data = runner.data if "d" in opts or "data" in opts: @@ -103,7 +103,8 @@ def afar(self, line, cell=None, *, local_ns): raise UsageError( f"-w or --where argument must be of type Where; got: {type(where)}" ) - client = where.client + if client is None: + client = where.client if "c" in opts or "client" in opts: client = opts.get("c", opts.get("client"))