Skip to content

Commit

Permalink
Modified shoal detection to work with the new refactorings (#132)
Browse files Browse the repository at this point in the history
* Modified shoal detection to work with the new refactorings

* Added dask_image to requirements-dev
  • Loading branch information
ruxandra-valcu authored Nov 13, 2023
1 parent 527ec12 commit 920eed2
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 63 deletions.
6 changes: 1 addition & 5 deletions oceanstream/L3_regridded_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
)
from .mvbs_computation import compute_mvbs
from .nasc_computation import compute_per_dataset_nasc
from .shoal_detection_handler import (
attach_shoal_mask_to_ds,
combine_shoal_masks_multichannel,
create_shoal_mask_multichannel,
)
from .shoal_detection_handler import attach_shoal_mask_to_ds, create_shoal_mask_multichannel
from .shoal_process import (
process_shoals,
process_single_shoal,
Expand Down
54 changes: 10 additions & 44 deletions oceanstream/L3_regridded_data/shoal_detection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@

from oceanstream.utils import add_metadata_to_mask, attach_mask_to_dataset

WEILL_DEFAULT_PARAMETERS = {"thr": -70, "maxvgap": -5, "maxhgap": 0, "minvlen": 0, "minhlen": 0}
WEILL_DEFAULT_PARAMETERS = {
"thr": -70,
"maxvgap": 5,
"maxhgap": 5,
"minvlen": 0,
"minhlen": 0,
"dask_chunking": {"ping_time": 1000, "range_sample": 1000},
}


def create_shoal_mask_multichannel(
Expand Down Expand Up @@ -70,49 +77,9 @@ def create_shoal_mask_multichannel(
>>> mask, mask_ = create_shoal_mask_multichannel(Sv, parameters, method)
"""
mask, mask_ = get_shoal_mask_multichannel(Sv, parameters, method)
mask = get_shoal_mask_multichannel(Sv, parameters, method)
mask_type_value = method
mask.attrs["shoal detection mask type"] = mask_type_value
mask_.attrs["shoal detection mask type"] = mask_type_value
return mask, mask_


def combine_shoal_masks_multichannel(mask: xr.DataArray, mask_: xr.DataArray) -> xr.DataArray:
"""
Combines the provided multichannel masks (`mask` and `mask_`) to produce a final mask that contains `True` values
only where both input masks are `True` for each channel.
Parameters:
mask : xr.DataArray
A multichannel mask for the Sv data. Regions satisfying the thresholding criteria
for shoal identification are filled with `True`, else the regions are filled with `False`.
mask_ : xr.DataArray
A mask indicating the valid samples for the first mask. Edge regions are
filled with 'False', whereas the portion in which shoals could be detected is 'True'.
Returns:
xr.DataArray
A final multichannel mask for the Sv data. Regions that meet the thresholding criteria
for shoal identification and fall within valid samples are marked as True.
All other regions are marked as False.
Example:
>>> mask, mask_ = create_shoal_mask_multichannel(Sv_dataset)
>>> combined_masks = combine_masks_multichannel(mask, mask_)
"""
# Check if both masks have the 'channel' dimension
if "channel" not in mask.dims or "channel" not in mask_.dims:
raise ValueError("Both masks must have a 'channel' dimension for multichannel processing.")
# Ensure the channels in both masks match
if not all(mask["channel"].values == mask_["channel"].values):
raise ValueError("Channels in both masks must match.")
combined_masks = xr.where(mask & mask_, True, False)
combined_masks.attrs["shoal detection mask type"] = mask.attrs["shoal detection mask type"]
return combined_masks


def attach_shoal_mask_to_ds(
Expand All @@ -138,7 +105,6 @@ def attach_shoal_mask_to_ds(
Example:
>>> ds_with_shoal_mask = attach_shoal_mask_to_ds(ds, parameters, method)
"""
mask, mask_ = create_shoal_mask_multichannel(ds, parameters, method)
shoal_mask = combine_shoal_masks_multichannel(mask, mask_)
shoal_mask = create_shoal_mask_multichannel(ds, parameters, method)
shoal_mask = add_metadata_to_mask(mask=shoal_mask, metadata={"mask_type": "shoal"})
return attach_mask_to_dataset(ds, shoal_mask)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ sphinxcontrib-mermaid
twine
wheel
haversine
dask_image
git+https://github.com/OceanStreamIO/echopype.git@next-dev
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ geopy
pydantic>2
echopype
haversine
dask_image
17 changes: 3 additions & 14 deletions tests/test_shoal_detection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from oceanstream.L3_regridded_data.shoal_detection_handler import (
attach_shoal_mask_to_ds,
combine_shoal_masks_multichannel,
create_shoal_mask_multichannel,
)

Expand All @@ -16,24 +15,14 @@ def _count_false_values(mask: xr.DataArray) -> int:

@pytest.fixture(scope="session")
def shoal_masks(ek_60_Sv_denoised):
mask, mask_ = create_shoal_mask_multichannel(ek_60_Sv_denoised)
return mask, mask_
mask = create_shoal_mask_multichannel(ek_60_Sv_denoised)
return mask


# @pytest.mark.ignore
def test_create_shoal_mask_multichannel(shoal_masks):
mask, mask_ = shoal_masks
mask = shoal_masks
assert _count_false_values(mask) == 4873071
assert _count_false_values(mask_) == 0


# @pytest.mark.ignore
def test_combine_shoal_masks_multichannel(shoal_masks):
mask, mask_ = shoal_masks
combined_masks = combine_shoal_masks_multichannel(mask, mask_)

assert _count_false_values(combined_masks) == 4873071


@pytest.mark.ignore
def test_attach_shoal_mask_to_ds(ek_60_Sv_denoised):
Expand Down

0 comments on commit 920eed2

Please sign in to comment.