diff --git a/tests/conftest.py b/tests/conftest.py index 98a13fe..7d5e89a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,42 @@ from pvnet_summation.data.datamodule import DataModule +def construct_batch_by_sample_duplication(og_batch, i): + """From a batch of data, take the ith sample and repeat it 317 to create a new batch""" + new_batch = {} + + # Need to loop through these keys and add to batch + ununsed_keys = list(og_batch.keys()) + + # NWP is nested so needs to be treated differently + if BatchKey.nwp in og_batch: + og_nwp_batch = og_batch[BatchKey.nwp] + new_nwp_batch = {} + for nwp_source, og_nwp_source_batch in og_nwp_batch.items(): + new_nwp_source_batch = {} + for key, value in og_nwp_source_batch.items(): + if isinstance(value, torch.Tensor): + n_dims = len(value.shape) + repeats = (317,) + tuple(1 for dim in range(n_dims - 1)) + new_nwp_source_batch[key] = value[i : i + 1].repeat(repeats)[:317] + else: + new_nwp_source_batch[key] = value + new_nwp_batch[nwp_source] = new_nwp_source_batch + + new_batch[BatchKey.nwp] = new_nwp_batch + ununsed_keys.remove(BatchKey.nwp) + + for key in ununsed_keys: + if isinstance(og_batch[key], torch.Tensor): + n_dims = len(og_batch[key].shape) + repeats = (317,) + tuple(1 for dim in range(n_dims - 1)) + new_batch[key] = og_batch[key][i : i + 1].repeat(repeats)[:317] + else: + new_batch[key] = og_batch[key] + + return new_batch + + @pytest.fixture() def sample_data(): # Copy small batches to fake 317 GSPs in each @@ -28,28 +64,22 @@ def sample_data(): times = [] file_n = 0 - for file in glob.glob("tests/data/sample_batches/train/*.pt"): - batch = torch.load(file) - - this_batch = {} - for i in range(batch[BatchKey.gsp_time_utc].shape[0]): + for file in glob.glob("tests/test_data/sample_batches/train/*.pt"): + og_batch = torch.load(file) + + for i in range(og_batch[BatchKey.gsp_time_utc].shape[0]): + # Duplicate sample to fake 317 GSPs - for key in batch.keys(): - if isinstance(batch[key], torch.Tensor): - n_dims = len(batch[key].shape) - repeats = (317,) + tuple(1 for dim in range(n_dims - 1)) - this_batch[key] = batch[key][i : i + 1].repeat(repeats)[:317] - else: - this_batch[key] = batch[key] + new_batch = construct_batch_by_sample_duplication(og_batch, i) # Save fopr both train and val - torch.save(this_batch, f"{tmpdirname}/train/{file_n:06}.pt") - torch.save(this_batch, f"{tmpdirname}/val/{file_n:06}.pt") + torch.save(new_batch, f"{tmpdirname}/train/{file_n:06}.pt") + torch.save(new_batch, f"{tmpdirname}/val/{file_n:06}.pt") file_n += 1 - times += [batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")] - + times += [new_batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")] + times = np.unique(np.sort(np.concatenate(times))) da_output = xr.DataArray( @@ -79,7 +109,7 @@ def sample_data(): ) ds.to_zarr(f"{tmpdirname}/gsp.zarr") - + yield tmpdirname, f"{tmpdirname}/gsp.zarr" @@ -109,7 +139,7 @@ def model_kwargs(): # These kwargs define the pvnet model which the summation model uses kwargs = dict( model_name="openclimatefix/pvnet_v2", - model_version="805ca9b2ee3120592b0b70b7c75a454e2b4e4bec", + model_version="22e577100d55787eb2547d701275b9bb48f7bfa0", ) return kwargs diff --git a/tests/data/sample_batches/data_configuration.yaml b/tests/data/sample_batches/data_configuration.yaml deleted file mode 100644 index 697651b..0000000 --- a/tests/data/sample_batches/data_configuration.yaml +++ /dev/null @@ -1,81 +0,0 @@ -general: - description: Config for producing batches on GCP - name: gcp_pvnet - -input_data: - default_history_minutes: 120 - default_forecast_minutes: 480 - - gsp: - gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/GSP/v5/pv_gsp.zarr - history_minutes: 120 - forecast_minutes: 480 - time_resolution_minutes: 30 - start_datetime: "2020-01-01T00:00:00" - end_datetime: "2021-09-01T00:00:00" - gsp_image_size_pixels_height: 64 - gsp_image_size_pixels_width: 64 - gsp_meters_per_pixel: 2000 - n_gsp_per_example: 32 - is_live: false - log_level: DEBUG - metadata_only: false - - nwp: - nwp_zarr_path: /mnt/disks/nwp/UKV_intermediate_version_7.zarr - history_minutes: 120 - forecast_minutes: 480 - time_resolution_minutes: 60 - nwp_channels: - - t # live = t2m - - dswrf - nwp_image_size_pixels_height: 24 - nwp_image_size_pixels_width: 24 - nwp_meters_per_pixel: 2000 - log_level: DEBUG - - satellite: - satellite_zarr_path: - - /mnt/disks/data_ssd/2017_nonhrv.zarr - - /mnt/disks/data_ssd/2018_nonhrv.zarr - - /mnt/disks/data_ssd/2019_nonhrv.zarr - - /mnt/disks/data_ssd/2020_nonhrv.zarr - - /mnt/disks/data_ssd/2021_nonhrv.zarr - history_minutes: 90 - forecast_minutes: 0 - time_resolution_minutes: 5 - satellite_channels: - - IR_016 - - IR_039 - - IR_087 - - IR_097 - - IR_108 - - IR_120 - - IR_134 - - VIS006 - - VIS008 - - WV_062 - - WV_073 - satellite_image_size_pixels_height: 24 - satellite_image_size_pixels_width: 24 - satellite_meters_per_pixel: 6000 - keep_dawn_dusk_hours: 4 - live_delay_minutes: 30 - is_live: false - log_level: DEBUG - - hrvsatellite: - hrvsatellite_zarr_path: - - gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/v4/2019_hrv.zarr - - gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/v4/2020_hrv.zarr - - gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/v4/2021_hrv.zarr - history_minutes: 60 - forecast_minutes: 0 - time_resolution_minutes: 5 - hrvsatellite_channels: - - HRV - hrvsatellite_image_size_pixels_height: 24 - hrvsatellite_image_size_pixels_width: 24 - -output_data: - filepath: "not-needed" diff --git a/tests/data/sample_batches/train/000000.pt b/tests/data/sample_batches/train/000000.pt deleted file mode 100644 index 6e99981..0000000 Binary files a/tests/data/sample_batches/train/000000.pt and /dev/null differ diff --git a/tests/test_data/sample_batches/data_configuration.yaml b/tests/test_data/sample_batches/data_configuration.yaml new file mode 100644 index 0000000..27f0099 --- /dev/null +++ b/tests/test_data/sample_batches/data_configuration.yaml @@ -0,0 +1,424 @@ +general: + description: Config for producing batches on GCP + name: gcp_pvnet + +input_data: + default_history_minutes: 120 + default_forecast_minutes: 480 + + gsp: + gsp_zarr_path: /mnt/disks/nwp_rechunk/pv_gsp_temp.zarr + history_minutes: 120 + forecast_minutes: 480 + time_resolution_minutes: 30 + metadata_only: false + + pv: + pv_files_groups: + - label: solar_sheffield_passiv + pv_filename: /mnt/disks/nwp_rechunk/passive/v1.1/passiv.netcdf + pv_metadata_filename: /mnt/disks/nwp_rechunk/passive/v0/system_metadata_OCF_ONLY.csv + pv_ml_ids: + [ + 154, + 155, + 156, + 158, + 159, + 160, + 162, + 164, + 165, + 166, + 167, + 168, + 169, + 171, + 173, + 177, + 178, + 179, + 181, + 182, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 197, + 198, + 199, + 200, + 202, + 204, + 205, + 206, + 208, + 209, + 211, + 214, + 215, + 216, + 217, + 218, + 219, + 220, + 221, + 225, + 229, + 230, + 232, + 233, + 234, + 236, + 242, + 243, + 245, + 252, + 254, + 255, + 256, + 257, + 258, + 260, + 261, + 262, + 265, + 267, + 268, + 272, + 273, + 275, + 276, + 277, + 280, + 281, + 282, + 283, + 287, + 289, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 301, + 302, + 303, + 304, + 306, + 307, + 309, + 310, + 311, + 317, + 318, + 319, + 320, + 321, + 322, + 323, + 325, + 326, + 329, + 332, + 333, + 335, + 336, + 338, + 340, + 342, + 344, + 345, + 346, + 348, + 349, + 352, + 354, + 355, + 356, + 357, + 360, + 362, + 363, + 368, + 369, + 370, + 371, + 372, + 374, + 375, + 376, + 378, + 380, + 382, + 384, + 385, + 388, + 390, + 391, + 393, + 396, + 397, + 398, + 399, + 400, + 401, + 403, + 404, + 405, + 406, + 407, + 409, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 423, + 424, + 425, + 426, + 427, + 429, + 431, + 435, + 437, + 438, + 440, + 441, + 444, + 447, + 450, + 451, + 453, + 456, + 457, + 458, + 459, + 464, + 465, + 466, + 467, + 468, + 470, + 471, + 473, + 474, + 476, + 477, + 479, + 480, + 481, + 482, + 485, + 486, + 488, + 490, + 491, + 492, + 493, + 496, + 498, + 501, + 503, + 506, + 507, + 508, + 509, + 510, + 511, + 512, + 513, + 515, + 516, + 517, + 519, + 520, + 521, + 522, + 524, + 526, + 527, + 528, + 531, + 532, + 536, + 537, + 538, + 540, + 541, + 542, + 543, + 544, + 545, + 549, + 550, + 551, + 552, + 553, + 554, + 556, + 557, + 560, + 561, + 563, + 566, + 568, + 571, + 572, + 575, + 576, + 577, + 579, + 580, + 581, + 582, + 584, + 585, + 588, + 590, + 594, + 595, + 597, + 600, + 602, + 603, + 604, + 606, + 611, + 613, + 614, + 616, + 618, + 620, + 622, + 623, + 624, + 625, + 626, + 628, + 629, + 630, + 631, + 636, + 637, + 638, + 640, + 641, + 642, + 644, + 645, + 646, + 650, + 651, + 652, + 653, + 654, + 655, + 657, + 660, + 661, + 662, + 663, + 666, + 667, + 668, + 670, + 675, + 676, + 679, + 681, + 683, + 684, + 685, + 687, + 696, + 698, + 701, + 702, + 703, + 704, + 706, + 710, + 722, + 723, + 724, + 725, + 727, + 728, + 729, + 730, + 732, + 733, + 734, + 735, + 736, + 737, + ] + history_minutes: 180 + forecast_minutes: 0 + time_resolution_minutes: 5 + + nwp: + ukv: + nwp_zarr_path: + - /mnt/disks/nwp_rechunk/UKV_intermediate_version_7.1.zarr + - /mnt/disks/nwp_rechunk/UKV_2021_NWP_missing_chunked.zarr + - /mnt/disks/nwp_rechunk/UKV_2022_NWP_chunked.zarr + - /mnt/disks/nwp_rechunk/UKV_2023_chunked.zarr + history_minutes: 120 + forecast_minutes: 480 + time_resolution_minutes: 60 + nwp_channels: + - t # live = t2m + - dswrf + #- lcc + #- mcc + #- hcc + #- dlwrf + nwp_image_size_pixels_height: 24 + nwp_image_size_pixels_width: 24 + nwp_provider: ukv + + satellite: + satellite_zarr_path: + - /mnt/disks/nwp_rechunk/filled_sat/2017_nonhrv.zarr + - /mnt/disks/nwp_rechunk/filled_sat/2018_nonhrv.zarr + - /mnt/disks/nwp_rechunk/filled_sat/2019_nonhrv.zarr + - /mnt/disks/nwp_rechunk/filled_sat/2020_nonhrv.zarr + - /mnt/disks/nwp_rechunk/filled_sat/2021_nonhrv.zarr + - /mnt/disks/nwp_rechunk/filled_sat/2022_nonhrv.zarr + - /mnt/disks/nwp_rechunk/filled_sat/2023_nonhrv.zarr + history_minutes: 90 + forecast_minutes: 0 + live_delay_minutes: 30 + time_resolution_minutes: 5 + satellite_channels: + - IR_016 + - IR_039 + - IR_087 + - IR_097 + - IR_108 + - IR_120 + - IR_134 + - VIS006 + - VIS008 + - WV_062 + - WV_073 + satellite_image_size_pixels_height: 24 + satellite_image_size_pixels_width: 24 diff --git a/tests/data/sample_batches/datamodule.yaml b/tests/test_data/sample_batches/datamodule.yaml similarity index 95% rename from tests/data/sample_batches/datamodule.yaml rename to tests/test_data/sample_batches/datamodule.yaml index 2c30ed1..411a1ad 100644 --- a/tests/data/sample_batches/datamodule.yaml +++ b/tests/test_data/sample_batches/datamodule.yaml @@ -1,6 +1,6 @@ _target_: pvnet.data.datamodule.DataModule configuration: /home/jamesfulton/repos/PVNet/configs/datamodule/configuration/gcp_configuration.yaml -num_workers: 20 +num_workers: 2 prefetch_factor: 2 batch_size: 8 train_period: diff --git a/tests/test_data/sample_batches/train/000000.pt b/tests/test_data/sample_batches/train/000000.pt new file mode 100644 index 0000000..543affe Binary files /dev/null and b/tests/test_data/sample_batches/train/000000.pt differ