Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 4, 2024
1 parent 478e06c commit 7eb1f69
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
6 changes: 3 additions & 3 deletions graph_weather/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
The dataloader has to do a few things for the model to work correctly
1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the
1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the
correct resolution
2. Calculate the top-of-atmosphere solar radiation for each location at
2. Calculate the top-of-atmosphere solar radiation for each location at
fcurrent time and 10 other
times +- 12 hours
3. Add day-of-year, sin(lat), cos(lat), sin(lon), cos(lon) as well
Expand Down Expand Up @@ -126,7 +126,7 @@ def __getitem__(self, item):
],
axis=-1,
)
# Not want to predict non-physics variables -> Output only the data variables?
# Not want to predict non-physics variables -> Output only the data variables?
# Would be simpler, and just add in the new ones each time

output_data = np.stack(
Expand Down
1 change: 1 addition & 0 deletions train/deepspeed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def configure_optimizers(self):
def forward(self, x):
return self.model(x)


class FakeDataset(Dataset):
def __init__(self):
super(FakeDataset, self).__init__()
Expand Down
2 changes: 1 addition & 1 deletion train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def __iter__(self):
seed=np.random.randint(low=-1000, high=10000), buffer_size=4
)
for data in iter(self.dataset):
#TODO Currently leaves out lat/lon/Sun irradience, and land/sea mask and topographic data
# TODO Currently leaves out lat/lon/Sun irradience, and land/sea mask and topographic data
data.update(
{
key: np.expand_dims(np.asarray(value), axis=-1)
Expand Down
1 change: 1 addition & 0 deletions train/run_fulll.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)


class XrDataset(Dataset):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 7eb1f69

Please sign in to comment.