Skip to content

Commit

Permalink
Merge pull request #30 from lucduron/fix-issue-29
Browse files Browse the repository at this point in the history
 Fix #27  - support for 1D results
 Fix #29 - dtype is set properly
  • Loading branch information
tomsail authored Mar 21, 2024
2 parents bf2e8ac + b51cd23 commit 198b001
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
Binary file added tests/data/r1d_tomsail.slf
Binary file not shown.
16 changes: 16 additions & 0 deletions tests/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
],
)

DIMS = pytest.mark.parametrize(
"slf_in",
[
pytest.param("tests/data/r3d_tidal_flats.slf", id="3D"),
pytest.param("tests/data/r2d_tidal_flats.slf", id="2D"),
pytest.param("tests/data/r1d_tomsail.slf", id="1D"),
],
)


def write_netcdf(ds, nc_out):
# Remove dict and multi-dimensional arrays not supported in netCDF
Expand All @@ -29,6 +38,7 @@ def write_netcdf(ds, nc_out):
def test_open_dataset(slf_in):
ds = xr.open_dataset(slf_in, engine="selafin")
assert isinstance(ds, xr.Dataset)
repr(ds)

# Dimensions
assert ds.sizes["time"] == 17
Expand Down Expand Up @@ -151,3 +161,9 @@ def test_from_scratch(tmp_path):

# Writing to a SELAFIN file
ds.selafin.write(slf_out)


@DIMS
def test_dim(slf_in):
ds = xr.open_dataset(slf_in, engine="selafin")
repr(ds)
2 changes: 1 addition & 1 deletion xarray_selafin/Serafin.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def np_type(self):
def _check_dim(self):
# verify data consistence and determine 2D or 3D
if self.is_2d:
if self.nb_nodes_per_elem != 3:
if self.nb_nodes_per_elem not in (1, 3):
raise SerafinValidationError("Unknown mesh type")
else:
if self.nb_nodes_per_elem != 6:
Expand Down
4 changes: 2 additions & 2 deletions xarray_selafin/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def write_serafin(fout, ds):
else: # 2D
is_2d = True
nplan = 1 # just to do a multiplication
slf_header.nb_nodes_per_elem = 3
slf_header.nb_nodes_per_elem = ds.attrs["ikle2"].shape[1]
slf_header.nb_elements = len(ds.attrs["ikle2"])

slf_header.nb_nodes = ds.sizes["node"] * nplan
Expand Down Expand Up @@ -284,7 +284,7 @@ def open_dataset(

# Create data variables
data_vars = {}
dtype = np.float64
dtype = np.dtype(slf.header.np_float_type)

if nplan == 0:
shape = (len(times), npoin2)
Expand Down

0 comments on commit 198b001

Please sign in to comment.