From 4a53dfeb043ed4e26bec7740fb1711e540ebe5a4 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 23 Jan 2024 10:30:08 -0800 Subject: [PATCH 1/5] Azure checkpointing support (#2893) * v1 * fix * fix * logs * dump env * fix * logs * force logs * bucket support * typo * more logs * logs * more logs * fix autoresume * logs * fix * fix * lint * morelogs * logs * fix autoresume * fix * lint * fix * fix lstirp * strip prefix * muck around * logs * azure * timestamp * fix * state * logs * logs * remove * game * fix * lint --- .../loggers/remote_uploader_downloader.py | 2 +- composer/utils/checkpoint.py | 2 +- composer/utils/file_helpers.py | 25 +++++++++++++++---- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index 76d70bdf3e..0ee65c832b 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -103,7 +103,7 @@ class RemoteUploaderDownloader(LoggerDestination): backend_kwargs={ 'provider': 's3', 'container': 'my-bucket', - 'provider_kwargs=': { + 'provider_kwargs': { 'key': 'AKIA...', 'secret': '*********', 'region': 'ap-northeast-1', diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 63e87f57fe..2af494e68b 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -641,7 +641,7 @@ def download_checkpoint(path: str, raise FileNotFoundError( (f'Checkpoint {_format_path_with_current_rank(path)} does not exist, ' f'but is required for sharded checkpointing on rank {dist.get_global_rank()}. ' - 'Please ensure that the checkpoint exists and your load_path was specified as a format string' + 'Please ensure that the checkpoint exists and your load_path was specified as a format string ' 'with the {rank} argument.')) from e if extracted_checkpoint_folder is not None: diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index d62487e106..c42aa7ce6f 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -21,8 +21,8 @@ from composer.utils import dist from composer.utils.iter_helpers import iterate_with_callback from composer.utils.misc import partial_format -from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, - UCObjectStore) +from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore, + OCIObjectStore, S3ObjectStore, UCObjectStore) from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX if TYPE_CHECKING: @@ -319,6 +319,7 @@ def parse_uri(uri: str) -> Tuple[str, str, str]: Tuple[str, str, str]: A tuple containing the backend (e.g. s3), bucket name, and path. Backend name will be empty string if the input is a local path """ + uri = uri.replace('AZURE_BLOBS', 'azure') # urlparse does not support _ in scheme parse_result = urlparse(uri) backend, net_loc, path = parse_result.scheme, parse_result.netloc, parse_result.path bucket_name = net_loc if '@' not in net_loc else net_loc.split('@')[0] @@ -354,6 +355,13 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]: return GCSObjectStore(bucket=bucket_name) elif backend == 'oci': return OCIObjectStore(bucket=bucket_name) + elif backend == 'azure': + return LibcloudObjectStore( + provider='AZURE_BLOBS', + container=bucket_name, + key_environ='AZURE_ACCOUNT_NAME', + secret_environ='AZURE_ACCOUNT_ACCESS_KEY', + ) elif backend == 'dbfs': if path.startswith(MLFLOW_DBFS_PATH_PREFIX): store = None @@ -411,14 +419,21 @@ def maybe_create_remote_uploader_downloader_from_uri( return None if backend in ['s3', 'oci', 'gs']: return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}') - + elif backend == 'azure': + return RemoteUploaderDownloader( + bucket_uri=f'libcloud://{bucket_name}', + backend_kwargs={ + 'provider': 'AZURE_BLOBS', + 'container': bucket_name, + 'key_environ': 'AZURE_ACCOUNT_NAME', + 'secret_environ': 'AZURE_ACCOUNT_ACCESS_KEY', + }, + ) elif backend == 'dbfs': return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path}) - elif backend == 'wandb': raise NotImplementedError(f'There is no implementation for WandB via URI. Please use ' 'WandBLogger with log_artifacts set to True') - else: raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use ' 'one of the supported RemoteUploaderDownloader object stores') From cfc439a5919f76288a921e05e23079682b846d40 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 23 Jan 2024 16:49:57 -0800 Subject: [PATCH 2/5] Pass PG into checkpoint load and load rng with state_dict (#2897) * checkdown * remove comment * lint * comments * fix * accelerate test * fix test * lint * fix test --- composer/utils/checkpoint.py | 59 +++++++++++---------------- tests/trainer/test_fsdp_checkpoint.py | 22 +++++++--- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 2af494e68b..c47184eada 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -367,8 +367,7 @@ def load_sharded_checkpoint( ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, exclude_algorithms: Optional[list[str]] = None, algorithm_passes: Optional[list[AlgorithmPass]] = None, -) -> list[dict]: - +) -> Union[list[dict], None]: if not using_torch_2(): raise ValueError( f'Sharded checkpoint loading requires torch version >= 2.0.0. You have torch version {torch.__version__}') @@ -389,16 +388,6 @@ def load_sharded_checkpoint( from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner - # This function is used so we can figure out which ranks need to load saved rngs and which can just make their own. - def _get_num_ranks_that_saved_rng(metadata: Metadata): - rng_inds = [] - for field_name, field_value in metadata.planner_data.items(): - if 'rng' in field_name: - _, rng_rank_index, _ = field_value - rng_inds.append(rng_rank_index) - rng_inds = set(rng_inds) - return len(rng_inds) - class FileSystemReaderWithValidation(dist_cp.FileSystemReader): """FileSystemReader that validates checkpoint files prior to reading.""" @@ -501,13 +490,16 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): with torch.no_grad(): # 1. Load model and metadata first if load_weights_only: - state_dict = {'state': {'model': state.get_model_state_dict()}} + state_dict: Dict[str, Any] = {'state': {'model': state.get_model_state_dict()}} else: cur_state_dict = state.state_dict() # For older versions of torch, we load optimizer separately. if version.parse(torch.__version__) < version.parse('2.1.3'): cur_state_dict.pop('optimizers') - state_dict = {'state': cur_state_dict} + state_dict: Dict[str, Any] = { + 'state': cur_state_dict, + 'rng': reproducibility.get_rng_state(), + } if ignore_keys: # Filter provided list of key paths @@ -518,17 +510,32 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # Ensure state exists state_dict['state'] = state_dict.get('state', {}) + # Only some ranks are meant to load checkpoint + expect_file = False + process_group = None + device_mesh = state.fsdp_device_mesh + if device_mesh is not None and device_mesh.ndim == 2: + # If hybrid shard, only rank in first replica saves + expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) + if expect_file: + process_group = device_mesh.get_group(1) # Shard process_group for first replica + log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}') + else: + expect_file = True + if version.parse(torch.__version__) > version.parse('2.1.3'): dist_cp.load( # type: ignore state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, + process_group=process_group, ) else: dist_cp.load_state_dict( state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, + process_group=process_group, ) state.load_state_dict( @@ -547,26 +554,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): storage_reader=storage_reader) state._legacy_load_optim_state(optim_state) - # 3. Optionally load RNG - rng_state_dicts = reproducibility.get_rng_state() - if not load_weights_only: - # If we are resuming on more ranks than were used at save time we only want to load in rngs for those ranks - num_ranks_that_saved_rng = _get_num_ranks_that_saved_rng(storage_reader.read_metadata()) - rng_state_dicts_load = {} - rng_state_dicts_load['rng'] = rng_state_dicts[:num_ranks_that_saved_rng] if len( - rng_state_dicts) > num_ranks_that_saved_rng else rng_state_dicts - dist_cp.load_state_dict( - state_dict=rng_state_dicts_load, - storage_reader=storage_reader, - planner=load_planner, - ) - # We also want to append newly generated rng states for the ranks that don't have an rng state to load in - # if we are resuming on more ranks than were used at save time. - if len(rng_state_dicts) > num_ranks_that_saved_rng: - rng_state_dicts_load['rng'].extend(rng_state_dicts[num_ranks_that_saved_rng:]) - rng_state_dicts = rng_state_dicts_load['rng'] - - return rng_state_dicts + return state_dict.get('rng', None) def _get_local_rank_zero_path(path: Optional[str]) -> str: @@ -1010,9 +998,10 @@ def _save_checkpoint( process_group = None device_mesh = state.fsdp_device_mesh if device_mesh is not None and device_mesh.ndim == 2: + # If hybrid shard, only rank in first replica saves expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) if expect_file: - process_group = device_mesh.get_group(1) # Only save on first replica + process_group = device_mesh.get_group(1) # Shard process_group for first replica log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}') else: expect_file = True diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 28216d3541..5bd416f4c7 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -11,7 +11,7 @@ import uuid from contextlib import nullcontext as does_not_raise from functools import partial -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Union from unittest.mock import patch import numpy as np @@ -600,12 +600,16 @@ def mock_get_checkpoint_validation_function(): @pytest.mark.gpu @world_size(2) -@pytest.mark.parametrize('weights_only', [False, True]) -@pytest.mark.parametrize('optimizer', ['adam', 'adamw']) @pytest.mark.parametrize('state_dict_type', ['sharded', 'local']) -@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16']) @pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) -@pytest.mark.parametrize('autoresume', [True, False]) +@pytest.mark.parametrize('weights_only,optimizer,precision,autoresume,load_ignore_keys', [ + [False, 'adamw', 'amp_bf16', False, None], + [True, 'adamw', 'amp_bf16', False, None], + [False, 'adam', 'amp_bf16', False, None], + [False, 'adamw', 'amp_fp16', False, None], + [False, 'adamw', 'amp_bf16', True, None], + [False, 'adamw', 'amp_bf16', False, ['rng']], +]) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), reason='requires PyTorch 1.13 or higher') @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @@ -619,6 +623,7 @@ def test_fsdp_partitioned_state_dict_load( precision: str, optimizer: str, weights_only: bool, + load_ignore_keys: Union[list[str], None], use_remote, s3_bucket, s3_ephemeral_prefix, @@ -630,6 +635,7 @@ def test_fsdp_partitioned_state_dict_load( pytest.xfail(('Loading a state_dict_type="local" checkpoint with strict=True ' 'errors out. See https://github.com/pytorch/pytorch/issues/102667 ' 'for more info')) + load_ignore_keys = [] if load_ignore_keys is None else load_ignore_keys if autoresume: local_run_name = f'my-cool-autoresume-run-{uuid.uuid1()}' @@ -700,6 +706,7 @@ def test_fsdp_partitioned_state_dict_load( optimizer=optimizer, load_weights_only=weights_only, fsdp_config=fsdp_config, + load_ignore_keys=load_ignore_keys, ) state_dict_from_trainer2 = trainer2.state.state_dict() rng2 = trainer2._rng_state @@ -709,7 +716,10 @@ def test_fsdp_partitioned_state_dict_load( state_dict_from_trainer2, ) if not weights_only: - _compare_rng_states_between_trainers(rng1, rng2) + if any('rng' in x for x in load_ignore_keys): + assert rng1 is not None and rng2 is None + else: + _compare_rng_states_between_trainers(rng1, rng2) _compare_optims_between_state_dicts( state_dict_from_trainer1_ba2, state_dict_from_trainer2, From 704c07ee5b1a1723711a6af2059eee1c044af85f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 23 Jan 2024 17:44:07 -0800 Subject: [PATCH 3/5] Remove monkeypatch and new state dict APIs for torch 2.2 (#2899) * fix mosaicfsdp * bump to 2.3 * remove init --- composer/core/state.py | 8 ++++---- composer/trainer/mosaic_fsdp.py | 10 ---------- composer/trainer/trainer.py | 5 ++++- composer/utils/checkpoint.py | 14 +++++++------- 4 files changed, 15 insertions(+), 22 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 4967aa1dba..59b5babfe7 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -874,7 +874,7 @@ def get_model_state_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The state dict for the model. """ - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( @@ -909,7 +909,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The state dict for the optimizer. """ - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( @@ -1216,7 +1216,7 @@ def load_model_state( model_on_rank = state_dict['model'] is not None if model_on_rank: - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict set_model_state_dict( model=self.model, @@ -1277,7 +1277,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True): strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer state dict should perfectly match the keys in the optimizer instance. """ - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict optimizer = self.optimizers[0] set_optimizer_state_dict( diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 1a8ab77bbf..1b346e92e4 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -69,16 +69,6 @@ def patch_pytorch(): from torch.distributed.fsdp import _runtime_utils _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None - # Monkeypatch dtensor support - from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0 - FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore - - # Monkeypath state_dict - from torch.distributed.checkpoint import state_dict # type: ignore - - from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0 - state_dict._verify_options = _verify_options_t2p2p0 - elif version.parse(torch.__version__) < version.parse('2.3.1'): # Monkey patch for torch < 2.3.1 ie torch == 2.3.0 # Note: this is the same patch as 2.2.0, we are just making a new if branch diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e8c587288a..80a519d758 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -986,7 +986,10 @@ def __init__( assert not isinstance(device_train_microbatch_size, str) # Distributed - dist.initialize_dist(device, dist_timeout) + if deepspeed_config is not None or fsdp_config is not None or dist.get_world_size() > 1: + # Deepspeed and FSDP both require torch.distributed to be initialized, even if the world size is 1 + # And torch.distributed is always required for multi-rank training + dist.initialize_dist(device, dist_timeout) # Reproducibility rank_zero_seed, seed = _distribute_and_get_random_seed(seed, device) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index c47184eada..ddb2f3236a 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -494,7 +494,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): else: cur_state_dict = state.state_dict() # For older versions of torch, we load optimizer separately. - if version.parse(torch.__version__) < version.parse('2.1.3'): + if version.parse(torch.__version__) < version.parse('2.2.9'): cur_state_dict.pop('optimizers') state_dict: Dict[str, Any] = { 'state': cur_state_dict, @@ -523,7 +523,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): else: expect_file = True - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): dist_cp.load( # type: ignore state_dict=state_dict, storage_reader=storage_reader, @@ -547,8 +547,8 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): ) # 2. Optionally load optimizer - # if we are using later than 2.1.0 then optimizer will already be loaded - if version.parse(torch.__version__) < version.parse('2.1.3') and not load_weights_only: + # if we are using later than 2.2.9 then optimizer will already be loaded + if version.parse(torch.__version__) < version.parse('2.2.9') and not load_weights_only: optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'], optimizer_key='optimizers', storage_reader=storage_reader) @@ -956,12 +956,12 @@ def _save_checkpoint( state_dict['state'] = state_dict.get('state', {}) if state.fsdp_sharded_state_dict_enabled: - # To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top + # To load optimizer states with 2.0 <= torch < 2.2.9 , the optimizer state must be at the top # level of the state dict because the load_sharded_optimizer_state_dict function # requires a top level state dict key for the optimizer. # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 # for more info. - if using_torch_2() and version.parse(torch.__version__) < version.parse('2.1.3'): + if using_torch_2() and version.parse(torch.__version__) < version.parse('2.2.9'): if not weights_only: state_dict['optimizers'] = state_dict['state'].pop('optimizers') log.debug('State dict created.') @@ -1007,7 +1007,7 @@ def _save_checkpoint( expect_file = True if expect_file: - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): dist_cp.save( # type: ignore state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname), From 2553e544b744da2d75a4e9911d196d47108b0ed7 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 24 Jan 2024 16:50:40 -0800 Subject: [PATCH 4/5] Only load RNG keys that exist (#2901) * cut * fix call * fix test --- composer/utils/checkpoint.py | 12 +++++++++++- tests/trainer/test_fsdp_checkpoint.py | 6 +----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index ddb2f3236a..628910c95c 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -388,6 +388,15 @@ def load_sharded_checkpoint( from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner + def _get_num_ranks_that_saved_rng(metadata: Metadata): + rng_inds = [] + for field_name, field_value in metadata.planner_data.items(): + if 'rng' in field_name: + _, rng_rank_index, _ = field_value + rng_inds.append(rng_rank_index) + rng_inds = set(rng_inds) + return len(rng_inds) + class FileSystemReaderWithValidation(dist_cp.FileSystemReader): """FileSystemReader that validates checkpoint files prior to reading.""" @@ -496,9 +505,10 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # For older versions of torch, we load optimizer separately. if version.parse(torch.__version__) < version.parse('2.2.9'): cur_state_dict.pop('optimizers') + num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata()) state_dict: Dict[str, Any] = { 'state': cur_state_dict, - 'rng': reproducibility.get_rng_state(), + 'rng': reproducibility.get_rng_state()[:num_rng_ranks], } if ignore_keys: diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 5bd416f4c7..0799d815d4 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -561,11 +561,7 @@ def test_checkpoint_loading_with_validation(world_size, tmp_path, is_valid_check # Set the error expectations. expectation = does_not_raise() if not is_valid_checkpoint: - if using_torch_2() and state_dict_type == 'sharded': - from torch.distributed.checkpoint import CheckpointException - expectation = pytest.raises(CheckpointException) - else: - expectation = pytest.raises(ValueError) + expectation = pytest.raises(ValueError) def mock_get_checkpoint_validation_function(): return lambda _: is_valid_checkpoint From 322e90fc241a04a1fd4f39ad705e19fe6a85e695 Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Thu, 25 Jan 2024 15:27:40 -0500 Subject: [PATCH 5/5] Bump version to 0.18.1 (#2905) --- composer/_version.py | 2 +- docker/README.md | 4 ++-- docker/build_matrix.yaml | 12 ++++++------ docker/generate_build_matrix.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/composer/_version.py b/composer/_version.py index 54887b79cb..600bc4d182 100644 --- a/composer/_version.py +++ b/composer/_version.py @@ -3,4 +3,4 @@ """The Composer Version.""" -__version__ = '0.18.0' +__version__ = '0.18.1' diff --git a/docker/README.md b/docker/README.md index e3bab86b5c..32f6f6e0e9 100644 --- a/docker/README.md +++ b/docker/README.md @@ -15,8 +15,8 @@ all dependencies for both NLP and Vision models. They are built on top of the | Composer Version | CUDA Support | Docker Tag | |--------------------|----------------|----------------------------------------------------------------| -| 0.18.0 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.18.0` | -| 0.18.0 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.18.0_cpu` | +| 0.18.1 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.18.1` | +| 0.18.1 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.18.1_cpu` | **Note**: For a lightweight installation, we recommended using a [MosaicML PyTorch Image](#pytorch-images) and manually diff --git a/docker/build_matrix.yaml b/docker/build_matrix.yaml index 7a87276336..7ba413a4bb 100644 --- a/docker/build_matrix.yaml +++ b/docker/build_matrix.yaml @@ -193,9 +193,9 @@ TORCHVISION_VERSION: 0.18.0 - AWS_OFI_NCCL_VERSION: '' BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04 - COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.18.0 + COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.18.1 CUDA_VERSION: 12.1.0 - IMAGE_NAME: composer-0-18-0 + IMAGE_NAME: composer-0-18-1 MOFED_VERSION: 5.5-1.0.3.2 NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 @@ -216,15 +216,15 @@ PYTORCH_NIGHTLY_VERSION: '' PYTORCH_VERSION: 2.1.2 TAGS: - - mosaicml/composer:0.18.0 + - mosaicml/composer:0.18.1 - mosaicml/composer:latest TARGET: composer_stage TORCHVISION_VERSION: 0.16.2 - AWS_OFI_NCCL_VERSION: '' BASE_IMAGE: ubuntu:20.04 - COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.18.0 + COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.18.1 CUDA_VERSION: '' - IMAGE_NAME: composer-0-18-0-cpu + IMAGE_NAME: composer-0-18-1-cpu MOFED_VERSION: 5.5-1.0.3.2 NVIDIA_REQUIRE_CUDA_OVERRIDE: '' PYTHON_VERSION: '3.10' @@ -232,7 +232,7 @@ PYTORCH_NIGHTLY_VERSION: '' PYTORCH_VERSION: 2.1.2 TAGS: - - mosaicml/composer:0.18.0_cpu + - mosaicml/composer:0.18.1_cpu - mosaicml/composer:latest_cpu TARGET: composer_stage TORCHVISION_VERSION: 0.16.2 diff --git a/docker/generate_build_matrix.py b/docker/generate_build_matrix.py index 28d52aefa9..e51662729d 100644 --- a/docker/generate_build_matrix.py +++ b/docker/generate_build_matrix.py @@ -246,7 +246,7 @@ def _main(): composer_entries = [] # The `GIT_COMMIT` is a placeholder and Jenkins will substitute it with the actual git commit for the `composer_staging` images - composer_versions = ['0.18.0'] # Only build images for the latest composer version + composer_versions = ['0.18.1'] # Only build images for the latest composer version composer_python_versions = [LATEST_PYTHON_VERSION] # just build composer against the latest for product in itertools.product(composer_python_versions, composer_versions, cuda_options):