From f1ce95399ec39e5ae20afc9f21bab01ad42c22ca Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 31 Aug 2023 19:01:21 +0800 Subject: [PATCH 1/4] hybrid plugin support huggingface from_pretrained --- .../hybrid_parallel_checkpoint_io.py | 3 + colossalai/checkpoint_io/utils.py | 31 ++++- .../test_huggingface_compatibility.py | 123 ++++++++++++++++++ 3 files changed, 153 insertions(+), 4 deletions(-) create mode 100644 tests/test_checkpoint_io/test_huggingface_compatibility.py diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index c128858b1efe..bcd4d44e19a2 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -26,6 +26,7 @@ load_shard_state_dict, load_state_dict_into_model, load_states_into_optimizer, + save_config_file, save_param_groups, save_state_dict_shards, search_tp_partition_dim, @@ -204,6 +205,7 @@ def save_sharded_model(self, if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) if self.verbose: logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -251,6 +253,7 @@ def save_sharded_model(self, final_index_file.append_weight_map(weight, weight_filename) final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) rmtree(tmp_index_file_folder) if self.verbose: logging.info(f"The model is split into checkpoint shards. " diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 0025d07dfc8e..ceca10963228 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -9,12 +9,12 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch -import torch.distributed as dist import torch.nn as nn -from torch.distributed import ProcessGroup from torch.optim import Optimizer +from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype +from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, @@ -335,6 +335,29 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None: torch.save(param_groups, group_file_path) +def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True): + """ + Save config.json/generation_config.json if model is a Huggingface pretrained model. + """ + if not isinstance(model, PreTrainedModel): + return + + model = unwrap_huggingface_model(model) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + dtype = get_parameter_dtype(model) + model.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model.config.architectures = [model.__class__.__name__] + + # Save the config + if is_master: + model.config.save_pretrained(checkpoint_path) + if model.can_generate(): + model.generation_config.save_pretrained(checkpoint_path) + + def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: """ Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains @@ -709,5 +732,5 @@ def get_shard_filename(weights_name: str, idx: int): get shard file name """ shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") - shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") + shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") return shard_file diff --git a/tests/test_checkpoint_io/test_huggingface_compatibility.py b/tests/test_checkpoint_io/test_huggingface_compatibility.py new file mode 100644 index 000000000000..1992a067e576 --- /dev/null +++ b/tests/test_checkpoint_io/test_huggingface_compatibility.py @@ -0,0 +1,123 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +def run_check(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, shard=True, size_per_shard=32): + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + + model = model_fn() + optimizer = Adam((model.parameters()), lr=0.001) + criterion = loss_fn + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + else: + output = model(**_preprocess_data(data)) + loss = criterion(output) + optimizer.backward(loss) + + optimizer.step() + + with shared_tempdir() as tempdir: + + model_ckpt_path = f"{tempdir}/model" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model.__class__.from_pretrained(model_ckpt_path) + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@clear_cache_before_run() +@parameterize('shard', [True]) +@parameterize('size_per_shard', [32]) +@parameterize('test_config', [{ + 'tp_size': 4, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 2, + 'pp_size': 1, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 +}]) +def test_compatibility(shard, size_per_shard, test_config): + + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + run_check(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + test_compatibility() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_hybrid_IO_huggingface_compability(): + spawn(run_dist, 4) From 92887dd17f9eae86fb084cad651ddb14c7b56ada Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 31 Aug 2023 21:35:21 +0800 Subject: [PATCH 2/4] add huggingface compatibility tests --- ... test_hybrid_huggingface_compatibility.py} | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) rename tests/test_checkpoint_io/{test_huggingface_compatibility.py => test_hybrid_huggingface_compatibility.py} (77%) diff --git a/tests/test_checkpoint_io/test_huggingface_compatibility.py b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py similarity index 77% rename from tests/test_checkpoint_io/test_huggingface_compatibility.py rename to tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py index 1992a067e576..df907605d869 100644 --- a/tests/test_checkpoint_io/test_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py @@ -19,7 +19,13 @@ from tests.kit.model_zoo import model_zoo -def run_check(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, shard=True, size_per_shard=32): +def exam_from_pretrained(model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + test_config, + shard=True, + size_per_shard=32): def _criterion(outputs, inputs): outputs = output_transform_fn(outputs) @@ -67,7 +73,10 @@ def _preprocess_data(data): booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() - new_model = model.__class__.from_pretrained(model_ckpt_path) + new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path) + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) Randomizer.reset_index() @@ -75,8 +84,6 @@ def _preprocess_data(data): @clear_cache_before_run() -@parameterize('shard', [True]) -@parameterize('size_per_shard', [32]) @parameterize('test_config', [{ 'tp_size': 4, 'pp_size': 1, @@ -101,23 +108,22 @@ def _preprocess_data(data): 'precision': 'fp16', 'initial_scale': 1 }]) -def test_compatibility(shard, size_per_shard, test_config): - +def run_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - run_check(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - + exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() torch.cuda.empty_cache() def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_compatibility() + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_test() @pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) @rerun_if_address_is_in_use() -def test_hybrid_IO_huggingface_compability(): - spawn(run_dist, 4) +def test_huggingface_compatibility(world_size): + spawn(run_dist, world_size) From 997551422c17b39eaf6f61c38ded9f5f8d822786 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 1 Sep 2023 12:25:47 +0800 Subject: [PATCH 3/4] add folder cleaning --- .../hybrid_parallel_checkpoint_io.py | 16 +++--- colossalai/checkpoint_io/utils.py | 50 +++++++++++++++++-- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index bcd4d44e19a2..fef5b0d16d60 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -221,9 +221,9 @@ def save_sharded_model(self, Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") - weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, @@ -231,7 +231,8 @@ def save_sharded_model(self, index_file=index_file, base_filename=weights_name, is_master=control_saving, - use_safetensors=use_safetensors) + use_safetensors=use_safetensors, + use_pp_format=True) if control_saving: assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." index_file.append_meta_data("total_size", total_size) @@ -426,15 +427,16 @@ def save_sharded_optimizer(self, Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=states_name, - is_master=control_saving) + is_master=control_saving, + use_pp_format=True) if control_saving: assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index ceca10963228..0300e62653eb 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -228,7 +228,8 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] index_file: "CheckpointIndexFile", base_filename: str, is_master: bool, - use_safetensors: bool = False) -> int: + use_safetensors: bool = False, + use_pp_format: bool = False) -> int: ''' Save sharded state dict only on master rank, this method can be used by both model and optimizer states. Args: @@ -236,14 +237,16 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] checkpoint (str): The path of checkpoint directory as string. index_file (CheckpointIndexFile): The index file object to be updated. base_filename (str): Decides the prefix of filenames of shards. - is_master (bool): Whether current rank is master. - use_safetensors (bool): Whether to use safetensors to save checkpoint. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. Returns: int: the total size of shards ''' total_size = 0 + shard_filenames = [] for idx, shard_pair in enumerate(sharded_state_dict): shard, current_size = shard_pair if not is_master: @@ -257,8 +260,12 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] # Only save on master rank. save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + shard_filenames.append(shard_file) del shard + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) + return total_size @@ -335,9 +342,46 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None: torch.save(param_groups, group_file_path) +def clean_folder(checkpoint_path: str, + weights_name: str, + shard_filenames: List[str], + is_master: bool = True, + use_pp_format: bool = False): + """ + Clean the unneeded files in checkpoint directory after shards of state_dict have been saved. + + Args: + checkpoint_path (str): Path to the checkpoint directory. + weights_name (str): Decides the prefix of filenames of weight shards. + shard_filenames (List[str]): The list of saved shard filenames which should not be removed. + is_master (bool, optional): Whether current rank is main process. Defaults to True. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + """ + if is_master: + for filename in os.listdir(checkpoint_path): + full_filename = os.path.join(checkpoint_path, filename) + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + if not use_pp_format: + reg = re.compile(r"(.*?)-\d{5}") + else: + # When this checkpoint is created by pipeline parallel process, the pattern is a little different. + reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}") + if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) + and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None): + os.remove(full_filename) + + def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True): """ Save config.json/generation_config.json if model is a Huggingface pretrained model. + This method can only be called when a model is saved in a sharded way. + + Args: + model (nn.Module): The model whose config should be saved if it's a huggingface model. + checkpoint_path (str): Path to the checkpoint directory. + is_master (bool): Whether current rank is main process. """ if not isinstance(model, PreTrainedModel): return From 944906ab7cc1319dc4531bb533c2d76815325168 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 1 Sep 2023 16:00:34 +0800 Subject: [PATCH 4/4] fix bugs --- .github/workflows/build_on_pr.yml | 2 +- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 4c7e08e5799e..3f91dc33a660 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -208,7 +208,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-cov=. --durations=10 tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 277843b66568..eced4fc1a16b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -141,10 +141,10 @@ def get_param_info(optim: Optimizer): def init_pipeline_optimizer(optim: Optimizer, model: Module): - params = set(model.parameters()) + model_params = set(model.parameters()) new_param_groups = [] for group in optim.param_groups: - params = [p for p in group['params'] if p in params] + params = [p for p in group['params'] if p in model_params] new_param_groups.append({**group, 'params': params}) optim.__setstate__({'param_groups': new_param_groups})