Skip to content

Commit

Permalink
Dataloader and Noise (#109)
Browse files Browse the repository at this point in the history
* Dataloader and Noise

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add comments

* Reorganize functions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gbruno16 and pre-commit-ci[bot] authored Jun 8, 2024
1 parent e140814 commit b0bab00
Show file tree
Hide file tree
Showing 12 changed files with 1,350 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ target-version = "py311"
fix=false
# Group violations by containing file.
output-format = "github"
lint.ignore-init-module-imports = true
#lint.ignore-init-module-imports = true

[lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
Expand Down
1 change: 1 addition & 0 deletions environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies:
- zarr
- h3-py
- numpy
- pyshtools
- pip:
- datasets
- einops
Expand Down
1 change: 1 addition & 0 deletions environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- zarr
- h3-py
- numpy
- pyshtools
- pip:
- datasets
- einops
Expand Down
24 changes: 21 additions & 3 deletions graph_weather/data/IFSAnalysis_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
The dataloader for IFS analysis.
"""

import numpy as np
import torchvision.transforms as transforms
import xarray as xr
Expand All @@ -23,7 +27,21 @@


class IFSAnalisysDataset(Dataset):
"""
Dataset for IFSAnalysis.
Args:
filepath: path of the dataset.
features: list of features.
start_year: initial year. Defaults to 2016.
end_year: ending year. Defaults to 2022.
"""

def __init__(self, filepath: str, features: list, start_year: int = 2016, end_year: int = 2022):
"""
Initialize the dataset object.
"""

super().__init__()
assert (
start_year <= end_year
Expand All @@ -46,15 +64,15 @@ def __getitem__(self, idx):
end = self.data.isel(time=idx + 1)

# Extract NWP features
input_data = self.nwp_features_extraction(start)
output_data = self.nwp_features_extraction(end)
input_data = self._nwp_features_extraction(start)
output_data = self._nwp_features_extraction(end)

return (
(transforms).ToTensor()(input_data).view(-1, input_data.shape[-1]),
(transforms).ToTensor()(output_data).view(-1, output_data.shape[-1]),
)

def nwp_features_extraction(self, data):
def _nwp_features_extraction(self, data):
data_cube = np.stack(
[
(data[f"{var}"].values - IFS_MEAN[f"{var}"]) / (IFS_STD[f"{var}"] + 1e-6)
Expand Down
Loading

0 comments on commit b0bab00

Please sign in to comment.