diff --git a/graph_weather/data/dataloader.py b/graph_weather/data/dataloader.py index e30e0e67..c732d4e9 100644 --- a/graph_weather/data/dataloader.py +++ b/graph_weather/data/dataloader.py @@ -2,9 +2,9 @@ The dataloader has to do a few things for the model to work correctly -1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the +1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the correct resolution -2. Calculate the top-of-atmosphere solar radiation for each location at +2. Calculate the top-of-atmosphere solar radiation for each location at fcurrent time and 10 other times +- 12 hours 3. Add day-of-year, sin(lat), cos(lat), sin(lon), cos(lon) as well @@ -126,7 +126,7 @@ def __getitem__(self, item): ], axis=-1, ) - # Not want to predict non-physics variables -> Output only the data variables? + # Not want to predict non-physics variables -> Output only the data variables? # Would be simpler, and just add in the new ones each time output_data = np.stack( diff --git a/train/deepspeed_graph.py b/train/deepspeed_graph.py index 8a191949..9a2b9575 100644 --- a/train/deepspeed_graph.py +++ b/train/deepspeed_graph.py @@ -33,6 +33,7 @@ def configure_optimizers(self): def forward(self, x): return self.model(x) + class FakeDataset(Dataset): def __init__(self): super(FakeDataset, self).__init__() diff --git a/train/run.py b/train/run.py index 7a60227e..9766190e 100644 --- a/train/run.py +++ b/train/run.py @@ -335,7 +335,7 @@ def __iter__(self): seed=np.random.randint(low=-1000, high=10000), buffer_size=4 ) for data in iter(self.dataset): - #TODO Currently leaves out lat/lon/Sun irradience, and land/sea mask and topographic data + # TODO Currently leaves out lat/lon/Sun irradience, and land/sea mask and topographic data data.update( { key: np.expand_dims(np.asarray(value), axis=-1) diff --git a/train/run_fulll.py b/train/run_fulll.py index 089a2d4b..1ff0111b 100644 --- a/train/run_fulll.py +++ b/train/run_fulll.py @@ -19,6 +19,7 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(BASE_DIR) + class XrDataset(Dataset): def __init__(self): super().__init__()