Skip to content

Commit

Permalink
cleanup, modularize, linting, list-comprehensions
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 31, 2024
1 parent 743a52c commit 82a36a1
Showing 1 changed file with 105 additions and 103 deletions.
208 changes: 105 additions & 103 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def __init__(
)

def __getitem__(self, idx):
if idx >= self.total_samples:
return self.base_dataset[self.original_indices[-1]]
return self.base_dataset[idx % len(self.base_dataset)]
return self.base_dataset[
self.original_indices[-1]
if idx >= self.total_samples
else idx % len(self.base_dataset)
]

def __len__(self):
return self.total_samples + self.padded_samples
Expand All @@ -48,17 +50,11 @@ def get_original_indices(self):


def get_rank():
"""Get the rank of the current process in the distributed group."""
if "SLURM_PROCID" in os.environ:
return int(os.environ["SLURM_PROCID"])
return 0
return int(os.environ.get("SLURM_PROCID", 0))


def get_world_size():
"""Get the number of processes in the distributed group."""
if "SLURM_NTASKS" in os.environ:
return int(os.environ["SLURM_NTASKS"])
return 1
return int(os.environ.get("SLURM_NTASKS", 1))


def setup(rank, world_size): # pylint: disable=redefined-outer-name
Expand All @@ -73,22 +69,57 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name
.decode("utf-8")
)
else:
print(
"\033[91mCareful, you are running this script with --parallelize "
"without any scheduler. In most cases this will result in slower "
"execution and the --parallelize flag should be removed.\033[0m"
)
master_node = "localhost"
master_port = "12355"
os.environ["MASTER_ADDR"] = master_node
os.environ["MASTER_PORT"] = master_port
if torch.cuda.is_available():
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(
"nccl" if torch.cuda.is_available() else "gloo",
rank=rank,
world_size=world_size,
)
print(
f"Initialized {dist.get_backend()} process group with "
f"world size "
f"{world_size}."
f"Initialized {dist.get_backend()} process group with world size {world_size}."
)


def main(): # pylint: disable=redefined-outer-name
def save_stats(
static_dir_path, means, squares, flux_means, flux_squares, filename_prefix
):
means = torch.stack(means) if len(means) > 1 else means[0]
squares = torch.stack(squares) if len(squares) > 1 else squares[0]
mean = torch.mean(means, dim=0)
second_moment = torch.mean(squares, dim=0)
std = torch.sqrt(second_moment - mean**2)
torch.save(
mean.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_mean.pt")
)
torch.save(
std.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_std.pt")
)

if len(flux_means) == 0:
return
flux_means = (
torch.stack(flux_means) if len(flux_means) > 1 else flux_means[0]
)
flux_squares = (
torch.stack(flux_squares) if len(flux_squares) > 1 else flux_squares[0]
)
flux_mean = torch.mean(flux_means)
flux_second_moment = torch.mean(flux_squares)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2)
torch.save(
torch.stack((flux_mean, flux_std)).cpu(),
os.path.join(static_dir_path, f"{filename_prefix}_flux_stats.pt"),
)


def main():
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--data_config",
Expand Down Expand Up @@ -129,16 +160,15 @@ def main(): # pylint: disable=redefined-outer-name

rank = get_rank()
world_size = get_world_size()

config_loader = config.Config.from_file(args.data_config)

if args.parallelize:

setup(rank, world_size)
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
device = torch.device(
f"cuda:{rank}" if torch.cuda.is_available() else "cpu"
)
torch.cuda.set_device(device) if torch.cuda.is_available() else None

if rank == 0:
static_dir_path = os.path.join(
Expand Down Expand Up @@ -171,14 +201,13 @@ def main(): # pylint: disable=redefined-outer-name
pred_length=63,
standardize=False,
)
ds = PaddedWeatherDataset(
ds,
world_size,
args.batch_size,
duplication_factor=args.duplication_factor,
)
if args.parallelize:
ds = PaddedWeatherDataset(
ds,
world_size,
args.batch_size,
duplication_factor=args.duplication_factor,
)

sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank)
else:
sampler = None
Expand All @@ -190,14 +219,9 @@ def main(): # pylint: disable=redefined-outer-name
sampler=sampler,
)

# Compute mean and std.-dev. of each parameter (+ flux forcing) across
# full dataset
if rank == 0:
print("Computing mean and std.-dev. for parameters...")
means = []
squares = []
flux_means = []
flux_squares = []
means, squares, flux_means, flux_squares = [], [], [], []

for init_batch, target_batch, forcing_batch in tqdm(loader):
if args.parallelize:
Expand All @@ -214,45 +238,32 @@ def main(): # pylint: disable=redefined-outer-name
flux_squares.append(torch.mean(flux_batch**2).cpu())

if args.parallelize:
means_gathered = [None] * world_size
squares_gathered = [None] * world_size
means_gathered, squares_gathered = [None] * world_size, [
None
] * world_size
dist.all_gather_object(means_gathered, torch.cat(means, dim=0))
dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0))
if rank == 0:
means_all = torch.cat(means_gathered, dim=0)
squares_all = torch.cat(squares_gathered, dim=0)
original_indices = ds.get_original_indices()
means = [means_all[i] for i in original_indices]
squares = [squares_all[i] for i in original_indices]
means_gathered, squares_gathered = torch.cat(
means_gathered, dim=0
), torch.cat(squares_gathered, dim=0)
means, squares = [
means_gathered[i] for i in ds.get_original_indices()
], [squares_gathered[i] for i in ds.get_original_indices()]

if rank == 0:
if len(means) > 1:
means = torch.stack(means)
squares = torch.stack(squares)
else:
means = means[0]
squares = squares[0]
mean = torch.mean(means, dim=0)
second_moment = torch.mean(squares, dim=0)
std = torch.sqrt(second_moment - mean**2)
torch.save(
mean.cpu(), os.path.join(static_dir_path, "parameter_mean.pt")
)
torch.save(std.cpu(), os.path.join(static_dir_path, "parameter_std.pt"))
if len(flux_means) > 1:
flux_means_all = torch.stack(flux_means)
flux_squares_all = torch.stack(flux_squares)
else:
flux_means_all = flux_means[0]
flux_squares_all = flux_squares[0]
flux_mean = torch.mean(flux_means_all)
flux_second_moment = torch.mean(flux_squares_all)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2)
torch.save(
torch.stack((flux_mean, flux_std)).cpu(),
os.path.join(static_dir_path, "flux_stats.pt"),
save_stats(
static_dir_path,
means,
squares,
flux_means,
flux_squares,
"parameter",
)

if args.parallelize:
dist.barrier()

if rank == 0:
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
Expand All @@ -262,14 +273,13 @@ def main(): # pylint: disable=redefined-outer-name
pred_length=63,
standardize=True,
)
ds_standard = PaddedWeatherDataset(
ds_standard,
world_size,
args.batch_size,
duplication_factor=args.duplication_factor,
)
if args.parallelize:
ds_standard = PaddedWeatherDataset(
ds_standard,
world_size,
args.batch_size,
duplication_factor=args.duplication_factor,
)

sampler_standard = DistributedSampler(
ds_standard, num_replicas=world_size, rank=rank
)
Expand All @@ -284,8 +294,7 @@ def main(): # pylint: disable=redefined-outer-name
)
used_subsample_len = (65 // args.step_length) * args.step_length

diff_means = []
diff_squares = []
diff_means, diff_squares = [], []

for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0):
if args.parallelize:
Expand All @@ -301,41 +310,34 @@ def main(): # pylint: disable=redefined-outer-name
dim=0,
)
batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1]

diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu())
diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu())

if args.parallelize:
dist.barrier()

diff_means_gathered = [None] * world_size
diff_squares_gathered = [None] * world_size
diff_means_gathered, diff_squares_gathered = [None] * world_size, [
None
] * world_size
dist.all_gather_object(
diff_means_gathered, torch.cat(diff_means, dim=0)
)
dist.all_gather_object(
diff_squares_gathered, torch.cat(diff_squares, dim=0)
)

if rank == 0:
diff_means_all = torch.cat(diff_means_gathered, dim=0)
diff_squares_all = torch.cat(diff_squares_gathered, dim=0)
original_indices = ds_standard.get_original_indices()
diff_means = [diff_means_all[i] for i in original_indices]
diff_squares = [diff_squares_all[i] for i in original_indices]
diff_means_gathered, diff_squares_gathered = torch.cat(
diff_means_gathered, dim=0
), torch.cat(diff_squares_gathered, dim=0)
diff_means, diff_squares = [
diff_means_gathered[i]
for i in ds_standard.get_original_indices()
], [
diff_squares_gathered[i]
for i in ds_standard.get_original_indices()
]

if rank == 0:
if len(diff_means) > 1:
diff_means = torch.stack(diff_means)
diff_squares = torch.stack(diff_squares)
else:
diff_means = diff_means[0]
diff_squares = diff_squares[0]
diff_mean = torch.mean(diff_means, dim=0)
diff_second_moment = torch.mean(diff_squares, dim=0)
diff_std = torch.sqrt(diff_second_moment - diff_mean**2)

torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt"))
torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt"))
save_stats(static_dir_path, diff_means, diff_squares, [], [], "diff")

if args.parallelize:
dist.destroy_process_group()
Expand Down

0 comments on commit 82a36a1

Please sign in to comment.