diff --git a/mllam_data_prep/derived_variables.py b/mllam_data_prep/derived_variables.py index cda1bdf..b6b67db 100644 --- a/mllam_data_prep/derived_variables.py +++ b/mllam_data_prep/derived_variables.py @@ -38,19 +38,25 @@ def derive_variables(fp, derived_variables, chunking): ds_subset = xr.Dataset() ds_subset.attrs.update(ds.attrs) for _, derived_variable in derived_variables.items(): - required_variables = derived_variable.kwargs + required_kwargs = derived_variable.kwargs function_name = derived_variable.function derived_variable_attributes = derived_variable.attributes or {} - ds_input = ds[required_variables.keys()] + + # Separate the lat,lon from the required variables as these will be derived separately + latlon_coords_to_include = {} + for k, v in list(required_kwargs.items()): + if k in ["lat", "lon"]: + latlon_coords_to_include[k] = required_kwargs.pop(k) + + # Subset the dataset + ds_input = ds[required_kwargs.keys()] # Any coordinates needed for the derivation, for which chunking should be performed, - # should be converted to variables since it is not possible for coordinates to be - # chunked dask arrays + # should be converted to variables since it is not possible for *indexed* coordinates + # to be chunked dask arrays chunks = {d: chunking.get(d, int(ds_input[d].count())) for d in ds_input.dims} required_coordinates = [ - req_var - for req_var in required_variables.keys() - if req_var in ds_input.coords + req_var for req_var in required_kwargs.keys() if req_var in ds_input.coords ] ds_input = ds_input.drop_indexes(required_coordinates, errors="ignore") for req_coord in required_coordinates: @@ -60,9 +66,15 @@ def derive_variables(fp, derived_variables, chunking): # Chunk the data variables ds_input = ds_input.chunk(chunks) - # Calculate the derived variable - kwargs = {v: ds_input[k] for k, v in required_variables.items()} + # Add function arguments to kwargs + kwargs = {} + if len(latlon_coords_to_include): + latlon = get_latlon_coords_for_input(ds) + for k, v in latlon_coords_to_include.items(): + kwargs[v] = latlon[k] + kwargs.update({v: ds_input[k] for k, v in required_kwargs.items()}) func = _get_derived_variable_function(function_name) + # Calculate the derived variable derived_field = func(**kwargs) # Check the derived field(s) @@ -408,3 +420,8 @@ def cyclic_encoding(data_array, da_max): data_array_cos = np.cos((data_array / da_max) * 2 * np.pi) return data_array_cos, data_array_sin + + +def get_latlon_coords_for_input(ds_input): + """Dummy function for getting lat and lon.""" + return ds_input[["lat", "lon"]].chunk(-1, -1)