Skip to content

Commit

Permalink
ensure shape is always passed to array_strategy_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Nov 2, 2023
1 parent 700d652 commit 0e01d76
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions xarray/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,10 @@ def variables(
)

else:
# nothing provided, so generate everything consistently by drawing dims to match data
array_strategy = _array_strategy_fn(dtype=_dtype)
# nothing provided, so generate everything consistently
# We still generate the shape first here just so that we always pass shape to array_strategy_fn
_shape = draw(npst.array_shapes())
array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype)
_data = draw(array_strategy)
dim_names = draw(dimension_names(min_dims=_data.ndim, max_dims=_data.ndim))

Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def fixed_array_strategy_fn(*, shape=None, dtype=None):

var = data.draw(
variables(
array_strategy_fn=fixed_array_strategy_fn, dtype=st.just(arr.dtype) # type: ignore[arg-type]
array_strategy_fn=fixed_array_strategy_fn, dims=st.just({"x": 2, "y": 2}), dtype=st.just(arr.dtype) # type: ignore[arg-type]
)
)

Expand Down

0 comments on commit 0e01d76

Please sign in to comment.