Skip to content

Commit

Permalink
try fix loading deepspeed for generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Emelyanov committed Feb 11, 2021
1 parent 31dd756 commit 49c00eb
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
55 changes: 54 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ print(generated_text)

For more information about 🤗HuggingFace interface please follow this [documentation](https://huggingface.co/transformers/main_classes/model.html#transformers.generation_utils.GenerationMixin.generate).

#### Data issues
For training pass single txt file.

## Megatron interface
### Without deepspeed
For using our code for finetuning without deepspeed (not recommended) we should install apex:

```bash
Expand All @@ -64,7 +68,56 @@ sh setup.sh

Example of finetuning, generating and loading/convert megatron checkpoints [here](examples/Finetune_and_generate_RuGPTs_only_with_megatron.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sberbank-ai/ru-gpts/blob/master/examples/Finetune_and_generate_RuGPTs_only_with_megatron.ipynb)

Note! This way is valid for all RuGPTs models except RuGPT3XL.
**Note!** This way is valid for all RuGPTs models except RuGPT3XL.

### Megatron with deepspeed
For using our code for finetuning with deepspeed (recommended) we should install apex (see previous section) and deepspeed:

```bash
pip install deepspeed==0.3.7
```

Example of finetuning, generating and loading/convert megatron checkpoints [here](examples/Finetune_and_generate_RuGPTs_deepspeed_megatron.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sberbank-ai/ru-gpts/blob/master/examples/Finetune_and_generate_RuGPTs_deepspeed_megatron.ipynb)

**Note!** For using deepspeed we should specify environ variable before all your python scripts and run with torch.distributed or mpi:

```
USE_DEEPSPEED=1 python -m torch.distributed.launch --nproc_per_node 1 ru-gpts/pretrain_gpt3.py \
--train-data-path "train.list" \
--test-data-path "valid.list" \
--max-files-per-process 100 \
--save model \
--load-huggingface sberbank-ai/rugpt3small_based_on_gpt2 \
--model-parallel-size 1 \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--fp16 \
--checkpoint-activations \
--deepspeed-activation-checkpointing \
--deepspeed \
--deepspeed_config ru-gpts/src/deepspeed_config/gpt3_small_2048.json
```

#### Data issues
We use custom implementation of distributed dataset. For training and evaluating we should specify file `file.list` with list of paths to txt files. All files from `file.list` will be splitted between aviable GPUs. The logic of splitting is described by the following code:

```python
shard_size = len(files) // world_size
shard_start = rank * shard_size
shard_end = (rank + 1) * shard_size
files = files[shard_start:shard_end]
```

For more details please see full code of dataset: `src.dataset_rugpt3.RuGpt3TextDataset` and example.

**Note!** This way is valid for all RuGPTs models except RuGPT3XL.






## Setup ruGPT3XL
Expand Down
11 changes: 11 additions & 0 deletions generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def setup_model(args):
"""Setup model and optimizer."""

model = get_model(args)
if DEEPSPEED_WRAP and args.deepspeed:
print_rank_0("DeepSpeed is enabled.")

model, optimizer, _, lr_scheduler = DEEPSPEED_WRAP.deepspeed.initialize(
model=model,
optimizer=None,
args=args,
lr_scheduler=None,
mpu=mpu,
dist_init_required=False
)

print("Load checkpoint from " + args.load)
_ = load_checkpoint(model, None, None, args, deepspeed=DEEPSPEED_WRAP and args.deepspeed)
Expand Down

0 comments on commit 49c00eb

Please sign in to comment.