diff --git a/finetune/run_classifier_deepspeed.py b/finetune/run_classifier_deepspeed.py index 7ea3c46..4610b2f 100644 --- a/finetune/run_classifier_deepspeed.py +++ b/finetune/run_classifier_deepspeed.py @@ -14,6 +14,7 @@ sys.path.append(tencentpretrain_dir) from tencentpretrain.opts import deepspeed_opts +from tencentpretrain.model_loader import * from finetune.run_classifier import * @@ -129,10 +130,16 @@ def main(): args.tokenizer = str2tokenizer[args.tokenizer](args) # Build classification model. - model = Classifier(args) + if args.enable_zero3: + with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): + model = Classifier(args) + if args.pretrained_model_path: + model = _load_state_dict_into_model(model, args.load_model_path) + else: + model = Classifier(args) - # Load or initialize parameters. - load_or_initialize_parameters(args, model) + # Load or initialize parameters. + load_or_initialize_parameters(args, model) # Get logger. args.logger = init_logger(args) diff --git a/inference/run_classifier_deepspeed_infer.py b/inference/run_classifier_deepspeed_infer.py index 8bd0c0f..fe26b8e 100644 --- a/inference/run_classifier_deepspeed_infer.py +++ b/inference/run_classifier_deepspeed_infer.py @@ -34,7 +34,6 @@ def main(): parser.add_argument("--output_prob", action="store_true", help="Write probabilities to output file.") deepspeed_opts(parser) - parser.add_argument("--mp_size", type=int, default=1, help="Model parallel size.") args = parser.parse_args() @@ -47,48 +46,51 @@ def main(): # Build classification model and load parameters. args.soft_targets, args.soft_alpha = False, False deepspeed.init_distributed() - model = Classifier(args) - - if args.load_model_path: + if args.enable_zero3: + with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): + model = Classifier(args) + model = _load_state_dict_into_model(model, args.load_model_path) + else: + model = Classifier(args) model = load_model(model, args.load_model_path) - model = deepspeed.init_inference(model=model, mp_size=args.mp_size, replace_method=None) + model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0] rank = dist.get_rank() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if rank == 0: - dataset = read_dataset(args, args.test_path) + dataset = read_dataset(args, args.test_path) - src = torch.LongTensor([sample[0] for sample in dataset]) - seg = torch.LongTensor([sample[1] for sample in dataset]) + src = torch.LongTensor([sample[0] for sample in dataset]) + seg = torch.LongTensor([sample[1] for sample in dataset]) - batch_size = args.batch_size - instances_num = src.size()[0] + batch_size = args.batch_size + instances_num = src.size()[0] - print("The number of prediction instances: ", instances_num) + print("The number of prediction instances: ", instances_num) - model.eval() + model.eval() - with open(args.prediction_path, mode="w", encoding="utf-8") as f: + with open(args.prediction_path, mode="w", encoding="utf-8") as f: + if rank == 0: f.write("label") if args.output_logits: f.write("\t" + "logits") if args.output_prob: f.write("\t" + "prob") f.write("\n") - for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)): - src_batch = src_batch.to(device) - seg_batch = seg_batch.to(device) - with torch.no_grad(): - _, logits = model(src_batch, None, seg_batch) - - pred = torch.argmax(logits, dim=1) - pred = pred.cpu().numpy().tolist() - prob = nn.Softmax(dim=1)(logits) - logits = logits.cpu().numpy().tolist() - prob = prob.cpu().numpy().tolist() - + for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)): + src_batch = src_batch.to(device) + seg_batch = seg_batch.to(device) + with torch.no_grad(): + _, logits = model(src_batch, None, seg_batch) + + pred = torch.argmax(logits, dim=1) + pred = pred.cpu().numpy().tolist() + prob = nn.Softmax(dim=1)(logits) + logits = logits.cpu().numpy().tolist() + prob = prob.cpu().numpy().tolist() + if rank == 0: for j in range(len(pred)): f.write(str(pred[j])) if args.output_logits: diff --git a/inference/run_classifier_infer.py b/inference/run_classifier_infer.py index e8a7c7d..f106bab 100644 --- a/inference/run_classifier_infer.py +++ b/inference/run_classifier_infer.py @@ -15,7 +15,7 @@ from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam from tencentpretrain.utils.seed import set_seed -from tencentpretrain.model_loader import load_model +from tencentpretrain.model_loader import * from tencentpretrain.opts import infer_opts, tokenizer_opts from finetune.run_classifier import Classifier diff --git a/models/deepspeed_zero3_config.json b/models/deepspeed_zero3_config.json new file mode 100644 index 0000000..784860c --- /dev/null +++ b/models/deepspeed_zero3_config.json @@ -0,0 +1,44 @@ +{ + "gradient_accumulation_steps": 1, + "train_micro_batch_size_per_gpu":1, + "steps_per_print": 100, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-5, + "weight_decay": 1e-2 + } + }, + "flops_profiler": { + "enabled": true, + "profile_step": 1, + "module_depth": -1, + "top_modules": 3, + "detailed": true + }, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 3, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + } + }, + "activation_checkpointing": { + "partition_activations": false, + "contiguous_memory_optimization": false, + "cpu_checkpointing": false + }, + "wall_clock_breakdown": false, + "zero_allow_untested_optimizer": true +} diff --git a/scripts/generate_lm.py b/scripts/generate_lm.py index a10cae0..b6d5cc1 100644 --- a/scripts/generate_lm.py +++ b/scripts/generate_lm.py @@ -17,7 +17,7 @@ from tencentpretrain.utils.constants import * from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam -from tencentpretrain.model_loader import load_model +from tencentpretrain.model_loader import * from tencentpretrain.opts import infer_opts, tokenizer_opts diff --git a/scripts/generate_lm_deepspeed.py b/scripts/generate_lm_deepspeed.py index 009de69..5fca576 100644 --- a/scripts/generate_lm_deepspeed.py +++ b/scripts/generate_lm_deepspeed.py @@ -29,7 +29,6 @@ tokenizer_opts(parser) deepspeed_opts(parser) - parser.add_argument("--mp_size", type=int, default=1, help="Model parallel size.") args = parser.parse_args() @@ -40,36 +39,40 @@ args.tokenizer = str2tokenizer[args.tokenizer](args) - model = GenerateLm(args) - model = load_model(model, args.load_model_path) + if args.enable_zero3: + with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): + model = GenerateLm(args) + model = _load_state_dict_into_model(model, args.load_model_path) + else: + model = GenerateLm(args) + model = load_model(model, args.load_model_path) deepspeed.init_distributed() - model = deepspeed.init_inference(model=model, mp_size=args.mp_size, replace_method=None) + model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0] rank = dist.get_rank() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if rank == 0: - model.eval() - - with open(args.test_path, mode="r", encoding="utf-8") as f: - line = f.readline().strip() - src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line)) - seg = [1] * len(src) - beginning_length = len(src) - if len(src) > args.seq_length: - src = src[:args.seq_length] - seg = seg[:args.seq_length] - src_tensor, seg_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device) - - with open(args.prediction_path, mode="w", encoding="utf-8") as f: - for i in range(args.seq_length - beginning_length): - output = model(src_tensor, seg_tensor) - next_token_logits = output[0][-1] / args.temperature - filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) - - src_tensor = torch.cat([src_tensor, next_token.view(1, 1).to(device)], dim=1) - seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).to(device)], dim=1) - + model.eval() + + with open(args.test_path, mode="r", encoding="utf-8") as f: + line = f.readline().strip() + src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line)) + seg = [1] * len(src) + beginning_length = len(src) + if len(src) > args.seq_length: + src = src[:args.seq_length] + seg = seg[:args.seq_length] + src_tensor, seg_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device) + + with open(args.prediction_path, mode="w", encoding="utf-8") as f: + for i in range(args.seq_length - beginning_length): + output = model(src_tensor, seg_tensor) + next_token_logits = output[0][-1] / args.temperature + filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + + src_tensor = torch.cat([src_tensor, next_token.view(1, 1).to(device)], dim=1) + seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).to(device)], dim=1) + if rank == 0: f.write(line + "\n") generated_sentence = "".join( args.tokenizer.convert_ids_to_tokens([token_id.item() for token_id in src_tensor[0]]) diff --git a/scripts/generate_seq2seq.py b/scripts/generate_seq2seq.py index 1f43199..4f2b7b0 100644 --- a/scripts/generate_seq2seq.py +++ b/scripts/generate_seq2seq.py @@ -14,7 +14,7 @@ from tencentpretrain.utils.constants import * from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam -from tencentpretrain.model_loader import load_model +from tencentpretrain.model_loader import * from tencentpretrain.opts import infer_opts, tokenizer_opts from scripts.generate_lm import top_k_top_p_filtering diff --git a/scripts/generate_seq2seq_deepspeed.py b/scripts/generate_seq2seq_deepspeed.py index 845aad5..8dbce1d 100644 --- a/scripts/generate_seq2seq_deepspeed.py +++ b/scripts/generate_seq2seq_deepspeed.py @@ -29,7 +29,6 @@ parser.add_argument("--tgt_seq_length", type=int, default=128, help="Sequence length.") deepspeed_opts(parser) - parser.add_argument("--mp_size", type=int, default=1, help="Model parallel size.") args = parser.parse_args() @@ -45,36 +44,40 @@ args.vocab_path = args.tgt_vocab_path args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args) args.tgt_vocab = args.tgt_tokenizer.vocab - - model = GenerateSeq2seq(args) - model = load_model(model, args.load_model_path) + if args.enable_zero3: + with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): + model = GenerateSeq2seq(args) + model = _load_state_dict_into_model(model, args.load_model_path) + else: + model = GenerateSeq2seq(args) + model = load_model(model, args.load_model_path) deepspeed.init_distributed() - model = deepspeed.init_inference(model=model, mp_size=args.mp_size, replace_method=None) + model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0] rank = dist.get_rank() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if rank == 0: - model.eval() - - with open(args.test_path, mode="r", encoding="utf-8") as f: - line = f.readline().strip() - src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line) + [SEP_TOKEN]) - seg = [1] * len(src) - tgt = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN]) - beginning_length = len(src) - if len(src) > args.seq_length: - src = src[:args.seq_length] - seg = seg[:args.seq_length] - src_tensor, seg_tensor, tgt_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device), torch.LongTensor([tgt]).to(device) - - with open(args.prediction_path, mode="w", encoding="utf-8") as f: - for i in range(args.tgt_seq_length-1): - output = model(src_tensor, seg_tensor, tgt_tensor) - next_token_logits = output[0][-1] / args.temperature - filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) - tgt_tensor = torch.cat([tgt_tensor, next_token.view(1, 1).to(device)], dim=1) + model.eval() + + with open(args.test_path, mode="r", encoding="utf-8") as f: + line = f.readline().strip() + src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line) + [SEP_TOKEN]) + seg = [1] * len(src) + tgt = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN]) + beginning_length = len(src) + if len(src) > args.seq_length: + src = src[:args.seq_length] + seg = seg[:args.seq_length] + src_tensor, seg_tensor, tgt_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device), torch.LongTensor([tgt]).to(device) + + with open(args.prediction_path, mode="w", encoding="utf-8") as f: + for i in range(args.tgt_seq_length-1): + output = model(src_tensor, seg_tensor, tgt_tensor) + next_token_logits = output[0][-1] / args.temperature + filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + tgt_tensor = torch.cat([tgt_tensor, next_token.view(1, 1).to(device)], dim=1) + if rank == 0: f.write(line + "\n") generated_sentence = "".join( args.tgt_tokenizer.convert_ids_to_tokens([token_id.item() for token_id in tgt_tensor[0]]) diff --git a/tencentpretrain/model_loader.py b/tencentpretrain/model_loader.py index 836b1f6..7e6bba4 100644 --- a/tencentpretrain/model_loader.py +++ b/tencentpretrain/model_loader.py @@ -10,3 +10,48 @@ def load_model(model, model_path): else: model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) return model + + +def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""): + # Convert old format to new format if needed from a PyTorch state_dict + + # copy state_dict so _load_from_state_dict can modify it + state_dict = torch.load(model_path, map_location="cpu") + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module, state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + import deepspeed + # In sharded models, each shard has only part of the full state_dict, so only gather + # parameters that are in the current state_dict. + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + if len(params_to_gather) > 0: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") + + load(model_to_load, state_dict, prefix=start_prefix) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict + + return model_to_load diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index 029d30a..34e53ed 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -208,6 +208,8 @@ def tgt_tokenizer_opts(parser): def deepspeed_opts(parser): parser.add_argument("--deepspeed", action="store_true", help=".") + parser.add_argument("--enable_zero3", action="store_true", + help=".") parser.add_argument("--deepspeed_config", default="models/deepspeed_config.json", type=str, help=".") parser.add_argument("--deepspeed_checkpoint_activations", action='store_true', diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index a5b01c5..b609636 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -2,7 +2,7 @@ import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel -from tencentpretrain.model_loader import load_model +from tencentpretrain.model_loader import * from tencentpretrain.model_saver import save_model from tencentpretrain.model_builder import build_model from tencentpretrain.utils.logging import init_logger @@ -23,12 +23,20 @@ def train_and_validate(args): args.vocab = args.tokenizer.vocab # Build model. - model_for_training = build_model(args) + if args.deepspeed and args.enable_zero3: + import deepspeed + with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): + model_for_training = build_model(args) + else: + model_for_training = build_model(args) # Load or initialize parameters. if args.pretrained_model_path is not None: # Initialize with pretrained model. - model_for_training = load_model(model_for_training, args.pretrained_model_path) + if args.deepspeed and args.enable_zero3: + model_for_training = _load_state_dict_into_model(model_for_training, args.pretrained_model_path) + else: + model_for_training = load_model(model_for_training, args.pretrained_model_path) else: # Initialize with normal distribution. if args.deep_init: