Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Issue/624 pv types #641

Merged
merged 4 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions nowcasting_dataset/data_sources/pv/pv_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,14 +396,17 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
data=pv_system_row_number,
dims=["id"],
)
pv["x_osgb"] = x_coords
pv["y_osgb"] = y_coords
pv["x_osgb"] = x_coords.astype("float32")
pv["y_osgb"] = y_coords.astype("float32")
pv["pv_system_row_number"] = pv_system_row_number

# pad out so that there are always n_pv_systems_per_example, pad with zeros
pad_n = self.n_pv_systems_per_example - len(pv.id)
pv = pv.pad(id=(0, pad_n), power_mw=((0, 0), (0, pad_n)), constant_values=0)

# format id
pv.__setitem__("id", pv.id.astype("int32"))

return pv

def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLocation]:
Expand Down
9 changes: 7 additions & 2 deletions nowcasting_dataset/dataset/xr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def join_list_dataset_to_batch_dataset(datasets: list[xr.Dataset]) -> xr.Dataset
new_dataset = dataset.expand_dims(dim="example").assign_coords(example=("example", [i]))
new_datasets.append(new_dataset)

return xr.concat(new_datasets, dim="example")
joined_dataset = xr.concat(new_datasets, dim="example")

# format example index
joined_dataset.__setitem__("example", joined_dataset.example.astype("int32"))

return joined_dataset


def convert_coordinates_to_indexes_for_list_datasets(
Expand All @@ -43,7 +48,7 @@ def convert_coordinates_to_indexes(dataset: xr.Dataset) -> xr.Dataset:

for original_dim_name in original_dim_names:
original_coords = dataset[original_dim_name]
new_index_coords = np.arange(len(original_coords))
new_index_coords = np.arange(len(original_coords)).astype("int32")
new_index_dim_name = f"{original_dim_name}_index"
dataset[original_dim_name] = new_index_coords
dataset = dataset.rename({original_dim_name: new_index_dim_name})
Expand Down
6 changes: 6 additions & 0 deletions tests/data_sources/pv/test_pv_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def test_get_example_and_batch(): # noqa: D103
# start at 6, to avoid some nans
batch = pv_data_source.get_batch(locations=locations[6:16])
assert batch.power_mw.shape == (10, 19, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE)
assert str(batch.x_osgb.dtype) == "float32"
assert str(batch.y_osgb.dtype) == "float32"
assert str(batch.id.dtype) == "int32"
assert str(batch.example.dtype) == "int32"
assert str(batch.id_index.dtype) == "int32"
assert str(batch.time_index.dtype) == "int32"


def test_drop_pv_systems_which_produce_overnight(): # noqa: D103
Expand Down