From cdfdbf587f785ccf300e49b53975229d2d9a1b3b Mon Sep 17 00:00:00 2001 From: Spencer Jones Date: Mon, 12 Oct 2020 16:12:59 -0400 Subject: [PATCH 1/4] Added wrapper for xarrays/dask arrays --- fastjmd95/__init__.py | 1 + fastjmd95/jmd95wrapper.py | 47 ++++++++++++++++++++++++++++++++++++ fastjmd95/test/test_jmd95.py | 8 +++--- 3 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 fastjmd95/jmd95wrapper.py diff --git a/fastjmd95/__init__.py b/fastjmd95/__init__.py index 61e3d9e..3dfb698 100644 --- a/fastjmd95/__init__.py +++ b/fastjmd95/__init__.py @@ -1 +1,2 @@ from .jmd95numba import rho, drhodt, drhods +from .jmd95wrapper import rho_x, drhodt_x, drhods_x \ No newline at end of file diff --git a/fastjmd95/jmd95wrapper.py b/fastjmd95/jmd95wrapper.py new file mode 100644 index 0000000..e11706e --- /dev/null +++ b/fastjmd95/jmd95wrapper.py @@ -0,0 +1,47 @@ + +import numpy as np +import dask.array as dsa +import dask +import xarray as xr + +import fastjmd95.jmd95numba as jmd95numba + + +def rho_x(s, t, p): + if any([isinstance(s,dask.array.core.Array), + isinstance(t,dask.array.core.Array), + isinstance(p,dask.array.core.Array)]): + rho = dsa.map_blocks(jmd95numba.rho,s,t,p) + elif any([[isinstance(s,xr.DataArray), + isinstance(t,xr.DataArray), + isinstance(p,xr.DataArray)]]): + rho = xr.apply_ufunc(jmd95numba.rho,s,t,p,dask='allowed') + else: + rho = jmd95numba.rho(s,t,p) + return rho + +def drhodt_x(s, t, p): + if any([isinstance(s,dask.array.core.Array), + isinstance(t,dask.array.core.Array), + isinstance(p,dask.array.core.Array)]): + drhodt = dsa.map_blocks(jmd95numba.drhodt,s,t,p) + elif any([[isinstance(s,xr.DataArray), + isinstance(t,xr.DataArray), + isinstance(p,xr.DataArray)]]): + drhodt = xr.apply_ufunc(jmd95numba.drhodt,s,t,p,dask='allowed') + else: + drhodt = jmd95numba.rho(s,t,p) + return drhodt + +def drhods_x(s, t, p): + if any([isinstance(s,dask.array.core.Array), + isinstance(t,dask.array.core.Array), + isinstance(p,dask.array.core.Array)]): + drhods = dsa.map_blocks(jmd95numba.drhods,s,t,p) + elif any([[isinstance(s,xr.DataArray), + isinstance(t,xr.DataArray), + isinstance(p,xr.DataArray)]]): + drhods = xr.apply_ufunc(jmd95numba.drhods,s,t,p,dask='allowed') + else: + drhods = jmd95numba.rho(s,t,p) + return drhods \ No newline at end of file diff --git a/fastjmd95/test/test_jmd95.py b/fastjmd95/test/test_jmd95.py index 4147661..3d35b41 100644 --- a/fastjmd95/test/test_jmd95.py +++ b/fastjmd95/test/test_jmd95.py @@ -2,7 +2,7 @@ from itertools import product import pytest -from fastjmd95 import rho, drhodt, drhods +from fastjmd95 import rho_x, drhodt_x, drhods_x from .reference_values import rho_expected, drhodt_expected, drhods_expected import dask @@ -55,9 +55,9 @@ def distributed_client(): # https://stackoverflow.com/questions/45225950/passing-yield-fixtures-as-test-parameters-with-a-temp-directory @pytest.mark.parametrize('client', all_clients) @pytest.mark.parametrize('function,expected', - [(rho, rho_expected), - (drhodt, drhodt_expected), - (drhods, drhods_expected)]) + [(rho_x, rho_expected), + (drhodt_x, drhodt_expected), + (drhods_x, drhods_expected)]) def test_functions(request, client, s_t_p, function, expected): s, t, p = s_t_p if client != 'no_client': From d2cfebef91a03e9b1f818d71c522eb03754d4b39 Mon Sep 17 00:00:00 2001 From: Spencer Jones Date: Tue, 13 Oct 2020 09:55:26 -0400 Subject: [PATCH 2/4] code refactor & rename functions --- fastjmd95/__init__.py | 3 +- fastjmd95/jmd95wrapper.py | 67 +++++++++++++++++------------------- fastjmd95/test/test_jmd95.py | 8 ++--- 3 files changed, 36 insertions(+), 42 deletions(-) diff --git a/fastjmd95/__init__.py b/fastjmd95/__init__.py index 3dfb698..1d937da 100644 --- a/fastjmd95/__init__.py +++ b/fastjmd95/__init__.py @@ -1,2 +1 @@ -from .jmd95numba import rho, drhodt, drhods -from .jmd95wrapper import rho_x, drhodt_x, drhods_x \ No newline at end of file +from .jmd95wrapper import rho, drhodt, drhods \ No newline at end of file diff --git a/fastjmd95/jmd95wrapper.py b/fastjmd95/jmd95wrapper.py index e11706e..eafb00b 100644 --- a/fastjmd95/jmd95wrapper.py +++ b/fastjmd95/jmd95wrapper.py @@ -6,42 +6,37 @@ import fastjmd95.jmd95numba as jmd95numba +def _any_dask_array(a,b,c): + any_dask = any([isinstance(a,dask.array.core.Array), + isinstance(b,dask.array.core.Array), + isinstance(c,dask.array.core.Array)]) + return any_dask -def rho_x(s, t, p): - if any([isinstance(s,dask.array.core.Array), - isinstance(t,dask.array.core.Array), - isinstance(p,dask.array.core.Array)]): - rho = dsa.map_blocks(jmd95numba.rho,s,t,p) - elif any([[isinstance(s,xr.DataArray), - isinstance(t,xr.DataArray), - isinstance(p,xr.DataArray)]]): - rho = xr.apply_ufunc(jmd95numba.rho,s,t,p,dask='allowed') - else: - rho = jmd95numba.rho(s,t,p) - return rho +def _any_xarray(a,b,c): + any_xarray = any([[isinstance(a,xr.DataArray), + isinstance(b,xr.DataArray), + isinstance(c,xr.DataArray)]]) + return any_xarray -def drhodt_x(s, t, p): - if any([isinstance(s,dask.array.core.Array), - isinstance(t,dask.array.core.Array), - isinstance(p,dask.array.core.Array)]): - drhodt = dsa.map_blocks(jmd95numba.drhodt,s,t,p) - elif any([[isinstance(s,xr.DataArray), - isinstance(t,xr.DataArray), - isinstance(p,xr.DataArray)]]): - drhodt = xr.apply_ufunc(jmd95numba.drhodt,s,t,p,dask='allowed') - else: - drhodt = jmd95numba.rho(s,t,p) - return drhodt +def my_decorator(func): + def wrapper(*args): + if _any_dask_array(*args): + rho = dsa.map_blocks(func,*args) + elif _any_xarray(*args): + rho = xr.apply_ufunc(func,*args,dask='allowed') + else: + rho = func(*args) + return rho + return wrapper -def drhods_x(s, t, p): - if any([isinstance(s,dask.array.core.Array), - isinstance(t,dask.array.core.Array), - isinstance(p,dask.array.core.Array)]): - drhods = dsa.map_blocks(jmd95numba.drhods,s,t,p) - elif any([[isinstance(s,xr.DataArray), - isinstance(t,xr.DataArray), - isinstance(p,xr.DataArray)]]): - drhods = xr.apply_ufunc(jmd95numba.drhods,s,t,p,dask='allowed') - else: - drhods = jmd95numba.rho(s,t,p) - return drhods \ No newline at end of file +@my_decorator +def rho(s,t,p): + return jmd95numba.rho(s,t,p) + +@my_decorator +def drhodt(s,t,p): + return jmd95numba.drhodt(s,t,p) + +@my_decorator +def drhods(s,t,p): + return jmd95numba.drhods(s,t,p) diff --git a/fastjmd95/test/test_jmd95.py b/fastjmd95/test/test_jmd95.py index 3d35b41..4147661 100644 --- a/fastjmd95/test/test_jmd95.py +++ b/fastjmd95/test/test_jmd95.py @@ -2,7 +2,7 @@ from itertools import product import pytest -from fastjmd95 import rho_x, drhodt_x, drhods_x +from fastjmd95 import rho, drhodt, drhods from .reference_values import rho_expected, drhodt_expected, drhods_expected import dask @@ -55,9 +55,9 @@ def distributed_client(): # https://stackoverflow.com/questions/45225950/passing-yield-fixtures-as-test-parameters-with-a-temp-directory @pytest.mark.parametrize('client', all_clients) @pytest.mark.parametrize('function,expected', - [(rho_x, rho_expected), - (drhodt_x, drhodt_expected), - (drhods_x, drhods_expected)]) + [(rho, rho_expected), + (drhodt, drhodt_expected), + (drhods, drhods_expected)]) def test_functions(request, client, s_t_p, function, expected): s, t, p = s_t_p if client != 'no_client': From 00469b5ddd1c695b407452959bc8699b120dc0d6 Mon Sep 17 00:00:00 2001 From: Spencer Jones Date: Wed, 14 Oct 2020 09:17:27 -0400 Subject: [PATCH 3/4] syntax tweaks --- fastjmd95/jmd95wrapper.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/fastjmd95/jmd95wrapper.py b/fastjmd95/jmd95wrapper.py index eafb00b..72627f0 100644 --- a/fastjmd95/jmd95wrapper.py +++ b/fastjmd95/jmd95wrapper.py @@ -6,19 +6,13 @@ import fastjmd95.jmd95numba as jmd95numba -def _any_dask_array(a,b,c): - any_dask = any([isinstance(a,dask.array.core.Array), - isinstance(b,dask.array.core.Array), - isinstance(c,dask.array.core.Array)]) - return any_dask +def _any_dask_array(*args): + return any([isinstance(a, dask.array.core.Array) for a in args]) -def _any_xarray(a,b,c): - any_xarray = any([[isinstance(a,xr.DataArray), - isinstance(b,xr.DataArray), - isinstance(c,xr.DataArray)]]) - return any_xarray +def _any_xarray(*args): + return any([isinstance(a, xr.DataArray) for a in args]) -def my_decorator(func): +def maybe_wrap_arrays(func): def wrapper(*args): if _any_dask_array(*args): rho = dsa.map_blocks(func,*args) @@ -29,14 +23,14 @@ def wrapper(*args): return rho return wrapper -@my_decorator +@maybe_wrap_arrays def rho(s,t,p): return jmd95numba.rho(s,t,p) -@my_decorator +@maybe_wrap_arrays def drhodt(s,t,p): return jmd95numba.drhodt(s,t,p) -@my_decorator +@maybe_wrap_arrays def drhods(s,t,p): return jmd95numba.drhods(s,t,p) From 59c561ffb310ecdc24c61d4ddb2f2299ed64b628 Mon Sep 17 00:00:00 2001 From: Spencer Jones Date: Wed, 14 Oct 2020 11:37:22 -0400 Subject: [PATCH 4/4] added testing and fixed apply_ufunc --- fastjmd95/jmd95wrapper.py | 2 +- fastjmd95/test/test_jmd95.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/fastjmd95/jmd95wrapper.py b/fastjmd95/jmd95wrapper.py index 72627f0..5e25e05 100644 --- a/fastjmd95/jmd95wrapper.py +++ b/fastjmd95/jmd95wrapper.py @@ -17,7 +17,7 @@ def wrapper(*args): if _any_dask_array(*args): rho = dsa.map_blocks(func,*args) elif _any_xarray(*args): - rho = xr.apply_ufunc(func,*args,dask='allowed') + rho = xr.apply_ufunc(func,*args,output_dtypes=[float],dask='parallelized') else: rho = func(*args) return rho diff --git a/fastjmd95/test/test_jmd95.py b/fastjmd95/test/test_jmd95.py index 4147661..0f3535a 100644 --- a/fastjmd95/test/test_jmd95.py +++ b/fastjmd95/test/test_jmd95.py @@ -7,6 +7,7 @@ import dask import dask.array +import xarray as xr @pytest.fixture def s_t_p(): @@ -19,6 +20,10 @@ def s_t_p(): def _chunk(*args): return [dask.array.from_array(a, chunks=(100,)) for a in args] +def _make_xarray(*args,withdask=False): + if withdask==True: + args = _chunk(*args) + return [xr.DataArray(a, coords=[np.arange(0,len(a))], dims=["i"]) for a in args] @pytest.fixture def no_client(): @@ -52,16 +57,24 @@ def distributed_client(): all_clients = ['no_client', 'threaded_client', 'processes_client', 'distributed_client'] +all_arrays = ['dask_arrays', 'xarrays'] # https://stackoverflow.com/questions/45225950/passing-yield-fixtures-as-test-parameters-with-a-temp-directory @pytest.mark.parametrize('client', all_clients) +@pytest.mark.parametrize('array_type', all_arrays) @pytest.mark.parametrize('function,expected', [(rho, rho_expected), (drhodt, drhodt_expected), (drhods, drhods_expected)]) -def test_functions(request, client, s_t_p, function, expected): +def test_functions(request, client, array_type, s_t_p, function, expected): s, t, p = s_t_p - if client != 'no_client': - s, t, p = _chunk(s, t, p) + if array_type == 'dask_arrays': + if client != 'no_client': + s, t, p = _chunk(s, t, p) + elif array_type == 'xarrays': + if client != 'no_client': + s, t, p = _make_xarray(s, t, p, withdask=True) + else: + s, t, p = _make_xarray(s, t, p, withdask=False) client = request.getfixturevalue(client) actual = function(s, t, p) np.testing.assert_allclose(actual, expected, rtol=1e-2)