Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jul 29, 2021
1 parent a2ed86f commit 4b79071
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions satflow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,38 +52,33 @@ def test_satflow_all():
config = load_config("satflow/tests/configs/satflow_all.yaml")
cloudflow = SatFlowDataset([dataset], config)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
channels = check_channels(config)
assert x.shape == (13, channels, 128, 128)
assert y.shape == (24, 1, 128, 128)
assert image.shape == (24, 12, 128, 128)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])


def test_satflow_large():
dataset = wds.WebDataset("datasets/satflow-test.tar")
config = load_config("satflow/tests/configs/satflow_large.yaml")
cloudflow = SatFlowDataset([dataset], config)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
channels = check_channels(config)
assert x.shape == (13, channels, 256, 256)
assert y.shape == (24, 1, 256, 256)
assert image.shape == (24, 12, 256, 256)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])


def test_satflow_crop():
dataset = wds.WebDataset("datasets/satflow-test.tar")
config = load_config("satflow/tests/configs/satflow_crop.yaml")
cloudflow = SatFlowDataset([dataset], config)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
channels = check_channels(config)
assert x.shape == (13, channels, 256, 256)
assert y.shape == (24, 1, 64, 64)
assert image.shape == (24, 12, 64, 64)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(x[0], x[-1])
Expand Down Expand Up @@ -117,13 +112,11 @@ def test_satflow_time_channels_all():
config = load_config("satflow/tests/configs/satflow_time_channels_all.yaml")
cloudflow = SatFlowDataset([dataset], config)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
channels = check_channels(config)
assert x.shape == (channels, 128, 128)
assert y.shape == (24, 128, 128)
assert image.shape == (12 * 24, 128, 128)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])


def test_cloudflow():
Expand All @@ -141,16 +134,14 @@ def test_satflow_all_deterministic_validation():
config = load_config("satflow/tests/configs/satflow_all.yaml")
cloudflow = SatFlowDataset([dataset], config, train=False)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
dataset2 = wds.WebDataset("datasets/satflow-test.tar")
cloudflow2 = SatFlowDataset([dataset2], config, train=False)
data = next(iter(cloudflow2))
x2, image2, y2 = data
x2, image2 = data
np.testing.assert_almost_equal(x, x2)
np.testing.assert_almost_equal(image, image2)
np.testing.assert_almost_equal(y, y2)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])


def test_satflow_all_deterministic_validation_restart():
Expand Down

0 comments on commit 4b79071

Please sign in to comment.