Skip to content

Commit

Permalink
remove unused func
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Nov 22, 2024
1 parent a87b83e commit ca0ffa6
Showing 1 changed file with 2 additions and 42 deletions.
44 changes: 2 additions & 42 deletions src/nanotron/serialize/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from datasets.download.streaming_download_manager import xPath
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR

from nanotron import distributed as dist
Expand All @@ -21,14 +20,12 @@
assert_tensor_synced_across_pg,
check_optim_state_in_sync,
)
from nanotron.serialize.metadata import CheckpointMetadata, TrainingMetadata, load_meta, save_meta
from nanotron.serialize.metadata import TrainingMetadata, save_meta
from nanotron.serialize.optimizer import (
load_lr_scheduler,
load_optimizer,
save_lr_scheduler,
save_optimizer,
)
from nanotron.serialize.weights import load_weights, save_weights
from nanotron.serialize.weights import save_weights

"""
We're going to use safetensors. The reason is that loading segments is going to be much easier
Expand Down Expand Up @@ -206,43 +203,6 @@ def save(
dist.barrier(parallel_context.world_pg)


def load(
model: nn.Module,
optimizer: optim.BaseOptimizer,
lr_scheduler,
parallel_context: ParallelContext,
root_folder: Path,
) -> CheckpointMetadata:
"""
Load checkpoint, raise if checkpoint is assumed corrupted. Inplace updates `model` and `optimizer` to have the newest parameters.
TODO @thomasw21: Make this topology agnostic
:param filepath: Path
:return:
"""
checkpoint_metadata = load_meta(parallel_context=parallel_context, root_folder=root_folder)
load_weights(model=model, parallel_context=parallel_context, root_folder=root_folder)

# SANITY CHECK: assert that optimizer's named_params still point to model's params (check only the first one)
if isinstance(optimizer, optim.ZeroDistributedOptimizer):
if (
len(optimizer.zero_named_param_groups) > 0
and len(optimizer.zero_named_param_groups[0]["named_params"]) > 0
):
optim_model_param_name, optim_model_param = optimizer.zero_named_param_groups[0]["named_params"][0]
if isinstance(model, DistributedDataParallel):
optim_model_param_name = f"module.{optim_model_param_name}"
param = next(p for n, p in model.named_parameters() if n == optim_model_param_name)
assert param.data_ptr() == optim_model_param.data_ptr()

load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder)
load_lr_scheduler(
lr_scheduler=lr_scheduler,
root_folder=root_folder,
)
return checkpoint_metadata


def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]:
"""Parse checkpoint path from config and download checkpoint from S3 if needed.
Expand Down

0 comments on commit ca0ffa6

Please sign in to comment.