Skip to content

Commit

Permalink
Add offset and length arguments for checkpoint validation functions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored May 3, 2024
1 parent ee55424 commit 2c0d039
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 21 deletions.
34 changes: 16 additions & 18 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -54,16 +54,13 @@
_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME = f'__{dist.get_global_rank()}_0.distcp'


def _get_checkpoint_validation_function() -> Optional[Callable[[Union[Path, str]], bool]]:
"""Get the validation function by name.
Args:
name (str): Qualified name of the checkpoint validation function.
It should be in the form '{module_name}.{fn_name}'.
def _get_checkpoint_validation_function(
) -> Optional[Callable[[Union[Path, str], Optional[List[Tuple[int, int]]]], bool]]:
"""Get the validation function specified by the environment variable `CHECKPOINT_VALIDATION_FUNCTION`.
Returns:
Callable[[Union[Path, str]], bool] The checkpoint validation function that returns
True given a valid checkpoint and False otherwise.
Callable[[Union[Path, str], Optional[int], Optional[int]], bool] The checkpoint validation function that returns
True given a valid checkpoint and optionally a list of offsets and lengths to check and False otherwise.
"""
name = os.environ.get('CHECKPOINT_VALIDATION_FUNCTION', None)
if name is None:
Expand All @@ -76,14 +73,16 @@ def _get_checkpoint_validation_function() -> Optional[Callable[[Union[Path, str]
return fn


def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str]) -> Union[Path, str]:
def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str],
specs: Optional[List[Tuple[int, int]]] = None) -> Union[Path, str]:
"""Ensures that the checkpoint at checkpoint_filepath is valid.
using the function specified by the CHECKPOINT_VALIDATION_FUNCTION environment variable.
If CHECKPOINT_VALIDATION_FUNCTION is not set, we skip validation.
Args:
checkpoint_filepath (Union[Path,str]): The path to the checkpoint file.
specs (Optional[List[Tuple[int,int]]]): A list of offsets and lengths to check. Defaults to None.
Raises:
ValueError if checkpoint file is invalid.
Expand All @@ -93,11 +92,10 @@ def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str]) -> Union[Pat

# No function name has been specified.
if validate is None:
log.debug('No validation function specified. Skipping checkpoint validation.')
return checkpoint_filepath

# Validate the checkpoint.
if not validate(checkpoint_filepath):
if not validate(checkpoint_filepath, specs):
raise ValueError(f'Checkpoint at {checkpoint_filepath} is invalid.')

log.debug(f'Checkpoint at {checkpoint_filepath} is valid.')
Expand Down Expand Up @@ -169,13 +167,13 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
Raises:
ValueError if the data file is invalid.
"""
validated_checkpoint_paths = set()
path_to_specs: Dict[str, List[Tuple[int, int]]] = {}
for read_item in plan.items:
data_path = os.path.join(self.path, self.storage_data[read_item.storage_index].relative_path)
if data_path in validated_checkpoint_paths:
continue
_ensure_valid_checkpoint(data_path)
validated_checkpoint_paths.add(data_path)
item_md = self.storage_data[read_item.storage_index]
path = os.path.join(self.path, item_md.relative_path)
path_to_specs.setdefault(path, []).append((item_md.offset, item_md.length))
for path, spec in path_to_specs.items():
_ensure_valid_checkpoint(path, spec)
return super().read_data(plan, planner)

def read_metadata(self) -> Metadata:
Expand Down
18 changes: 16 additions & 2 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import time
from glob import glob
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -1766,7 +1766,14 @@ def test_rotate_checkpoints(
dist.barrier() # all ranks finish before cleaning up tmpdir


def simple_validate(filepath: str):
def simple_validate(filepath: str, specs: Optional[List[Tuple[int, int]]] = None) -> bool:
if specs is not None:
with open(filepath, 'r') as f:
for offset, length in specs:
f.seek(offset)
if f.read(length) != 'good':
return False
return True
with open(filepath, 'r') as f:
return f.read() == 'good'

Expand Down Expand Up @@ -1795,6 +1802,13 @@ def test_checkpoint_validation(tmp_path):
result = _ensure_valid_checkpoint(checkpoint_filepath)
assert result == checkpoint_filepath

# Correct usage with offset and lengths and successful validation.
with open(checkpoint_filepath, 'w') as f:
f.write('good good')
with patch.dict(os.environ, {'CHECKPOINT_VALIDATION_FUNCTION': 'tests.trainer.test_checkpoint.simple_validate'}):
result = _ensure_valid_checkpoint(checkpoint_filepath, specs=[(0, 4), (5, 4)])
assert result == checkpoint_filepath

# Correct usage and failed validation.
with open(checkpoint_filepath, 'w') as f:
f.write('bad')
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def test_checkpoint_loading_with_validation(world_size, tmp_path, is_valid_check
expectation = pytest.raises(ValueError)

def mock_get_checkpoint_validation_function():
return lambda _: is_valid_checkpoint
return lambda checkpoint_path, specs: is_valid_checkpoint

tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = os.path.join(tmp_paths[0], 'checkpoints')
Expand Down

0 comments on commit 2c0d039

Please sign in to comment.