diff --git a/fastjmd95/__init__.py b/fastjmd95/__init__.py index 61e3d9e..1d937da 100644 --- a/fastjmd95/__init__.py +++ b/fastjmd95/__init__.py @@ -1 +1 @@ -from .jmd95numba import rho, drhodt, drhods +from .jmd95wrapper import rho, drhodt, drhods \ No newline at end of file diff --git a/fastjmd95/jmd95wrapper.py b/fastjmd95/jmd95wrapper.py new file mode 100644 index 0000000..5e25e05 --- /dev/null +++ b/fastjmd95/jmd95wrapper.py @@ -0,0 +1,36 @@ + +import numpy as np +import dask.array as dsa +import dask +import xarray as xr + +import fastjmd95.jmd95numba as jmd95numba + +def _any_dask_array(*args): + return any([isinstance(a, dask.array.core.Array) for a in args]) + +def _any_xarray(*args): + return any([isinstance(a, xr.DataArray) for a in args]) + +def maybe_wrap_arrays(func): + def wrapper(*args): + if _any_dask_array(*args): + rho = dsa.map_blocks(func,*args) + elif _any_xarray(*args): + rho = xr.apply_ufunc(func,*args,output_dtypes=[float],dask='parallelized') + else: + rho = func(*args) + return rho + return wrapper + +@maybe_wrap_arrays +def rho(s,t,p): + return jmd95numba.rho(s,t,p) + +@maybe_wrap_arrays +def drhodt(s,t,p): + return jmd95numba.drhodt(s,t,p) + +@maybe_wrap_arrays +def drhods(s,t,p): + return jmd95numba.drhods(s,t,p) diff --git a/fastjmd95/test/test_jmd95.py b/fastjmd95/test/test_jmd95.py index 4147661..0f3535a 100644 --- a/fastjmd95/test/test_jmd95.py +++ b/fastjmd95/test/test_jmd95.py @@ -7,6 +7,7 @@ import dask import dask.array +import xarray as xr @pytest.fixture def s_t_p(): @@ -19,6 +20,10 @@ def s_t_p(): def _chunk(*args): return [dask.array.from_array(a, chunks=(100,)) for a in args] +def _make_xarray(*args,withdask=False): + if withdask==True: + args = _chunk(*args) + return [xr.DataArray(a, coords=[np.arange(0,len(a))], dims=["i"]) for a in args] @pytest.fixture def no_client(): @@ -52,16 +57,24 @@ def distributed_client(): all_clients = ['no_client', 'threaded_client', 'processes_client', 'distributed_client'] +all_arrays = ['dask_arrays', 'xarrays'] # https://stackoverflow.com/questions/45225950/passing-yield-fixtures-as-test-parameters-with-a-temp-directory @pytest.mark.parametrize('client', all_clients) +@pytest.mark.parametrize('array_type', all_arrays) @pytest.mark.parametrize('function,expected', [(rho, rho_expected), (drhodt, drhodt_expected), (drhods, drhods_expected)]) -def test_functions(request, client, s_t_p, function, expected): +def test_functions(request, client, array_type, s_t_p, function, expected): s, t, p = s_t_p - if client != 'no_client': - s, t, p = _chunk(s, t, p) + if array_type == 'dask_arrays': + if client != 'no_client': + s, t, p = _chunk(s, t, p) + elif array_type == 'xarrays': + if client != 'no_client': + s, t, p = _make_xarray(s, t, p, withdask=True) + else: + s, t, p = _make_xarray(s, t, p, withdask=False) client = request.getfixturevalue(client) actual = function(s, t, p) np.testing.assert_allclose(actual, expected, rtol=1e-2)