Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance issues on google cloud (and beyond) #6

Open
jbusecke opened this issue Oct 28, 2020 · 12 comments
Open

Performance issues on google cloud (and beyond) #6

jbusecke opened this issue Oct 28, 2020 · 12 comments

Comments

@jbusecke
Copy link
Contributor

I am using fastjmd95 to infer potential density from CMIP6 models. I have recently experienced performance issues in a complicated workflow, but I think I can trace some of it back to the step involving fastjmd95.

Here is a small example that reproduces the issue:

# Load a single model from the CMIP archive
import xarray as xr
import gcsfs
from fastjmd95 import jmd95numba

gcs = gcsfs.GCSFileSystem(token='anon')
so = xr.open_zarr(gcs.get_mapper('gs://cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Omon/so/gn/'), consolidated=True).so
thetao = xr.open_zarr(gcs.get_mapper('gs://cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Omon/thetao/gn/'), consolidated=True).thetao

# calculate sigma0 based on the instruction notebook (https://nbviewer.jupyter.org/github/xgcm/fastjmd95/blob/master/doc/fastjmd95_tutorial.ipynb)
sigma_0 = xr.apply_ufunc(
    jmd95numba.rho, so, thetao, 0, dask='parallelized', output_dtypes=[so.dtype]
) - 1000

I then performed some tests on the Goodle Cloud deployment (dask cluster with 5 workers)
When I trigger a computation on the variables that are simply read from storage (so.mean().load(), everything works fine, the memory load is low and the task stream dense)

But when I try the same with the derived variable (sigma_0.mean().load()), things look really ugly: The memory fills up almost immediately and spilling to disk starts. From the Progress Pane it seems like dask is trying to load a large chunk of the dataset into memory before the rho calculation is applied.
image

To me it seems like the scheduler is going wide on the task graph rather than deep, which could free up some memory?
I am really not good enough to diagnose what is going on with dask, but any tips would be much appreciated.

@jbusecke
Copy link
Contributor Author

jbusecke commented Oct 28, 2020

I have just tried my older method using gsw

import gsw
def _sigma0(lon, lat, z, temp, salt):
    pr = gsw.conversions.p_from_z(-z, lat)
    sa = gsw.SA_from_SP(salt, pr, lon, lat)
    ct = gsw.CT_from_pt(sa, temp)
    sigma0 = gsw.sigma0(sa, ct)
    return sigma0

def reconstruct_sigma0(lon, lat, z, temp, salt):
    kwargs = dict(dask="parallelized", output_dtypes=[salt.dtype])
    ds_sigma0 = xr.apply_ufunc(_sigma0, lon, lat, z, temp, salt, **kwargs)
    return ds_sigma0

For now that works (it takes long and has high memory use but the above example actually crashes eventually), but it is clunky because of the latitude dependence and I would like to use a more performant implementation in the future. Just wanted to add another datapoint.

EDIT: I spoke too soon. This failed with some obscure broadcasting issue. Could it be that there is some new bug in xr.apply_ufuncs?

@rabernat
Copy link
Contributor

Julius, can you try this with the latest master? In #5, @cspencerjones implemented the xarray wrapper layer, so you should not have to call apply_ufunc at all. I'm not sure this makes any difference, but I would like to see.

@jbusecke
Copy link
Contributor Author

I should have mentioned this earlier. This is installed from the latest master (as of today). I am having trouble reproducing this behavior, which makes me think that this might be another problem on the google side?

@jbusecke
Copy link
Contributor Author

jbusecke commented Oct 28, 2020

Oh I think I misunderstood. You mean use it like this?

sigma_0 = jmd95numba.rho(so, thetao, 0) - 1000

EDIT: I just tried it and it gives me rather obscure error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<timed eval> in <module>

/srv/conda/envs/notebook/lib/python3.8/site-packages/xarray/core/dataarray.py in load(self, **kwargs)
    806         dask.array.compute
    807         """
--> 808         ds = self._to_temp_dataset().load(**kwargs)
    809         new = self._from_temp_dataset(ds)
    810         self._variable = new._variable

/srv/conda/envs/notebook/lib/python3.8/site-packages/xarray/core/dataset.py in load(self, **kwargs)
    652 
    653             # evaluate all the dask arrays simultaneously
--> 654             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    655 
    656             for k, data in zip(lazy_data, evaluated_data):

/srv/conda/envs/notebook/lib/python3.8/site-packages/dask/base.py in compute(*args, **kwargs)
    450         postcomputes.append(x.__dask_postcompute__())
    451 
--> 452     results = schedule(dsk, keys, **kwargs)
    453     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    454 

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2723                     should_rejoin = False
   2724             try:
-> 2725                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2726             finally:
   2727                 for f in futures.values():

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1984             else:
   1985                 local_worker = None
-> 1986             return self.sync(
   1987                 self._gather,
   1988                 futures,

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    830             return future
    831         else:
--> 832             return sync(
    833                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    834             )

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    338     if error[0]:
    339         typ, exc, tb = error[0]
--> 340         raise exc.with_traceback(tb)
    341     else:
    342         return result[0]

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/utils.py in f()
    322             if callback_timeout is not None:
    323                 future = asyncio.wait_for(future, callback_timeout)
--> 324             result[0] = yield future
    325         except Exception as exc:
    326             error[0] = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.8/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1849                             exc = CancelledError(key)
   1850                         else:
-> 1851                             raise exception.with_traceback(traceback)
   1852                         raise exc
   1853                     if errors == "skip":

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/protocol/pickle.py in loads()
     73             return pickle.loads(x, buffers=buffers)
     74         else:
---> 75             return pickle.loads(x)
     76     except Exception as e:
     77         logger.info("Failed to deserialize %s", x[:10000], exc_info=True)

/srv/conda/envs/notebook/lib/python3.8/site-packages/numpy/core/__init__.py in _ufunc_reconstruct()
    124     # scipy.special.expit for instance.
    125     mod = __import__(module, fromlist=[name])
--> 126     return getattr(mod, name)
    127 
    128 def _ufunc_reduce(func):

AttributeError: module '__mp_main__' has no attribute 'rho'

@jbusecke
Copy link
Contributor Author

One possible issue is, that salinity and temperature for this model are chunked differently in time 🙀. I rechunked them right after import and it seems to at least make the cluster die slower, but it is still not good.

@jbusecke
Copy link
Contributor Author

An update from my side. Thanks to @cspencerjones I was able to get this going.

I tried 3 different approaches

  1. sigma_0 = xr.apply_ufunc( jmd95numba.rho, so, thetao, 0, dask='parallelized', output_dtypes=[so.dtype] ) - 1000
    This barely works (with rechunking as mentioned above). Takes about 45 min in my setup and is a pain to watch on the dashboard

  2. sigma_0_a = jmd95numba.rho(so, thetao, 0) - 1000
    This still crashes with AttributeError: module '__mp_main__' has no attribute 'rho'

3. sigma_0_b = jmd95wrapper.rho(so, thetao, 0) - 1000
This one did the trick! It still has quite high memory usage but runs through smoothly, taking about 5 minutes to compute the mean over the full dataset!

It would be great if (2.) would work out of the box, but I think adding (3.) to the docs would be a great step forward already.

This really makes me wonder if there is an underlying problem with xr.apply_ufunc...Ill investigate that more and see if I can raise an issue with xarray. Or does any of you see a reason why 3. should be very different from 1.?

@cspencerjones
Copy link

I wonder if output_dtypes makes a difference? I don't really see why it would, but inside fastjmd95 I call apply_ufunc with output_dtypes=[float]. What is so.dtype?

@rabernat
Copy link
Contributor

You are not supposed to import from either jmd95wrapper or jmd95numba. Just import jmd95. That will import from jmd95wrapper. This is our only public API:

from .jmd95wrapper import rho, drhodt, drhods

It looks like 1 and 3 do the same thing, no?

elif _any_xarray(*args):
rho = xr.apply_ufunc(func,*args,output_dtypes=[float],dask='parallelized')

@rabernat
Copy link
Contributor

We should also probably be using a more dynamic output_dtype in jmd95wrapper. The underlying code is float32 or float64, depending on the inputs.

@jbusecke
Copy link
Contributor Author

jbusecke commented Oct 29, 2020

It looks like 1 and 3 do the same thing, no?

There is certainly a difference for my use case, but now I am even more confused to why? Is it the dtype? Let me check that on my inputs...

EDIT: Both of my inputs are float32 but no matter what method 1.-3. produce float64. I don't get it...

@jbusecke
Copy link
Contributor Author

You are not supposed to import from either jmd95wrapper or jmd95numba. Just import jmd95. That will import from jmd95wrapper.

Ok that makes sense, but the information in the notebook needs to be updated (that was what I based my trial on).
This method works pretty well for me. I still have no clue why the 'manual' wrapping in xr.apply_ufuncs would be different though.

@rabernat
Copy link
Contributor

This is related to issues @stb2145 was having today.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants