diff --git a/3.test_cases/10.FSDP/.gitignore b/3.test_cases/10.FSDP/.gitignore new file mode 100644 index 00000000..4d806e6c --- /dev/null +++ b/3.test_cases/10.FSDP/.gitignore @@ -0,0 +1,4 @@ +miniconda3 +pt_fsdp +checkpoints +Miniconda3-latest-Linux-x86_64.sh diff --git a/3.test_cases/10.FSDP/model_utils/checkpoint.py b/3.test_cases/10.FSDP/model_utils/checkpoint.py index a8e94293..55d8111c 100644 --- a/3.test_cases/10.FSDP/model_utils/checkpoint.py +++ b/3.test_cases/10.FSDP/model_utils/checkpoint.py @@ -51,11 +51,21 @@ def get_last_checkpoint(checkpoint_paths, model_type): steps = [int(re.findall(r'\d+steps', checkpoint.stem)[0].replace('steps','')) \ for checkpoint in checkpoint_paths] checkpoints = sorted([(step, path) for step,path in zip(steps, checkpoint_paths)]) - return checkpoints[-1][1].as_posix() + + # find last checkpoint, skipping incomplete ones + for step, path in reversed(checkpoints): + metadata_path = path.joinpath(".metadata") + if not metadata_path.exists(): + logger.warn(f"{metadata_path} not found. Skipping this incomplete checkpoint") + continue + return path.as_posix() + else: + return None def load_checkpoint(model, optimizer, scheduler, checkpoint_dir, model_type, device): - checkpoint_paths = list(Path(checkpoint_dir).glob(f"{model_type}*")) - if len(checkpoint_paths)==0: + checkpoint_paths = list(Path(checkpoint_dir).glob(f"{model_type}-*steps")) + last_checkpoint = get_last_checkpoint(checkpoint_paths, model_type) + if last_checkpoint is None: if dist.get_rank() == 0: logger.info("No Checkpoints Found") return( @@ -65,7 +75,6 @@ def load_checkpoint(model, optimizer, scheduler, checkpoint_dir, model_type, dev 0, 0, ) - last_checkpoint = get_last_checkpoint(checkpoint_paths, model_type) if dist.get_rank() == 0: logger.info("Loading checkpoint from %s ...", last_checkpoint) with FSDP.state_dict_type(