Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Commit

Permalink
Issue/624 pv types (#641)
Browse files Browse the repository at this point in the history
* assert the types we want of pv data - TDD

* change data types to float32 or int32

* tidy

Co-authored-by: Jacob Bieker <[email protected]>
  • Loading branch information
peterdudfield and jacobbieker authored May 9, 2022
1 parent 0711c52 commit 854c0bd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
7 changes: 5 additions & 2 deletions nowcasting_dataset/data_sources/pv/pv_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,14 +396,17 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
data=pv_system_row_number,
dims=["id"],
)
pv["x_osgb"] = x_coords
pv["y_osgb"] = y_coords
pv["x_osgb"] = x_coords.astype("float32")
pv["y_osgb"] = y_coords.astype("float32")
pv["pv_system_row_number"] = pv_system_row_number

# pad out so that there are always n_pv_systems_per_example, pad with zeros
pad_n = self.n_pv_systems_per_example - len(pv.id)
pv = pv.pad(id=(0, pad_n), power_mw=((0, 0), (0, pad_n)), constant_values=0)

# format id
pv.__setitem__("id", pv.id.astype("int32"))

return pv

def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLocation]:
Expand Down
9 changes: 7 additions & 2 deletions nowcasting_dataset/dataset/xr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def join_list_dataset_to_batch_dataset(datasets: list[xr.Dataset]) -> xr.Dataset
new_dataset = dataset.expand_dims(dim="example").assign_coords(example=("example", [i]))
new_datasets.append(new_dataset)

return xr.concat(new_datasets, dim="example")
joined_dataset = xr.concat(new_datasets, dim="example")

# format example index
joined_dataset.__setitem__("example", joined_dataset.example.astype("int32"))

return joined_dataset


def convert_coordinates_to_indexes_for_list_datasets(
Expand All @@ -43,7 +48,7 @@ def convert_coordinates_to_indexes(dataset: xr.Dataset) -> xr.Dataset:

for original_dim_name in original_dim_names:
original_coords = dataset[original_dim_name]
new_index_coords = np.arange(len(original_coords))
new_index_coords = np.arange(len(original_coords)).astype("int32")
new_index_dim_name = f"{original_dim_name}_index"
dataset[original_dim_name] = new_index_coords
dataset = dataset.rename({original_dim_name: new_index_dim_name})
Expand Down
6 changes: 6 additions & 0 deletions tests/data_sources/pv/test_pv_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def test_get_example_and_batch(): # noqa: D103
# start at 6, to avoid some nans
batch = pv_data_source.get_batch(locations=locations[6:16])
assert batch.power_mw.shape == (10, 19, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE)
assert str(batch.x_osgb.dtype) == "float32"
assert str(batch.y_osgb.dtype) == "float32"
assert str(batch.id.dtype) == "int32"
assert str(batch.example.dtype) == "int32"
assert str(batch.id_index.dtype) == "int32"
assert str(batch.time_index.dtype) == "int32"


def test_drop_pv_systems_which_produce_overnight(): # noqa: D103
Expand Down

0 comments on commit 854c0bd

Please sign in to comment.