Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the downsampling feature to same module as interpolation #145

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oceanstream/L2_calibrated_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
from .processed_data_io import read_processed, write_processed
from .sv_computation import compute_sv
from .sv_dataset_extension import enrich_sv_dataset
from .sv_interpolation import interpolate_sv
from .sv_interpolation import interpolate_sv, regrid_dataset
from .target_strength_computation import compute_target_strength
180 changes: 179 additions & 1 deletion oceanstream/L2_calibrated_data/sv_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
it offers edge filling capabilities to enhance the quality of interpolated echograms.
"""
from pathlib import Path
from typing import Union
from typing import Hashable, Tuple, Union

import numpy as np
import xarray as xr
from scipy.interpolate import interp1d

from oceanstream.L2_calibrated_data.processed_data_io import read_processed

Expand Down Expand Up @@ -99,3 +100,180 @@ def interpolate_sv(
dataset["Sv_interpolated"] = interpolated_sv

return dataset


def find_impacted_variables(
dataset: xr.Dataset, target_dim: str = "range_sample"
) -> list[Hashable]:
"""
Finds and lists the variable names in a given xarray Dataset that are associated with a specified dimension.

Parameters:
- dataset (xr.Dataset): The xarray Dataset to search for variables.
- target_dim (str, optional): The name of the target dimension to check in each variable. Defaults to "range_sample".

Returns:
- list[Hashable]: A list of variable names (hashable items) that have the target dimension.

Example:
>> impacted_vars = find_impacted_variables(dataset)
Expected Output:
A list of variable names from the dataset which have 'range_sample' as one of their dimensions.

>> impacted_vars = find_impacted_variables(dataset, target_dim='time')
Expected Output:
A list of variable names from the dataset which have 'time' as one of their dimensions.
"""
variable_list = []
for var_name, data_array in dataset.variables.items():
if target_dim in data_array.dims and var_name != target_dim:
variable_list.append(var_name)
return variable_list


def find_lowest_resolution_channel(dataset: xr.Dataset) -> Tuple[int, int]:
"""
Finds the channel with the lowest resolution in a given xarray Dataset based on echo range data.

This function iterates over all channels in the dataset, identified by 'frequency_nominal',
and determines the channel that reaches the smallest maximum depth across all pings.
The depth is determined based on 'echo_range' values in the dataset.

Parameters:
- dataset (xr.Dataset): The xarray Dataset containing echo data and channel information.

Returns:
- Tuple[int, int]: A tuple containing the index of the channel with the lowest resolution
and the maximum depth index of that channel. The first element is the channel index,
and the second element is the maximum depth index within that channel.

Example:
>> lowest_res_channel, max_depth_index = find_lowest_resolution_channel(dataset)
Expected Output:
A tuple where the first element is the index of the channel with the lowest resolution,
and the second element is the maximum depth index within that channel.
"""
arg_max_dataset = len(dataset["range_sample"].values)
return_channel = 0
for ch in range(len(dataset["frequency_nominal"])):
arg_max = np.max(np.nanargmax(dataset["echo_range"].values[ch, :, :], axis=1))
if arg_max < arg_max_dataset:
arg_max_dataset = arg_max
return_channel = ch
return return_channel, arg_max_dataset


def resample_xarray(da: xr.DataArray, old_depth_da, new_depth_da, new_range_sample) -> xr.DataArray:
"""
Resamples an xarray DataArray to a new depth profile using linear interpolation.

This function processes an xarray DataArray representing acoustic data, resampling it to
match a new depth profile. The function handles both 'Sv' (volume backscattering strength)
and other data types. For 'Sv' data, it first converts from decibel to linear scale before
resampling, and then back to decibel scale after resampling.

Parameters:
- da (xr.DataArray): The xarray DataArray to resample. It should contain 'channel' and 'ping_time' dimensions.
- old_depth_da (array-like): The original depth values corresponding to each point in 'da'.
- new_depth_da (array-like): The new depth values to which 'da' will be resampled.
- new_range_sample (array-like): The new range sample values that will form the new 'range_sample' coordinate.

Returns:
- xr.DataArray: The resampled xarray DataArray with the new depth profile.

Example:
>> resampled_da = resample_xarray(da, old_depth_da, new_depth_da, new_range_sample)
Expected Output:
An xarray DataArray resampled to the new depth profile specified by 'new_depth_da' and 'new_range_sample'.

Note:
- The function assumes 'da' has dimensions 'channel' and 'ping_time'.
- For 'Sv' data, conversions between decibel and linear scales are performed.
"""
interpolated_data_arrays = []
channel_order_list = []
time_coord = da["ping_time"]
depth_values = old_depth_da
for ch in range(da.sizes["channel"]):
depth_idx = len(depth_values[ch])
if da.name == "Sv":
data_array = db_to_linear(da[ch, :, :depth_idx]).values
else:
data_array = da[ch, :, :depth_idx].values
new_data_array = np.zeros((data_array.shape[0], len(new_range_sample)))
for i in range(len(data_array)): # Loop over the time dimension
# Assign new depth values as coordinates
# Create an interpolation function for the current slice
interp_func = interp1d(
depth_values[ch],
data_array[i],
kind="nearest",
bounds_error=False,
fill_value=np.nan,
)
# Interpolate the data to the new depth values
new_data_array[i, :] = interp_func(new_depth_da[:])
# Create a DataArray for the interpolated data of this channel
channel_data = xr.DataArray(
new_data_array,
coords={"ping_time": time_coord, "range_sample": new_range_sample},
dims=["ping_time", "range_sample"],
)
channel_order_list.append(da["channel"][ch].values)
interpolated_data_arrays.append(channel_data)
# Combine interpolated data from all channels
combined_data = xr.concat(interpolated_data_arrays, dim="channel")
combined_data = combined_data.assign_coords(channel=("channel", channel_order_list))
if da.name == "Sv":
combined_data = linear_to_db(combined_data)
return combined_data


def regrid_dataset(dataset: xr.Dataset) -> xr.Dataset:
"""
Regrids an xarray Dataset to a new depth profile based on the channel with the lowest resolution.

This function identifies the channel with the lowest resolution in the provided dataset and
uses its depth profile to regrid all relevant variables in the dataset to a new common depth
profile. It resamples variables that include the 'range_sample' dimension and retains other
coordinates and attributes unchanged.

Parameters:
- dataset (xr.Dataset): The xarray Dataset to be regridded. It should contain 'echo_range'
and 'range_sample' among other coordinates and dimensions.

Returns:
- xr.Dataset: A new xarray Dataset with variables resampled to the new common depth profile.

Example:
>> regridded_dataset = regrid_dataset(dataset)
Expected Output:
A new xarray Dataset with variables that have been resampled to match the depth profile of
the channel with the lowest resolution in the original dataset.

Note:
- The function determines the new common depth profile based on the channel with the lowest resolution.
- Variables with the 'range_sample' dimension are resampled; other variables and coordinates
are copied as is.
- The attributes of the original dataset are preserved in the new dataset.
"""
new_dataset = xr.Dataset()
channel, max_depth_idx = find_lowest_resolution_channel(dataset)

per_ping_depth = np.nanargmax(dataset["echo_range"].values[channel, :, :], axis=1)
new_range_sample = np.arange(np.max(per_ping_depth) + 1)
ping_with_max_depth = np.nanargmax(per_ping_depth)
bin_depths = dataset["echo_range"].values[:, ping_with_max_depth, :]
new_bin_depths = dataset["echo_range"].values[channel, ping_with_max_depth, : max_depth_idx + 1]
for coord_name in dataset.coords:
if coord_name == "range_sample":
new_dataset.coords[coord_name] = ("range_sample", new_range_sample)
else:
new_dataset.coords[coord_name] = dataset.coords[coord_name]
impacted_vars = find_impacted_variables(dataset, target_dim="range_sample")
for var_name, data_array in dataset.data_vars.items():
if var_name in impacted_vars:
data_array = resample_xarray(data_array, bin_depths, new_bin_depths, new_range_sample)
new_dataset[var_name] = data_array
new_dataset.attrs = dataset.attrs.copy()
return new_dataset
19 changes: 19 additions & 0 deletions tests/test_sv_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
db_to_linear,
interpolate_sv,
linear_to_db,
regrid_dataset,
)
from oceanstream.L2_calibrated_data.sv_computation import compute_sv
from oceanstream.utils import add_metadata_to_mask, attach_mask_to_dataset
from tests.conftest import TEST_DATA_FOLDER

Expand Down Expand Up @@ -186,3 +188,20 @@ def test_deterministic_behavior(complete_dataset_jr179):
result2 = interpolate_sv(dataset_with_mask2)

assert np.array_equal(result1["Sv_interpolated"], result2["Sv_interpolated"], equal_nan=True)


def test_regrid_dataset(ed_ek_80_for_Sv):
ds1 = compute_sv(ed_ek_80_for_Sv, waveform_mode="CW", encode_mode="complex")

ds2 = regrid_dataset(ds1)

# Check if dimension names are the same
assert set(ds1.dims.keys()) == set(ds2.dims.keys()), "Dimension names are different"

# Check if attributes are the same
assert ds1.attrs == ds2.attrs, "Attributes are different"

# Check if variables are the same
assert set(ds1.variables) == set(ds2.variables), "Variables are different"


Loading