Skip to content

Commit

Permalink
`erge branch 'dev' into composer_lora
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 26, 2024
2 parents 67e3cb2 + 322e90f commit 4f2fb19
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 76 deletions.
2 changes: 1 addition & 1 deletion composer/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Composer Version."""

__version__ = '0.18.0'
__version__ = '0.18.1'
8 changes: 4 additions & 4 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def get_model_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the model.
"""
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -909,7 +909,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the optimizer.
"""
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def load_model_state(
model_on_rank = state_dict['model'] is not None

if model_on_rank:
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
set_model_state_dict(
model=self.model,
Expand Down Expand Up @@ -1277,7 +1277,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True):
strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer
state dict should perfectly match the keys in the optimizer instance.
"""
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict
optimizer = self.optimizers[0]
set_optimizer_state_dict(
Expand Down
2 changes: 1 addition & 1 deletion composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class RemoteUploaderDownloader(LoggerDestination):
backend_kwargs={
'provider': 's3',
'container': 'my-bucket',
'provider_kwargs=': {
'provider_kwargs': {
'key': 'AKIA...',
'secret': '*********',
'region': 'ap-northeast-1',
Expand Down
10 changes: 0 additions & 10 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,6 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Monkeypatch dtensor support
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0
FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore

# Monkeypath state_dict
from torch.distributed.checkpoint import state_dict # type: ignore

from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0
state_dict._verify_options = _verify_options_t2p2p0

elif version.parse(torch.__version__) < version.parse('2.3.1'):
# Monkey patch for torch < 2.3.1 ie torch == 2.3.0
# Note: this is the same patch as 2.2.0, we are just making a new if branch
Expand Down
5 changes: 4 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,10 @@ def __init__(
assert not isinstance(device_train_microbatch_size, str)

# Distributed
dist.initialize_dist(device, dist_timeout)
if deepspeed_config is not None or fsdp_config is not None or dist.get_world_size() > 1:
# Deepspeed and FSDP both require torch.distributed to be initialized, even if the world size is 1
# And torch.distributed is always required for multi-rank training
dist.initialize_dist(device, dist_timeout)

# Reproducibility
rank_zero_seed, seed = _distribute_and_get_random_seed(seed, device)
Expand Down
67 changes: 33 additions & 34 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ def load_sharded_checkpoint(
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
exclude_algorithms: Optional[list[str]] = None,
algorithm_passes: Optional[list[AlgorithmPass]] = None,
) -> list[dict]:

) -> Union[list[dict], None]:
if not using_torch_2():
raise ValueError(
f'Sharded checkpoint loading requires torch version >= 2.0.0. You have torch version {torch.__version__}')
Expand All @@ -389,7 +388,6 @@ def load_sharded_checkpoint(
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner

# This function is used so we can figure out which ranks need to load saved rngs and which can just make their own.
def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
Expand Down Expand Up @@ -501,13 +499,17 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
with torch.no_grad():
# 1. Load model and metadata first
if load_weights_only:
state_dict = {'state': {'model': state.get_model_state_dict()}}
state_dict: Dict[str, Any] = {'state': {'model': state.get_model_state_dict()}}
else:
cur_state_dict = state.state_dict()
# For older versions of torch, we load optimizer separately.
if version.parse(torch.__version__) < version.parse('2.1.3'):
if version.parse(torch.__version__) < version.parse('2.2.9'):
cur_state_dict.pop('optimizers')
state_dict = {'state': cur_state_dict}
num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
state_dict: Dict[str, Any] = {
'state': cur_state_dict,
'rng': reproducibility.get_rng_state()[:num_rng_ranks],
}

if ignore_keys:
# Filter provided list of key paths
Expand All @@ -518,17 +520,32 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
# Ensure state exists
state_dict['state'] = state_dict.get('state', {})

if version.parse(torch.__version__) > version.parse('2.1.3'):
# Only some ranks are meant to load checkpoint
expect_file = False
process_group = None
device_mesh = state.fsdp_device_mesh
if device_mesh is not None and device_mesh.ndim == 2:
# If hybrid shard, only rank in first replica saves
expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0)
if expect_file:
process_group = device_mesh.get_group(1) # Shard process_group for first replica
log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}')
else:
expect_file = True

if version.parse(torch.__version__) > version.parse('2.2.9'):
dist_cp.load( # type: ignore
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
process_group=process_group,
)
else:
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
process_group=process_group,
)

state.load_state_dict(
Expand All @@ -540,33 +557,14 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
)

# 2. Optionally load optimizer
# if we are using later than 2.1.0 then optimizer will already be loaded
if version.parse(torch.__version__) < version.parse('2.1.3') and not load_weights_only:
# if we are using later than 2.2.9 then optimizer will already be loaded
if version.parse(torch.__version__) < version.parse('2.2.9') and not load_weights_only:
optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'],
optimizer_key='optimizers',
storage_reader=storage_reader)
state._legacy_load_optim_state(optim_state)

# 3. Optionally load RNG
rng_state_dicts = reproducibility.get_rng_state()
if not load_weights_only:
# If we are resuming on more ranks than were used at save time we only want to load in rngs for those ranks
num_ranks_that_saved_rng = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
rng_state_dicts_load = {}
rng_state_dicts_load['rng'] = rng_state_dicts[:num_ranks_that_saved_rng] if len(
rng_state_dicts) > num_ranks_that_saved_rng else rng_state_dicts
dist_cp.load_state_dict(
state_dict=rng_state_dicts_load,
storage_reader=storage_reader,
planner=load_planner,
)
# We also want to append newly generated rng states for the ranks that don't have an rng state to load in
# if we are resuming on more ranks than were used at save time.
if len(rng_state_dicts) > num_ranks_that_saved_rng:
rng_state_dicts_load['rng'].extend(rng_state_dicts[num_ranks_that_saved_rng:])
rng_state_dicts = rng_state_dicts_load['rng']

return rng_state_dicts
return state_dict.get('rng', None)


def _get_local_rank_zero_path(path: Optional[str]) -> str:
Expand Down Expand Up @@ -641,7 +639,7 @@ def download_checkpoint(path: str,
raise FileNotFoundError(
(f'Checkpoint {_format_path_with_current_rank(path)} does not exist, '
f'but is required for sharded checkpointing on rank {dist.get_global_rank()}. '
'Please ensure that the checkpoint exists and your load_path was specified as a format string'
'Please ensure that the checkpoint exists and your load_path was specified as a format string '
'with the {rank} argument.')) from e

if extracted_checkpoint_folder is not None:
Expand Down Expand Up @@ -968,12 +966,12 @@ def _save_checkpoint(
state_dict['state'] = state_dict.get('state', {})

if state.fsdp_sharded_state_dict_enabled:
# To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top
# To load optimizer states with 2.0 <= torch < 2.2.9 , the optimizer state must be at the top
# level of the state dict because the load_sharded_optimizer_state_dict function
# requires a top level state dict key for the optimizer.
# See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271
# for more info.
if using_torch_2() and version.parse(torch.__version__) < version.parse('2.1.3'):
if using_torch_2() and version.parse(torch.__version__) < version.parse('2.2.9'):
if not weights_only:
state_dict['optimizers'] = state_dict['state'].pop('optimizers')
log.debug('State dict created.')
Expand Down Expand Up @@ -1010,15 +1008,16 @@ def _save_checkpoint(
process_group = None
device_mesh = state.fsdp_device_mesh
if device_mesh is not None and device_mesh.ndim == 2:
# If hybrid shard, only rank in first replica saves
expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0)
if expect_file:
process_group = device_mesh.get_group(1) # Only save on first replica
process_group = device_mesh.get_group(1) # Shard process_group for first replica
log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}')
else:
expect_file = True

if expect_file:
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
dist_cp.save( # type: ignore
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(dirname),
Expand Down
25 changes: 20 additions & 5 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from composer.utils import dist
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.misc import partial_format
from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore,
UCObjectStore)
from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
OCIObjectStore, S3ObjectStore, UCObjectStore)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX

if TYPE_CHECKING:
Expand Down Expand Up @@ -319,6 +319,7 @@ def parse_uri(uri: str) -> Tuple[str, str, str]:
Tuple[str, str, str]: A tuple containing the backend (e.g. s3), bucket name, and path.
Backend name will be empty string if the input is a local path
"""
uri = uri.replace('AZURE_BLOBS', 'azure') # urlparse does not support _ in scheme
parse_result = urlparse(uri)
backend, net_loc, path = parse_result.scheme, parse_result.netloc, parse_result.path
bucket_name = net_loc if '@' not in net_loc else net_loc.split('@')[0]
Expand Down Expand Up @@ -354,6 +355,13 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'azure':
return LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket_name,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
)
elif backend == 'dbfs':
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
store = None
Expand Down Expand Up @@ -411,14 +419,21 @@ def maybe_create_remote_uploader_downloader_from_uri(
return None
if backend in ['s3', 'oci', 'gs']:
return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}')

elif backend == 'azure':
return RemoteUploaderDownloader(
bucket_uri=f'libcloud://{bucket_name}',
backend_kwargs={
'provider': 'AZURE_BLOBS',
'container': bucket_name,
'key_environ': 'AZURE_ACCOUNT_NAME',
'secret_environ': 'AZURE_ACCOUNT_ACCESS_KEY',
},
)
elif backend == 'dbfs':
return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})

elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB via URI. Please use '
'WandBLogger with log_artifacts set to True')

else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported RemoteUploaderDownloader object stores')
Expand Down
4 changes: 2 additions & 2 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ all dependencies for both NLP and Vision models. They are built on top of the
<!-- BEGIN_COMPOSER_BUILD_MATRIX -->
| Composer Version | CUDA Support | Docker Tag |
|--------------------|----------------|----------------------------------------------------------------|
| 0.18.0 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.18.0` |
| 0.18.0 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.18.0_cpu` |
| 0.18.1 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.18.1` |
| 0.18.1 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.18.1_cpu` |
<!-- END_COMPOSER_BUILD_MATRIX -->

**Note**: For a lightweight installation, we recommended using a [MosaicML PyTorch Image](#pytorch-images) and manually
Expand Down
12 changes: 6 additions & 6 deletions docker/build_matrix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@
TORCHVISION_VERSION: 0.18.0
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.18.0
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.18.1
CUDA_VERSION: 12.1.0
IMAGE_NAME: composer-0-18-0
IMAGE_NAME: composer-0-18-1
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
Expand All @@ -216,23 +216,23 @@
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.1.2
TAGS:
- mosaicml/composer:0.18.0
- mosaicml/composer:0.18.1
- mosaicml/composer:latest
TARGET: composer_stage
TORCHVISION_VERSION: 0.16.2
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: ubuntu:20.04
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.18.0
COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.18.1
CUDA_VERSION: ''
IMAGE_NAME: composer-0-18-0-cpu
IMAGE_NAME: composer-0-18-1-cpu
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
PYTORCH_VERSION: 2.1.2
TAGS:
- mosaicml/composer:0.18.0_cpu
- mosaicml/composer:0.18.1_cpu
- mosaicml/composer:latest_cpu
TARGET: composer_stage
TORCHVISION_VERSION: 0.16.2
2 changes: 1 addition & 1 deletion docker/generate_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _main():
composer_entries = []

# The `GIT_COMMIT` is a placeholder and Jenkins will substitute it with the actual git commit for the `composer_staging` images
composer_versions = ['0.18.0'] # Only build images for the latest composer version
composer_versions = ['0.18.1'] # Only build images for the latest composer version
composer_python_versions = [LATEST_PYTHON_VERSION] # just build composer against the latest

for product in itertools.product(composer_python_versions, composer_versions, cuda_options):
Expand Down
Loading

0 comments on commit 4f2fb19

Please sign in to comment.