Skip to content

Commit

Permalink
Merge pull request #16 from MeteoSwiss/setup_tsa
Browse files Browse the repository at this point in the history
Updated code for latest single zarr archive on Tsa
Apologies for bypass, but git history must be intact before major re-training
  • Loading branch information
sadamov authored Apr 26, 2024
2 parents 22cddcb + bdff899 commit a8e0b60
Show file tree
Hide file tree
Showing 18 changed files with 514 additions and 290 deletions.
2 changes: 1 addition & 1 deletion create_grid_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main():
# Concatenate grid features
grid_features = torch.cat(
(grid_xy, geopotential, grid_border_mask), dim=1
) # (N_grid, 4)
) # (N_grid, 2+N_fields+1)

torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt"))

Expand Down
12 changes: 10 additions & 2 deletions create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,20 @@ def main():
plot_graph(
pyg_down, title=f"Down graph, {from_level} -> {to_level}"
)
plt.show()
plt.savefig(
os.path.join(
graph_dir_path, f"mesh_down_graph_{from_level}.png"
)
)

plot_graph(
pyg_down, title=f"Up graph, {to_level} -> {from_level}"
)
plt.show()
plt.savefig(
os.path.join(
graph_dir_path, f"mesh_up_graph_{to_level}.png"
)
)

# Save up and down edges
save_edges_list(up_graphs, "mesh_up", graph_dir_path)
Expand Down
299 changes: 172 additions & 127 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
@@ -1,163 +1,208 @@
# Standard library
import os
import subprocess
from argparse import ArgumentParser

# Third-party
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

# First-party
from neural_lam import constants
from neural_lam.weather_dataset import WeatherDataset
from neural_lam.weather_dataset import WeatherDataModule


def main():
"""
Pre-compute parameter weights to be used in loss function
"""
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--dataset",
type=str,
default="meps_example",
help="Dataset to compute weights for (default: meps_example)",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size when iterating over the dataset",
)
parser.add_argument(
"--step_length",
type=int,
default=1,
help="Step length in hours to consider single time step (default: 1)",
)
parser.add_argument(
"--n_workers",
type=int,
default=4,
help="Number of workers in data loader (default: 4)",
)
args = parser.parse_args()
def get_rank():
"""Get the rank of the current process in the distributed group."""
return int(os.environ["SLURM_PROCID"])

static_dir_path = os.path.join("data", args.dataset, "static")

# Create parameter weights based on height
w_list = []
for var_name, pw in zip(
constants.PARAM_NAMES_SHORT, constants.PARAM_WEIGHTS.values()
):
# Determine the levels to iterate over
levels = (
constants.LEVEL_WEIGHTS.values()
if constants.IS_3D[var_name]
else [1]
)
def get_world_size():
"""Get the number of processes in the distributed group."""
return int(os.environ["SLURM_NTASKS"])

# Iterate over the levels
for lw in levels:
w_list.append(pw * lw)

w_list = np.array(w_list)
def setup(rank, world_size): # pylint: disable=redefined-outer-name
"""Initialize the distributed group."""
try:
master_node = (
subprocess.check_output(
"scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1",
shell=True,
)
.strip()
.decode("utf-8")
)
except Exception as e:
print(f"Error getting master node IP: {e}")
raise
master_port = "12355"
os.environ["MASTER_ADDR"] = master_node
os.environ["MASTER_PORT"] = master_port
if torch.cuda.is_available():
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
"""Destroy the distributed group."""
dist.destroy_process_group()


def main(rank, world_size): # pylint: disable=redefined-outer-name
"""Compute the mean and standard deviation of the input data."""
setup(rank, world_size)
parser = ArgumentParser(description="Training arguments")
parser.add_argument("--dataset", type=str, default="meps_example")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--subset", type=int, default=8760)
parser.add_argument("--n_workers", type=int, default=4)
args = parser.parse_args()

print("Saving parameter weights...")
np.save(
os.path.join(static_dir_path, "parameter_weights.npy"),
w_list.astype("float32"),
if args.subset % (world_size * args.batch_size) != 0:
raise ValueError(
"Subset size must be divisible by (world_size * batch_size)"
)

device = torch.device(
f"cuda:{rank % torch.cuda.device_count()}"
if torch.cuda.is_available()
else "cpu"
)
static_dir_path = os.path.join("data", args.dataset, "static")

# Load dataset without any subsampling
ds = WeatherDataset(
args.dataset,
split="train",
data_module = WeatherDataModule(
dataset_name=args.dataset,
standardize=False,
) # Without standardization
loader = torch.utils.data.DataLoader(
ds, args.batch_size, shuffle=False, num_workers=args.n_workers
subset=args.subset,
batch_size=args.batch_size,
num_workers=args.n_workers,
)
# Compute mean and std.-dev. of each parameter (+ flux forcing)
# across full dataset
print("Computing mean and std.-dev. for parameters...")
data_module.setup(stage="fit")

train_sampler = DistributedSampler(
data_module.train_dataset, num_replicas=world_size, rank=rank
)
train_loader = torch.utils.data.DataLoader(
data_module.train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.n_workers,
)

if rank == 0:
w_list = [
pw * lw
for var_name, pw in zip(
constants.PARAM_NAMES_SHORT, constants.PARAM_WEIGHTS.values()
)
for lw in (
constants.LEVEL_WEIGHTS.values()
if constants.IS_3D[var_name]
else [1]
)
]
np.save(
os.path.join(static_dir_path, "parameter_weights.npy"),
np.array(w_list, dtype="float32"),
)

means = []
squares = []
flux_means = []
flux_squares = []
for batch_data in tqdm(loader):
for init_batch, target_batch, _, forcing_batch in tqdm(
train_loader, disable=rank != 0
):
batch = torch.cat((init_batch, target_batch), dim=1).to(device)
means.append(torch.mean(batch, dim=(1, 2)))
squares.append(torch.mean(batch**2, dim=(1, 2)))
if constants.GRID_FORCING_DIM > 0:
flux_batch = forcing_batch[:, :, :, 1].to(device)
flux_means.append(torch.mean(flux_batch))
flux_squares.append(torch.mean(flux_batch**2))

dist.barrier()

means_gathered = [None] * world_size
squares_gathered = [None] * world_size
dist.all_gather_object(means_gathered, torch.cat(means, dim=0))
dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0))

if rank == 0:
means_all = torch.cat(means_gathered, dim=0)
squares_all = torch.cat(squares_gathered, dim=0)
mean = torch.mean(means_all, dim=0)
second_moment = torch.mean(squares_all, dim=0)
std = torch.sqrt(second_moment - mean**2)
torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt"))
torch.save(std, os.path.join(static_dir_path, "parameter_std.pt"))

if constants.GRID_FORCING_DIM > 0:
init_batch, target_batch, _, forcing_batch = batch_data
flux_batch = forcing_batch[:, :, :, 0] # Flux is first index
flux_means.append(torch.mean(flux_batch)) # (,)
flux_squares.append(torch.mean(flux_batch**2)) # (,)
else:
init_batch, target_batch, _ = batch_data

batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t, N_grid, d_features)
means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,)
squares.append(
torch.mean(batch**2, dim=(1, 2))
) # (N_batch, d_features,)

mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features)
second_moment = torch.mean(torch.cat(squares, dim=0), dim=0)
std = torch.sqrt(second_moment - mean**2) # (d_features)

if constants.GRID_FORCING_DIM > 0:
flux_mean = torch.mean(torch.stack(flux_means)) # (,)
flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,)
flux_stats = torch.stack((flux_mean, flux_std))

print("Saving mean flux_stats...")
torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt"))
print("Saving mean, std.-dev...")
torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt"))
torch.save(std, os.path.join(static_dir_path, "parameter_std.pt"))

# Compute mean and std.-dev. of one-step differences across the dataset
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
args.dataset,
split="train",
flux_means_all = torch.stack(flux_means)
flux_squares_all = torch.stack(flux_squares)
flux_mean = torch.mean(flux_means_all)
flux_second_moment = torch.mean(flux_squares_all)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2)
torch.save(
{"mean": flux_mean, "std": flux_std},
os.path.join(static_dir_path, "flux_stats.pt"),
)

data_module = WeatherDataModule(
dataset_name=args.dataset,
standardize=True,
) # Re-load with standardization
loader_standard = torch.utils.data.DataLoader(
ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers
subset=args.subset,
batch_size=args.batch_size,
num_workers=args.n_workers,
)
data_module.setup(stage="fit")

train_sampler = DistributedSampler(
data_module.train_dataset, num_replicas=world_size, rank=rank
)
train_loader = torch.utils.data.DataLoader(
data_module.train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.n_workers,
)

# Compute mean and std-dev of one-step differences
diff_means = []
diff_squares = []
for batch_data in tqdm(loader_standard):
if constants.GRID_FORCING_DIM > 0:
init_batch, target_batch, _, forcing_batch = batch_data
flux_batch = forcing_batch[:, :, :, 0] # Flux is first index
flux_means.append(torch.mean(flux_batch)) # (,)
flux_squares.append(torch.mean(flux_batch**2)) # (,)
else:
init_batch, target_batch, _ = batch_data
batch_diffs = init_batch[:, 1:] - target_batch
# (N_batch', N_t-1, N_grid, d_features)

diff_means.append(
torch.mean(batch_diffs, dim=(1, 2))
) # (N_batch', d_features,)
diff_squares.append(
torch.mean(batch_diffs**2, dim=(1, 2))
) # (N_batch', d_features,)

diff_mean = torch.mean(torch.cat(diff_means, dim=0), dim=0) # (d_features)
diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0)
diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features)

print("Saving one-step difference mean and std.-dev...")
torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt"))
torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt"))
for init_batch, target_batch, _, _ in tqdm(train_loader, disable=rank != 0):
batch = torch.cat((init_batch, target_batch), dim=1).to(device)
diffs = batch[:, 1:] - batch[:, :-1]
diff_means.append(torch.mean(diffs, dim=(1, 2)))
diff_squares.append(torch.mean(diffs**2, dim=(1, 2)))

dist.barrier()

diff_means_gathered = [None] * world_size
diff_squares_gathered = [None] * world_size
dist.all_gather_object(diff_means_gathered, torch.cat(diff_means, dim=0))
dist.all_gather_object(
diff_squares_gathered, torch.cat(diff_squares, dim=0)
)

if rank == 0:
diff_means_all = torch.cat(diff_means_gathered, dim=0)
diff_squares_all = torch.cat(diff_squares_gathered, dim=0)
diff_mean = torch.mean(diff_means_all, dim=0)
diff_second_moment = torch.mean(diff_squares_all, dim=0)
diff_std = torch.sqrt(diff_second_moment - diff_mean**2)
torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt"))
torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt"))

cleanup()


if __name__ == "__main__":
main()
rank = get_rank()
world_size = get_world_size()
main(rank, world_size)
4 changes: 2 additions & 2 deletions create_static_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def main():
parser.add_argument(
"--field_names",
nargs="+",
default=["HSURF", "FI", "HFL"],
default=["HSURF", "FI"],
help=(
"Names of the fields to extract from the .nc file "
'(default: ["HSURF", "FI", "HFL"])'
'(default: ["HSURF", "FI"])'
),
)
parser.add_argument(
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ dependencies:
- pyproj
- pyprojroot
- pytorch=2.2.2=py3.12_cuda11.8_cudnn8.7.0_0
- pytorch-cuda=11.8
- pytorch-lightning
- scikit-learn
- scipy
Expand Down
Loading

0 comments on commit a8e0b60

Please sign in to comment.