Skip to content

Commit

Permalink
Merge pull request #24 from eriknw/better_exceptions
Browse files Browse the repository at this point in the history
Nicer exceptions, and let Run objects have a client.
  • Loading branch information
eriknw authored Aug 30, 2021
2 parents 1b9c1a6 + d907bdb commit c388549
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
49 changes: 29 additions & 20 deletions afar/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def get_body(lines):
class Run:
_gather_data = False

def __init__(self, *names, data=None):
def __init__(self, *names, client=None, data=None):
self.names = names
self.data = data
self.client = client
self.context_body = None
# afar.run can be used as a singleton without calling it.
# If we do this, we shouldn't keep data around.
Expand All @@ -76,13 +77,15 @@ def __init__(self, *names, data=None):
self._body_start = None
self._lines = None

def __call__(self, *names, data=None):
def __call__(self, *names, client=None, data=None):
if data is None:
if self.data is None:
data = {}
else:
data = self.data
return type(self)(*names, data=data)
if client is None:
client = self.client
return type(self)(*names, client=client, data=data)

def __enter__(self):
self._frame = currentframe().f_back
Expand Down Expand Up @@ -143,33 +146,34 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, exc_traceback):
self._where = None
if self.data is None:
if exc_type is None:
raise RuntimeError("uh oh!")
return False
if exc_type is None or exc_traceback.tb_frame is not self._frame:
return False
where = find_where(exc_type, exc_value)
if where is None:
# The exception is valid
return False

try:
return self._exit(exc_type, exc_value, exc_traceback)
except KeyboardInterrupt:
return self._exit(where)
except KeyboardInterrupt as exc:
# Cancel all pending tasks
if self._where == "remotely":
self.cancel()
raise
raise exc from None
except Exception as exc:
raise exc from None
finally:
self._frame = None
self._lines = None
if self._is_singleton:
self.data = None

def _exit(self, exc_type, exc_value, exc_traceback):
def _exit(self, where):
frame = self._frame
if self.data is None:
if exc_type is None:
raise RuntimeError("uh oh!")
return False
if exc_type is None or exc_traceback.tb_frame is not frame:
return False

where = find_where(exc_type, exc_value)
if where is None:
# The exception is valid
return False

# What line does the context end?
maxline = self._body_start
for offset, line in dis.findlinestarts(frame.f_code):
Expand All @@ -187,7 +191,7 @@ def _exit(self, exc_type, exc_value, exc_traceback):
context_body,
self.names,
self.data,
client=where.client,
client=self.client or where.client,
submit_kwargs=where.submit_kwargs,
global_ns=frame.f_globals,
local_ns=frame.f_locals,
Expand Down Expand Up @@ -221,6 +225,11 @@ def _run(
if where == "remotely":
if client is None:
client = distributed.client._get_global_client()
if client is None:
raise TypeError(
"No dask.distributed client found. "
"You must create and connect to a Dask cluster before using afar."
)
if client not in self._client_to_futures:
weak_futures = WeakSet()
self._client_to_futures[client] = weak_futures
Expand Down
5 changes: 3 additions & 2 deletions afar/_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def afar(self, line, cell=None, *, local_ns):
if "c" in opts and "client" in opts:
raise UsageError("-c and --client options may not be used at the same time")
where = remotely
client = None

not_found = "argument not found in local namespace"
if "r" in opts or "run" in opts:
Expand All @@ -84,6 +83,7 @@ def afar(self, line, cell=None, *, local_ns):
raise UsageError(f"-r or --run argument must be of type Run; got: {type(runner)}")
else:
runner = run()
client = runner.client

data = runner.data
if "d" in opts or "data" in opts:
Expand All @@ -103,7 +103,8 @@ def afar(self, line, cell=None, *, local_ns):
raise UsageError(
f"-w or --where argument must be of type Where; got: {type(where)}"
)
client = where.client
if client is None:
client = where.client

if "c" in opts or "client" in opts:
client = opts.get("c", opts.get("client"))
Expand Down

0 comments on commit c388549

Please sign in to comment.