Skip to content

Commit

Permalink
Merge pull request #5 from cspencerjones/wrapper
Browse files Browse the repository at this point in the history
Added wrapper for xarrays/dask arrays
  • Loading branch information
rabernat authored Oct 20, 2020
2 parents 074ab2f + 59c561f commit d0cd78f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
2 changes: 1 addition & 1 deletion fastjmd95/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .jmd95numba import rho, drhodt, drhods
from .jmd95wrapper import rho, drhodt, drhods
36 changes: 36 additions & 0 deletions fastjmd95/jmd95wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

import numpy as np
import dask.array as dsa
import dask
import xarray as xr

import fastjmd95.jmd95numba as jmd95numba

def _any_dask_array(*args):
return any([isinstance(a, dask.array.core.Array) for a in args])

def _any_xarray(*args):
return any([isinstance(a, xr.DataArray) for a in args])

def maybe_wrap_arrays(func):
def wrapper(*args):
if _any_dask_array(*args):
rho = dsa.map_blocks(func,*args)
elif _any_xarray(*args):
rho = xr.apply_ufunc(func,*args,output_dtypes=[float],dask='parallelized')
else:
rho = func(*args)
return rho
return wrapper

@maybe_wrap_arrays
def rho(s,t,p):
return jmd95numba.rho(s,t,p)

@maybe_wrap_arrays
def drhodt(s,t,p):
return jmd95numba.drhodt(s,t,p)

@maybe_wrap_arrays
def drhods(s,t,p):
return jmd95numba.drhods(s,t,p)
19 changes: 16 additions & 3 deletions fastjmd95/test/test_jmd95.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import dask
import dask.array
import xarray as xr

@pytest.fixture
def s_t_p():
Expand All @@ -19,6 +20,10 @@ def s_t_p():
def _chunk(*args):
return [dask.array.from_array(a, chunks=(100,)) for a in args]

def _make_xarray(*args,withdask=False):
if withdask==True:
args = _chunk(*args)
return [xr.DataArray(a, coords=[np.arange(0,len(a))], dims=["i"]) for a in args]

@pytest.fixture
def no_client():
Expand Down Expand Up @@ -52,16 +57,24 @@ def distributed_client():


all_clients = ['no_client', 'threaded_client', 'processes_client', 'distributed_client']
all_arrays = ['dask_arrays', 'xarrays']
# https://stackoverflow.com/questions/45225950/passing-yield-fixtures-as-test-parameters-with-a-temp-directory
@pytest.mark.parametrize('client', all_clients)
@pytest.mark.parametrize('array_type', all_arrays)
@pytest.mark.parametrize('function,expected',
[(rho, rho_expected),
(drhodt, drhodt_expected),
(drhods, drhods_expected)])
def test_functions(request, client, s_t_p, function, expected):
def test_functions(request, client, array_type, s_t_p, function, expected):
s, t, p = s_t_p
if client != 'no_client':
s, t, p = _chunk(s, t, p)
if array_type == 'dask_arrays':
if client != 'no_client':
s, t, p = _chunk(s, t, p)
elif array_type == 'xarrays':
if client != 'no_client':
s, t, p = _make_xarray(s, t, p, withdask=True)
else:
s, t, p = _make_xarray(s, t, p, withdask=False)
client = request.getfixturevalue(client)
actual = function(s, t, p)
np.testing.assert_allclose(actual, expected, rtol=1e-2)

0 comments on commit d0cd78f

Please sign in to comment.