Skip to content

Commit

Permalink
try
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Nov 26, 2024
1 parent a7ca23b commit 27abd7c
Showing 1 changed file with 73 additions and 75 deletions.
148 changes: 73 additions & 75 deletions src/nanotron/serialize/main.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
import os
from pathlib import Path
from typing import Optional, cast

import torch
from datasets.download.streaming_download_manager import xPath
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR

from nanotron import distributed as dist
from nanotron import logging
from nanotron import optim as optim
from nanotron.config import Config
from nanotron.constants import MODEL_CONFIG_FILE_NAME
from nanotron.distributed import get_global_rank
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open
from nanotron.sanity_checks import (
assert_tensor_synced_across_pg,
check_optim_state_in_sync,
)
from nanotron.serialize.metadata import CheckpointMetadata, TrainingMetadata, load_meta, save_meta
from nanotron.serialize.metadata import TrainingMetadata, save_meta
from nanotron.serialize.optimizer import (
load_lr_scheduler,
load_optimizer,
save_lr_scheduler,
save_optimizer,
)
from nanotron.serialize.weights import load_weights, save_weights
from nanotron.serialize.weights import save_weights

"""
We're going to use safetensors. The reason is that loading segments is going to be much easier
Expand All @@ -40,7 +39,7 @@
Version 1:
- serialize -> dumps every process weights in individual files
- load -> assume topology is exactly the same
- load -> assume topology is exactly the same.
"""


Expand Down Expand Up @@ -204,46 +203,7 @@ def save(
dist.barrier(parallel_context.world_pg)


def load(
model: nn.Module,
optimizer: optim.BaseOptimizer,
lr_scheduler,
parallel_context: ParallelContext,
root_folder: Path,
) -> CheckpointMetadata:
"""
Load checkpoint, raise if checkpoint is assumed corrupted. Inplace updates `model` and `optimizer` to have the newest parameters.
TODO @thomasw21: Make this topology agnostic
:param filepath: Path
:return:
"""
checkpoint_metadata = load_meta(parallel_context=parallel_context, root_folder=root_folder)
load_weights(model=model, parallel_context=parallel_context, root_folder=root_folder)

# SANITY CHECK: assert that optimizer's named_params still point to model's params (check only the first one)
if isinstance(optimizer, optim.ZeroDistributedOptimizer):
if (
len(optimizer.zero_named_param_groups) > 0
and len(optimizer.zero_named_param_groups[0]["named_params"]) > 0
):
optim_model_param_name, optim_model_param = optimizer.zero_named_param_groups[0]["named_params"][0]
if isinstance(model, DistributedDataParallel):
optim_model_param_name = f"module.{optim_model_param_name}"
param = next(p for n, p in model.named_parameters() if n == optim_model_param_name)
assert param.data_ptr() == optim_model_param.data_ptr()

load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder)
load_lr_scheduler(
lr_scheduler=lr_scheduler,
is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer),
parallel_context=parallel_context,
root_folder=root_folder,
)
return checkpoint_metadata


def parse_ckpt_path(config: Config) -> Optional[Path]:
def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]:
"""Parse checkpoint path from config and download checkpoint from S3 if needed.
Args:
Expand All @@ -253,33 +213,71 @@ def parse_ckpt_path(config: Config) -> Optional[Path]:
Path to checkpoint or None if no checkpoint.
"""
load_from_candidate = config.checkpoints.resume_checkpoint_path
if load_from_candidate is None:
return None

latest_meta_path: Path = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
with open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi:
# TODO @thomasw21: make a better structure system so that we get typing correct
load_from_candidate = int(fi.read())
checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate)

elif (config.checkpoints.resume_checkpoint_path / MODEL_CONFIG_FILE_NAME).exists():
# we assume that the checkpoint path is a path to a checkpoint
checkpoint_path = config.checkpoints.resume_checkpoint_path
if load_from_candidate is not None:
if check_path_is_local(load_from_candidate):
latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi:
# TODO @thomasw21: make a better structure system so that we get typing correct
load_from_candidate = int(fi.read())
checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate)

elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists():
# we assume that the checkpoint path is a path to a checkpoint
checkpoint_path = config.checkpoints.resume_checkpoint_path

else:
log_rank(
f"No previous checkpoint found in: {latest_meta_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
return None

else:
log_rank(
f"No previous checkpoint found in: {latest_meta_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
return None
log_rank(
f"Loading checkpoint from {checkpoint_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
else:
latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
# if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint
with fs_open(latest_meta_path, mode="r") as fi:
latest_iteration = int(fi.read())
s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path
checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path
elif config.checkpoints.resume_checkpoint_path.exists():
# we assume that the checkpoint path is a path to a checkpoint
s3_path = config.checkpoints.resume_checkpoint_path # load_path
checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path
else:
log_rank(
f"No previous checkpoint found in: {config.checkpoints.resume_checkpoint_path}\n Initializing from scratch.",
logger=logger,
level=logging.WARNING,
rank=0,
)
return None
log_rank(
f"Downloading checkpoint from S3 in {checkpoint_path} ",
logger=logger,
level=logging.WARNING,
rank=0,
)
# Download checkpoint from S3
s3_mover = S3Mover(
local_path=os.path.join(checkpoint_path),
s3_path=os.path.join(s3_path),
s5cmd_numworkers=config.s3_upload.s5cmd_numworkers,
s5cmd_concurrency=config.s3_upload.s5cmd_concurrency,
s5cmd_path=config.s3_upload.s5cmd_path,
dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0),
)
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
s3_mover.start_downloading()
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)

log_rank(
f"Loading checkpoint from {checkpoint_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
return checkpoint_path
return checkpoint_path

0 comments on commit 27abd7c

Please sign in to comment.