Skip to content

Commit

Permalink
uk_regional update
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Jan 22, 2025
1 parent 59ad792 commit 7e06ac1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 47 deletions.
4 changes: 2 additions & 2 deletions ocf_data_sampler/sample/uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ class PVNetSample(SampleBase):
SatelliteSampleKey.satellite_actual
}

def __init__(self, data: Optional[Dict[str, Any]] = None):
def __init__(self):
logger.debug("Initialise PVNetSample instance")
super().__init__(data)
super().__init__()

def to_numpy(self) -> Dict[str, np.ndarray]:
""" Convert sample to numpy arrays - nested handling """
Expand Down
89 changes: 44 additions & 45 deletions tests/test_sample/test_uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,39 +82,38 @@ def create_test_data():
def test_sample_init():
""" Initialisation """
sample = PVNetSample()
test_data = create_test_data()
sample = PVNetSample(data=test_data)
assert isinstance(sample._data, dict)
assert len(sample._data) == 0


def test_sample_save_load():
test_data = create_test_data()
sample = PVNetSample(data=test_data)

with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
sample.save(tf.name)
loaded = PVNetSample.load(tf.name)

assert set(loaded._data.keys()) == set(sample._data.keys())

assert isinstance(loaded._data['nwp'], dict)
assert 'ukv' in loaded._data['nwp']
sample = PVNetSample()
sample._data = create_test_data()

with tempfile.NamedTemporaryFile(suffix='.pt') as tf:
sample.save(tf.name)
loaded = PVNetSample.load(tf.name)

assert set(loaded._data.keys()) == set(sample._data.keys())
assert isinstance(loaded._data['nwp'], dict)
assert 'ukv' in loaded._data['nwp']

assert loaded._data[GSPSampleKey.gsp].shape == (7,)
assert loaded._data[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
assert loaded._data[GSPSampleKey.solar_azimuth].shape == (7,)
assert loaded._data[GSPSampleKey.solar_elevation].shape == (7,)
assert loaded._data[GSPSampleKey.gsp].shape == (7,)
assert loaded._data[SatelliteSampleKey.satellite_actual].shape == (7, 1, 2, 2)
assert loaded._data[GSPSampleKey.solar_azimuth].shape == (7,)
assert loaded._data[GSPSampleKey.solar_elevation].shape == (7,)

np.testing.assert_array_almost_equal(
loaded._data[GSPSampleKey.gsp],
sample._data[GSPSampleKey.gsp]
)
np.testing.assert_array_almost_equal(
loaded._data[GSPSampleKey.gsp],
sample._data[GSPSampleKey.gsp]
)


def test_save_unsupported_format():
""" Test saving - unsupported file format """
test_data = create_test_data()
sample = PVNetSample(data=test_data)
sample = PVNetSample()
sample._data = create_test_data()

with tempfile.NamedTemporaryFile(suffix='.npz') as tf:
with pytest.raises(ValueError, match="Only .pt format is supported"):
sample.save(tf.name)
Expand Down Expand Up @@ -149,24 +148,24 @@ def test_dataset_get_sample(pvnet_config_filename):


def test_sample_to_numpy():
mixed_data = {
'nwp': {
'ukv': {
'nwp': np.random.rand(4, 1, 2, 2),
'x': np.array([1, 2]),
'y': np.array([1, 2])
}
},
GSPSampleKey.gsp: np.random.rand(7),
SatelliteSampleKey.satellite_actual: np.random.rand(7, 1, 2, 2),
GSPSampleKey.solar_azimuth: np.random.rand(7),
GSPSampleKey.solar_elevation: np.random.rand(7)
}
sample = PVNetSample(data=mixed_data)
numpy_data = sample.to_numpy()
assert isinstance(numpy_data[GSPSampleKey.gsp], np.ndarray)
assert isinstance(numpy_data[SatelliteSampleKey.satellite_actual], np.ndarray)
assert isinstance(numpy_data[GSPSampleKey.solar_azimuth], np.ndarray)
assert isinstance(numpy_data[GSPSampleKey.solar_elevation], np.ndarray)
sample = PVNetSample()
sample._data = {
'nwp': {
'ukv': {
'nwp': np.random.rand(4, 1, 2, 2),
'x': np.array([1, 2]),
'y': np.array([1, 2])
}
},
GSPSampleKey.gsp: np.random.rand(7),
SatelliteSampleKey.satellite_actual: np.random.rand(7, 1, 2, 2),
GSPSampleKey.solar_azimuth: np.random.rand(7),
GSPSampleKey.solar_elevation: np.random.rand(7)
}

numpy_data = sample.to_numpy()

assert isinstance(numpy_data[GSPSampleKey.gsp], np.ndarray)
assert isinstance(numpy_data[SatelliteSampleKey.satellite_actual], np.ndarray)
assert isinstance(numpy_data[GSPSampleKey.solar_azimuth], np.ndarray)
assert isinstance(numpy_data[GSPSampleKey.solar_elevation], np.ndarray)

0 comments on commit 7e06ac1

Please sign in to comment.