Skip to content

Commit

Permalink
ruff fixed (docstring remaining) (#100)
Browse files Browse the repository at this point in the history
* ruff fixed (docstring remaining)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update run.py

* docstrings added

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rahul-maurya11b and pre-commit-ci[bot] authored Apr 5, 2024
1 parent c80ca5b commit ab7bc74
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 44 deletions.
20 changes: 10 additions & 10 deletions .ruff.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
select = ["E", "F", "D", "I"]
ignore = ["D200","D202","D210","D212","D415","D105",]
lint.select = ["E", "F", "D", "I"]
lint.ignore = ["D200","D202","D210","D212","D415","D105"]

# Allow autofix for all enabled rules (when `--fix`) is provided.
fixable = ["A", "B", "C", "D", "E", "F", "I"]
unfixable = []
lint.fixable = ["A", "B", "C", "D", "E", "F", "I"]
lint.unfixable = []

# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down Expand Up @@ -35,22 +35,22 @@ exclude = [
line-length = 100

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

# Assume Python 3.10.
target-version = "py311"
fix = false
fix=false
# Group violations by containing file.
output-format = "github"
ignore-init-module-imports = true
lint.ignore-init-module-imports = true

[mccabe]
[lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 10

[pydocstyle]
[lint.pydocstyle]
# Use Google-style docstrings.
convention = "google"

[per-file-ignores]
[lint.per-file-ignores]
"__init__.py" = ["F401", "E402"]
3 changes: 2 additions & 1 deletion graph_weather/data/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
2. GFS Forecast Fields
3. ERA5 Reanalysis Fields
where the variance is the variance in the 3 hour change for a variable averaged across all lat/lon and pressure levels
where the variance is the variance in the 3 hour change for a variable averaged across all lat/lon
and pressure levels
and time for (~100 random temporal frames, more the better)
Min/Max/Mean/Stddev for all those plus each type of observation in observation files
Expand Down
36 changes: 31 additions & 5 deletions graph_weather/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
The dataloader has to do a few things for the model to work correctly
1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the correct resolution
2. Calculate the top-of-atmosphere solar radiation for each location at fcurrent time and 10 other
1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the
correct resolution
2. Calculate the top-of-atmosphere solar radiation for each location at
fcurrent time and 10 other
times +- 12 hours
3. Add day-of-year, sin(lat), cos(lat), sin(lon), cos(lon) as well
3. Batch data as either in geometric batches, or more normally
Expand All @@ -20,7 +22,26 @@


class AnalysisDataset(Dataset):
"""
Dataset class for analysis data.
Args:
filepaths: List of file paths.
invariant_path: Path to the invariant file.
mean: Mean value.
std Standard deviation value.
coarsen : Coarsening factor. Defaults to 8.
Methods:
__init__: Initialize the AnalysisDataset object.
__len__: Get the length of the dataset.
__getitem__: Get an item from the dataset.
"""

def __init__(self, filepaths, invariant_path, mean, std, coarsen: int = 8):
"""
Initialize the AnalysisDataset object.
"""
super().__init__()
self.filepaths = sorted(filepaths)
self.invariant_path = invariant_path
Expand Down Expand Up @@ -124,7 +145,8 @@ def __getitem__(self, item):
],
axis=-1,
)
# Not want to predict non-physics variables -> Output only the data variables? Would be simpler, and just add in the new ones each time
# Not want to predict non-physics variables -> Output only the data variables?
# Would be simpler, and just add in the new ones each time

output_data = np.stack(
[
Expand Down Expand Up @@ -154,9 +176,13 @@ def __getitem__(self, item):
obs_data = xr.open_zarr(
"/home/jacob/Development/prepbufr.gdas.20160101.t00z.nr.48h.raw.zarr", consolidated=True
)
# TODO Embedding? These should stay consistent across all of the inputs, so can just load the values, not the strings?
# Should only take in the quality markers, observations, reported observation time relative to start point

# TODO Embedding? These should stay consistent across all of the inputs, so can just load the values
# not the strings?
# Should only take in the quality markers, observations, reported observation time relative to start
# point
# Observation errors, and background values, lat/lon/height/speed of observing thing

print(obs_data)
print(obs_data.hdr_inst_typ.values)
print(obs_data.hdr_irpt_typ.values)
Expand Down
34 changes: 17 additions & 17 deletions graph_weather/models/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@ def __init__(
Args:
observation_lat_lons: Lat/lon points of the observations
output_lat_lons: List of latitude and longitudes for the output analysis
resolution: Resolution of the H3 grid, prefer even resolutions, as
odd ones have octogons and heptagons as well
observation_dim: Input feature size
analysis_dim: Output Analysis feature dim
node_dim: Node hidden dimension
edge_dim: Edge hidden dimension
num_blocks: Number of message passing blocks in the Processor
hidden_dim_processor_node: Hidden dimension of the node processors
hidden_dim_processor_edge: Hidden dimension of the edge processors
hidden_layers_processor_node: Number of hidden layers in the node processors
hidden_layers_processor_edge: Number of hidden layers in the edge processors
hidden_dim_decoder:Number of hidden dimensions in the decoder
hidden_layers_decoder: Number of layers in the decoder
norm_type: Type of norm for the MLPs
one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
use_checkpointing: Whether to use gradient checkpointing or not
output_lat_lons: List of latitude and longitudes for the output analysis
resolution: Resolution of the H3 grid, prefer even resolutions, as
odd ones have octogons and heptagons as well
observation_dim: Input feature size
analysis_dim: Output Analysis feature dim
node_dim: Node hidden dimension
edge_dim: Edge hidden dimension
num_blocks: Number of message passing blocks in the Processor
hidden_dim_processor_node: Hidden dimension of the node processors
hidden_dim_processor_edge: Hidden dimension of the edge processors
hidden_layers_processor_node: Number of hidden layers in the node processors
hidden_layers_processor_edge: Number of hidden layers in the edge processors
hidden_dim_decoder:Number of hidden dimensions in the decoder
hidden_layers_decoder: Number of layers in the decoder
norm_type: Type of norm for the MLPs
one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
use_checkpointing: Whether to use gradient checkpointing or not
"""
super().__init__()

Expand Down
3 changes: 2 additions & 1 deletion graph_weather/models/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
odd ones have octogons and heptagons as well
feature_dim: Input feature size
aux_dim: Number of non-NWP features (i.e. landsea mask, lat/lon, etc)
output_dim: Optional, output feature size, useful if want only subset of variables in output
output_dim: Optional, output feature size, useful if want only subset of variables in
output
node_dim: Node hidden dimension
edge_dim: Edge hidden dimension
num_blocks: Number of message passing blocks in the Processor
Expand Down
2 changes: 2 additions & 0 deletions graph_weather/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
Args:
feature_variance: Variance for each of the physical features
lat_lons: List of lat/lon pairs, used to generate weighting
device: checks for device whether it supports gpu or not
normalize: option for normalize
"""
# TODO Rescale by nominal static air density at each pressure level
super().__init__()
Expand Down
61 changes: 58 additions & 3 deletions train/deepspeed_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Module for training the Graph Weather forecaster model using PyTorch Lightning."""

import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset

from graph_weather import GraphWeatherForecaster

Expand All @@ -11,13 +14,42 @@


class LitModel(pl.LightningModule):
"""
LightningModule for the weather forecasting model.
Args:
lat_lons: List of latitude and longitude coordinates.
feature_dim: Dimension of the input features.
aux_dim : Dimension of the auxiliary features.
Methods:
__init__: Initialize the LitModel object.
"""

def __init__(self, lat_lons, feature_dim, aux_dim):
"""
Initialize the LitModel object.
Args:
lat_lons: List of latitude and longitude coordinates.
feature_dim : Dimension of the input features.
aux_dim : Dimension of the auxiliary features.
"""
super().__init__()
self.model = GraphWeatherForecaster(
lat_lons=lat_lons, feature_dim=feature_dim, aux_dim=aux_dim
)

def training_step(self, batch):
"""
Performs a training step.
Args:
batch: A batch of training data.
Returns:
The computed loss.
"""
x, y = batch
x = x.half()
y = y.half()
Expand All @@ -27,18 +59,41 @@ def training_step(self, batch):
return loss

def configure_optimizers(self):
"""
Configures the optimizer used during training.
Returns:
The optimizer.
"""
return torch.optim.AdamW(self.parameters())

def forward(self, x):
return self.model(x)
"""
Forward pass.
Args:
x (torch.Tensor): Input data.
# Fake data
from torch.utils.data import DataLoader, Dataset
Returns:
torch.Tensor: Output of the model.
"""
return self.model(x)


class FakeDataset(Dataset):
"""
Dataset class for generating fake data.
Methods:
__init__: Initialize the FakeDataset object.
__len__: Return the length of the dataset.
__getitem__: Get an item from the dataset.
"""

def __init__(self):
"""
Initialize the FakeDataset object.
"""
super(FakeDataset, self).__init__()

def __len__(self):
Expand Down
Loading

0 comments on commit ab7bc74

Please sign in to comment.