Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for deepspeed zero-3 #33

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions finetune/run_classifier_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
sys.path.append(tencentpretrain_dir)

from tencentpretrain.opts import deepspeed_opts
from tencentpretrain.model_loader import *
from finetune.run_classifier import *


Expand Down Expand Up @@ -129,10 +130,16 @@ def main():
args.tokenizer = str2tokenizer[args.tokenizer](args)

# Build classification model.
model = Classifier(args)
if args.enable_zero3:
with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config):
model = Classifier(args)
if args.pretrained_model_path:
model = _load_state_dict_into_model(model, args.load_model_path)
else:
model = Classifier(args)

# Load or initialize parameters.
load_or_initialize_parameters(args, model)
# Load or initialize parameters.
load_or_initialize_parameters(args, model)

# Get logger.
args.logger = init_logger(args)
Expand Down
54 changes: 28 additions & 26 deletions inference/run_classifier_deepspeed_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def main():
parser.add_argument("--output_prob", action="store_true", help="Write probabilities to output file.")

deepspeed_opts(parser)
parser.add_argument("--mp_size", type=int, default=1, help="Model parallel size.")

args = parser.parse_args()

Expand All @@ -47,48 +46,51 @@ def main():
# Build classification model and load parameters.
args.soft_targets, args.soft_alpha = False, False
deepspeed.init_distributed()
model = Classifier(args)

if args.load_model_path:
if args.enable_zero3:
with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config):
model = Classifier(args)
model = _load_state_dict_into_model(model, args.load_model_path)
else:
model = Classifier(args)
model = load_model(model, args.load_model_path)

model = deepspeed.init_inference(model=model, mp_size=args.mp_size, replace_method=None)
model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0]

rank = dist.get_rank()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if rank == 0:
dataset = read_dataset(args, args.test_path)
dataset = read_dataset(args, args.test_path)

src = torch.LongTensor([sample[0] for sample in dataset])
seg = torch.LongTensor([sample[1] for sample in dataset])
src = torch.LongTensor([sample[0] for sample in dataset])
seg = torch.LongTensor([sample[1] for sample in dataset])

batch_size = args.batch_size
instances_num = src.size()[0]
batch_size = args.batch_size
instances_num = src.size()[0]

print("The number of prediction instances: ", instances_num)
print("The number of prediction instances: ", instances_num)

model.eval()
model.eval()

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
with open(args.prediction_path, mode="w", encoding="utf-8") as f:
if rank == 0:
f.write("label")
if args.output_logits:
f.write("\t" + "logits")
if args.output_prob:
f.write("\t" + "prob")
f.write("\n")
for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
src_batch = src_batch.to(device)
seg_batch = seg_batch.to(device)
with torch.no_grad():
_, logits = model(src_batch, None, seg_batch)

pred = torch.argmax(logits, dim=1)
pred = pred.cpu().numpy().tolist()
prob = nn.Softmax(dim=1)(logits)
logits = logits.cpu().numpy().tolist()
prob = prob.cpu().numpy().tolist()

for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
src_batch = src_batch.to(device)
seg_batch = seg_batch.to(device)
with torch.no_grad():
_, logits = model(src_batch, None, seg_batch)

pred = torch.argmax(logits, dim=1)
pred = pred.cpu().numpy().tolist()
prob = nn.Softmax(dim=1)(logits)
logits = logits.cpu().numpy().tolist()
prob = prob.cpu().numpy().tolist()
if rank == 0:
for j in range(len(pred)):
f.write(str(pred[j]))
if args.output_logits:
Expand Down
2 changes: 1 addition & 1 deletion inference/run_classifier_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.utils.seed import set_seed
from tencentpretrain.model_loader import load_model
from tencentpretrain.model_loader import *
from tencentpretrain.opts import infer_opts, tokenizer_opts
from finetune.run_classifier import Classifier

Expand Down
44 changes: 44 additions & 0 deletions models/deepspeed_zero3_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu":1,
"steps_per_print": 100,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5,
"weight_decay": 1e-2
}
},
"flops_profiler": {
"enabled": true,
"profile_step": 1,
"module_depth": -1,
"top_modules": 3,
"detailed": true
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"activation_checkpointing": {
"partition_activations": false,
"contiguous_memory_optimization": false,
"cpu_checkpointing": false
},
"wall_clock_breakdown": false,
"zero_allow_untested_optimizer": true
}
2 changes: 1 addition & 1 deletion scripts/generate_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.model_loader import load_model
from tencentpretrain.model_loader import *
from tencentpretrain.opts import infer_opts, tokenizer_opts


Expand Down
57 changes: 30 additions & 27 deletions scripts/generate_lm_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
tokenizer_opts(parser)

deepspeed_opts(parser)
parser.add_argument("--mp_size", type=int, default=1, help="Model parallel size.")

args = parser.parse_args()

Expand All @@ -40,36 +39,40 @@

args.tokenizer = str2tokenizer[args.tokenizer](args)

model = GenerateLm(args)
model = load_model(model, args.load_model_path)
if args.enable_zero3:
with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config):
model = GenerateLm(args)
model = _load_state_dict_into_model(model, args.load_model_path)
else:
model = GenerateLm(args)
model = load_model(model, args.load_model_path)
deepspeed.init_distributed()
model = deepspeed.init_inference(model=model, mp_size=args.mp_size, replace_method=None)
model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0]

rank = dist.get_rank()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if rank == 0:
model.eval()

with open(args.test_path, mode="r", encoding="utf-8") as f:
line = f.readline().strip()
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line))
seg = [1] * len(src)
beginning_length = len(src)
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src_tensor, seg_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device)

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
for i in range(args.seq_length - beginning_length):
output = model(src_tensor, seg_tensor)
next_token_logits = output[0][-1] / args.temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

src_tensor = torch.cat([src_tensor, next_token.view(1, 1).to(device)], dim=1)
seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).to(device)], dim=1)

model.eval()

with open(args.test_path, mode="r", encoding="utf-8") as f:
line = f.readline().strip()
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line))
seg = [1] * len(src)
beginning_length = len(src)
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src_tensor, seg_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device)

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
for i in range(args.seq_length - beginning_length):
output = model(src_tensor, seg_tensor)
next_token_logits = output[0][-1] / args.temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

src_tensor = torch.cat([src_tensor, next_token.view(1, 1).to(device)], dim=1)
seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).to(device)], dim=1)
if rank == 0:
f.write(line + "\n")
generated_sentence = "".join(
args.tokenizer.convert_ids_to_tokens([token_id.item() for token_id in src_tensor[0]])
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.model_loader import load_model
from tencentpretrain.model_loader import *
from tencentpretrain.opts import infer_opts, tokenizer_opts
from scripts.generate_lm import top_k_top_p_filtering

Expand Down
55 changes: 29 additions & 26 deletions scripts/generate_seq2seq_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
parser.add_argument("--tgt_seq_length", type=int, default=128,
help="Sequence length.")
deepspeed_opts(parser)
parser.add_argument("--mp_size", type=int, default=1, help="Model parallel size.")

args = parser.parse_args()

Expand All @@ -45,36 +44,40 @@
args.vocab_path = args.tgt_vocab_path
args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args)
args.tgt_vocab = args.tgt_tokenizer.vocab

model = GenerateSeq2seq(args)
model = load_model(model, args.load_model_path)
if args.enable_zero3:
with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config):
model = GenerateSeq2seq(args)
model = _load_state_dict_into_model(model, args.load_model_path)
else:
model = GenerateSeq2seq(args)
model = load_model(model, args.load_model_path)
deepspeed.init_distributed()
model = deepspeed.init_inference(model=model, mp_size=args.mp_size, replace_method=None)
model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0]

rank = dist.get_rank()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if rank == 0:
model.eval()

with open(args.test_path, mode="r", encoding="utf-8") as f:
line = f.readline().strip()
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line) + [SEP_TOKEN])
seg = [1] * len(src)
tgt = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN])
beginning_length = len(src)
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src_tensor, seg_tensor, tgt_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device), torch.LongTensor([tgt]).to(device)

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
for i in range(args.tgt_seq_length-1):
output = model(src_tensor, seg_tensor, tgt_tensor)
next_token_logits = output[0][-1] / args.temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
tgt_tensor = torch.cat([tgt_tensor, next_token.view(1, 1).to(device)], dim=1)

model.eval()

with open(args.test_path, mode="r", encoding="utf-8") as f:
line = f.readline().strip()
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(line) + [SEP_TOKEN])
seg = [1] * len(src)
tgt = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN])
beginning_length = len(src)
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src_tensor, seg_tensor, tgt_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device), torch.LongTensor([tgt]).to(device)

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
for i in range(args.tgt_seq_length-1):
output = model(src_tensor, seg_tensor, tgt_tensor)
next_token_logits = output[0][-1] / args.temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
tgt_tensor = torch.cat([tgt_tensor, next_token.view(1, 1).to(device)], dim=1)
if rank == 0:
f.write(line + "\n")
generated_sentence = "".join(
args.tgt_tokenizer.convert_ids_to_tokens([token_id.item() for token_id in tgt_tensor[0]])
Expand Down
45 changes: 45 additions & 0 deletions tencentpretrain/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,48 @@ def load_model(model, model_path):
else:
model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
return model


def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""):
# Convert old format to new format if needed from a PyTorch state_dict

# copy state_dict so _load_from_state_dict can modify it
state_dict = torch.load(model_path, map_location="cpu")
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

error_msgs = []

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
import deepspeed
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")

load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict

return model_to_load
2 changes: 2 additions & 0 deletions tencentpretrain/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def tgt_tokenizer_opts(parser):
def deepspeed_opts(parser):
parser.add_argument("--deepspeed", action="store_true",
help=".")
parser.add_argument("--enable_zero3", action="store_true",
help=".")
parser.add_argument("--deepspeed_config", default="models/deepspeed_config.json", type=str,
help=".")
parser.add_argument("--deepspeed_checkpoint_activations", action='store_true',
Expand Down
Loading