Skip to content

Commit

Permalink
modify train ruGPT3XL finetune example
Browse files Browse the repository at this point in the history
  • Loading branch information
king-menin committed Dec 6, 2022
1 parent 580d649 commit 47ccf82
Show file tree
Hide file tree
Showing 8 changed files with 1,499 additions and 1,531 deletions.
406 changes: 0 additions & 406 deletions examples/Finetune_RuGPTs_with_HF.ipynb

This file was deleted.

574 changes: 287 additions & 287 deletions examples/RuGPT3FinetuneHF.ipynb

Large diffs are not rendered by default.

530 changes: 450 additions & 80 deletions examples/ruGPT3XL_finetune_example.ipynb

Large diffs are not rendered by default.

1,448 changes: 724 additions & 724 deletions examples/ruGPT3XL_generation.ipynb

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion pretrain_gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
print_args, print_rank_0,
get_sparse_attention_config, top_k_logits, DEEPSPEED_WRAP
)
from huggingface_hub import hf_hub_download
from src.download_utils import WEIGHTS_NAME

# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
Expand Down Expand Up @@ -75,7 +77,12 @@ def get_model(args):
sparse_mode=args.sparse_mode)

if args.load_huggingface is not None:
model = load_huggingface_model(model, args.load_huggingface, args.huggingface_double_pos_embeddings)
if args.load_huggingface == "sberbank-ai/rugpt3xl":
weights_path = hf_hub_download(args.load_huggingface, WEIGHTS_NAME)
checkpoint = torch.load(weights_path, map_location="cpu")['module']
model.load_state_dict(checkpoint, strict=False)
else:
model = load_huggingface_model(model, args.load_huggingface, args.huggingface_double_pos_embeddings)

if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
Expand Down
43 changes: 21 additions & 22 deletions scripts/deepspeed_gpt3_xl_finetune.sh
Original file line number Diff line number Diff line change
@@ -1,42 +1,41 @@
#! /bin/bash
%%bash

# Model parallel size
MP_SIZE=1
# Change for multinode config
NUM_GPUS_PER_WORKER=1

gpt_options=" \
--train-data-path /path/2/train/data/files.list \
--max-files-per-process 20000 \
--logging-dir=/path/2/log/dir \
--train-data-path examples/train.list \
--test-data-path examples/valid.list \
--load-huggingface sberbank-ai/rugpt3xl \
--save /path/2/save/model \
--tokenizer-path sberbank-ai/rugpt3xl \
--cache-prefix p5 \
--save-interval 500 \
--no-load-optim \
--finetune \
--log-interval 100 \
--model-parallel-size 1 \
--logging-dir=examples/log/ \
--save examples/model \
--save-interval 200 \
--model-parallel-size ${MP_SIZE} \
--num-layers 24 \
--hidden-size 2048 \
--num-attention-heads 16 \
--batch-size 2 \
--batch-size 1 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--train-iters 20000 \
--train-iters 1000 \
--distributed-backend nccl \
--lr 0.000015 \
--warmup 0.0 \
--lr-decay-style constant \
--lr 0.0002 \
--lr-decay-style cosine \
--weight-decay 1e-2 \
--warmup .01 \
--log-interval 50 \
--fp16 \
--sparse-mode alternating \
--checkpoint-activations \
--deepspeed-activation-checkpointing \
--sparse-mode alternating \
--deepspeed \
--deepspeed_config ../src/deepspeed_config/gpt3_xl_sparse_2048.json \
--deepspeed_config src/deepspeed_config/gpt3_xl_sparse_2048.json \
"

run_cmd="USE_DEEPSPEED=1 mpirun --np ${NUM_GPUS_PER_WORKER} python ../pretrain_gpt3.py $@ ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}
run_cmd="USE_DEEPSPEED=1 python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS_PER_WORKER} pretrain_gpt3.py $@ ${gpt_options}"
echo "${run_cmd}"
eval "${run_cmd}"

set +x
12 changes: 4 additions & 8 deletions src/download_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import os

from transformers.file_utils import (
cached_path,
hf_bucket_url,
is_remote_url,
)
from transformers.utils import logging
from huggingface_hub import hf_hub_download


logger = logging.get_logger(__name__)
WEIGHTS_NAME = "mp_rank_00_model_states.pt"
DEEPSPEED_CONFIG_NAME = "deepspeed_config.json"


def download_model_files(pretrained_model_name_or_path):
weights_path = download_file_from_hf(pretrained_model_name_or_path, WEIGHTS_NAME)
deepspeed_config_path = download_file_from_hf(pretrained_model_name_or_path, DEEPSPEED_CONFIG_NAME)
weights_path = hf_hub_download(pretrained_model_name_or_path, WEIGHTS_NAME)
deepspeed_config_path = hf_hub_download(pretrained_model_name_or_path, DEEPSPEED_CONFIG_NAME)
return weights_path, deepspeed_config_path


Expand Down
8 changes: 5 additions & 3 deletions src/xl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from src import mpu
from .fp16 import FP16_Module
from .model import GPT3Model
from .download_utils import download_model_files
from .download_utils import download_model_files, DEEPSPEED_CONFIG_NAME, hf_hub_download
from transformers.utils import logging


Expand Down Expand Up @@ -80,11 +80,11 @@ def get_model(deepspeed_config_path):

def setup_model(weights_path, deepspeed_config_path):
model = get_model(deepspeed_config_path)
logger.info("Load checkpoint from " + weights_path)
print("Load checkpoint from " + weights_path)
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)['module']
model.load_state_dict(checkpoint, strict=False)
model.eval()
logger.info("Model Loaded")
print("Model Loaded")
return model


Expand Down Expand Up @@ -180,6 +180,8 @@ def from_pretrained(cls, model_name_or_path=None, seq_len=512, weights_path=None
logger.info("Check cached model files...")
if weights_path is None:
weights_path, deepspeed_config_path = download_model_files(model_name_or_path)
if deepspeed_config_path is None:
deepspeed_config_path = hf_hub_download(model_name_or_path, DEEPSPEED_CONFIG_NAME)
model = setup_model(weights_path, deepspeed_config_path)
model.cuda()
model = model.eval()
Expand Down

0 comments on commit 47ccf82

Please sign in to comment.