Skip to content

Commit

Permalink
ruff fixed (docstring remaining)
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-maurya11b committed Apr 4, 2024
1 parent c80ca5b commit 478e06c
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 46 deletions.
20 changes: 10 additions & 10 deletions .ruff.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
select = ["E", "F", "D", "I"]
ignore = ["D200","D202","D210","D212","D415","D105",]
lint.select = ["E", "F", "D", "I"]
lint.ignore = ["D200","D202","D210","D212","D415","D105","D101","D107","D103","D102","D100"]

# Allow autofix for all enabled rules (when `--fix`) is provided.
fixable = ["A", "B", "C", "D", "E", "F", "I"]
unfixable = []
lint.fixable = ["A", "B", "C", "D", "E", "F", "I"]
lint.unfixable = []

# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down Expand Up @@ -35,22 +35,22 @@ exclude = [
line-length = 100

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

# Assume Python 3.10.
target-version = "py311"
fix = false
fix=false
# Group violations by containing file.
output-format = "github"
ignore-init-module-imports = true
lint.ignore-init-module-imports = true

[mccabe]
[lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 10

[pydocstyle]
[lint.pydocstyle]
# Use Google-style docstrings.
convention = "google"

[per-file-ignores]
[lint.per-file-ignores]
"__init__.py" = ["F401", "E402"]
3 changes: 2 additions & 1 deletion graph_weather/data/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
2. GFS Forecast Fields
3. ERA5 Reanalysis Fields
where the variance is the variance in the 3 hour change for a variable averaged across all lat/lon and pressure levels
where the variance is the variance in the 3 hour change for a variable averaged across all lat/lon
and pressure levels
and time for (~100 random temporal frames, more the better)
Min/Max/Mean/Stddev for all those plus each type of observation in observation files
Expand Down
17 changes: 12 additions & 5 deletions graph_weather/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
The dataloader has to do a few things for the model to work correctly
1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the correct resolution
2. Calculate the top-of-atmosphere solar radiation for each location at fcurrent time and 10 other
1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the
correct resolution
2. Calculate the top-of-atmosphere solar radiation for each location at
fcurrent time and 10 other
times +- 12 hours
3. Add day-of-year, sin(lat), cos(lat), sin(lon), cos(lon) as well
3. Batch data as either in geometric batches, or more normally
Expand Down Expand Up @@ -124,7 +126,8 @@ def __getitem__(self, item):
],
axis=-1,
)
# Not want to predict non-physics variables -> Output only the data variables? Would be simpler, and just add in the new ones each time
# Not want to predict non-physics variables -> Output only the data variables?
# Would be simpler, and just add in the new ones each time

output_data = np.stack(
[
Expand Down Expand Up @@ -154,9 +157,13 @@ def __getitem__(self, item):
obs_data = xr.open_zarr(
"/home/jacob/Development/prepbufr.gdas.20160101.t00z.nr.48h.raw.zarr", consolidated=True
)
# TODO Embedding? These should stay consistent across all of the inputs, so can just load the values, not the strings?
# Should only take in the quality markers, observations, reported observation time relative to start point

# TODO Embedding? These should stay consistent across all of the inputs, so can just load the values
# not the strings?
# Should only take in the quality markers, observations, reported observation time relative to start
# point
# Observation errors, and background values, lat/lon/height/speed of observing thing

print(obs_data)
print(obs_data.hdr_inst_typ.values)
print(obs_data.hdr_irpt_typ.values)
Expand Down
34 changes: 17 additions & 17 deletions graph_weather/models/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@ def __init__(
Args:
observation_lat_lons: Lat/lon points of the observations
output_lat_lons: List of latitude and longitudes for the output analysis
resolution: Resolution of the H3 grid, prefer even resolutions, as
odd ones have octogons and heptagons as well
observation_dim: Input feature size
analysis_dim: Output Analysis feature dim
node_dim: Node hidden dimension
edge_dim: Edge hidden dimension
num_blocks: Number of message passing blocks in the Processor
hidden_dim_processor_node: Hidden dimension of the node processors
hidden_dim_processor_edge: Hidden dimension of the edge processors
hidden_layers_processor_node: Number of hidden layers in the node processors
hidden_layers_processor_edge: Number of hidden layers in the edge processors
hidden_dim_decoder:Number of hidden dimensions in the decoder
hidden_layers_decoder: Number of layers in the decoder
norm_type: Type of norm for the MLPs
one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
use_checkpointing: Whether to use gradient checkpointing or not
output_lat_lons: List of latitude and longitudes for the output analysis
resolution: Resolution of the H3 grid, prefer even resolutions, as
odd ones have octogons and heptagons as well
observation_dim: Input feature size
analysis_dim: Output Analysis feature dim
node_dim: Node hidden dimension
edge_dim: Edge hidden dimension
num_blocks: Number of message passing blocks in the Processor
hidden_dim_processor_node: Hidden dimension of the node processors
hidden_dim_processor_edge: Hidden dimension of the edge processors
hidden_layers_processor_node: Number of hidden layers in the node processors
hidden_layers_processor_edge: Number of hidden layers in the edge processors
hidden_dim_decoder:Number of hidden dimensions in the decoder
hidden_layers_decoder: Number of layers in the decoder
norm_type: Type of norm for the MLPs
one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
use_checkpointing: Whether to use gradient checkpointing or not
"""
super().__init__()

Expand Down
3 changes: 2 additions & 1 deletion graph_weather/models/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
odd ones have octogons and heptagons as well
feature_dim: Input feature size
aux_dim: Number of non-NWP features (i.e. landsea mask, lat/lon, etc)
output_dim: Optional, output feature size, useful if want only subset of variables in output
output_dim: Optional, output feature size, useful if want only subset of variables in
output
node_dim: Node hidden dimension
edge_dim: Edge hidden dimension
num_blocks: Number of message passing blocks in the Processor
Expand Down
2 changes: 2 additions & 0 deletions graph_weather/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
Args:
feature_variance: Variance for each of the physical features
lat_lons: List of lat/lon pairs, used to generate weighting
device: checks for device whether it supports gpu or not
normalize: option for normalize
"""
# TODO Rescale by nominal static air density at each pressure level
super().__init__()
Expand Down
6 changes: 1 addition & 5 deletions train/deepspeed_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset

from graph_weather import GraphWeatherForecaster

Expand Down Expand Up @@ -32,11 +33,6 @@ def configure_optimizers(self):
def forward(self, x):
return self.model(x)


# Fake data
from torch.utils.data import DataLoader, Dataset


class FakeDataset(Dataset):
def __init__(self):
super(FakeDataset, self).__init__()
Expand Down
8 changes: 5 additions & 3 deletions train/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Training script for training the weather forecasting model"""

import time

import datasets
import numpy as np
import pandas as pd
Expand All @@ -22,7 +24,7 @@ def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)


def get_mean_stds():
def get_mean_stds(): # noqa: D103
names = [
"CLMR",
"GRLE",
Expand Down Expand Up @@ -333,7 +335,7 @@ def __iter__(self):
seed=np.random.randint(low=-1000, high=10000), buffer_size=4
)
for data in iter(self.dataset):
# TODO Currently leaves out lat/lon/Sun irradience, and land/sea mask and topographic data
#TODO Currently leaves out lat/lon/Sun irradience, and land/sea mask and topographic data
data.update(
{
key: np.expand_dims(np.asarray(value), axis=-1)
Expand Down Expand Up @@ -468,7 +470,7 @@ def __iter__(self):
).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
print("Done Setup")
import time


for epoch in range(100): # loop over the dataset multiple times
running_loss = 0.0
Expand Down
8 changes: 4 additions & 4 deletions train/run_fulll.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import json
import os
import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)
import time

import numpy as np
import torch
Expand All @@ -18,6 +16,8 @@
from graph_weather.data import const
from graph_weather.models.losses import NormalizedMSELoss

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)

class XrDataset(Dataset):
def __init__(self):
Expand Down Expand Up @@ -110,7 +110,7 @@ def __getitem__(self, item):
model = GraphWeatherForecaster(lat_lons, feature_dim=597, num_blocks=6).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
print("Done Setup")
import time


for epoch in range(100): # loop over the dataset multiple times
running_loss = 0.0
Expand Down

0 comments on commit 478e06c

Please sign in to comment.