Skip to content

Commit

Permalink
[DTensor&DModule&DDP&Examples] feature updates and new examples (#35)
Browse files Browse the repository at this point in the history
In this PR, we add two examples and update some features in DTensor,
DModule, and DDP.

## Examples

1. 4D finetuning the llama2_3b model.
2. 4D pretraining a mixtral MOE-based model

## DTensor

1. Update op strategies on `Partial`ed and `InterleavedShard`ed
dtensors.
2. Add all-to-all communications.

## DModule

1. Support factory methods for nested submodules

## DDP

1. Unblock gradient allreduce for sparse modules in DDP
  • Loading branch information
lichen225 authored May 21, 2024
1 parent 9047a73 commit 2a072bf
Show file tree
Hide file tree
Showing 62 changed files with 3,132 additions and 693 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
/**/*.tar.gz
/**/*.json.gz
/**/*.log
/**/.DS_Store
*_checkpoint_dir

# pre-commit config
./.pre-commit-config.yaml
Expand Down
32 changes: 32 additions & 0 deletions examples/llama2_4D_finetune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Finetune a Llama2 3b model in 4D parallelism using veScale

## Overview

Finetune a pretrained llama2_3b model on a small Shakespeare dataset.
Dropout is set to 0 for this model, thus no randomness is involved during finetuning.
The reason for choosing llama2_3b instead of the 7b one is that it fits in 1 GPU so that we can check the correctness of veScale.

## Prerequisite

```
pip3 install sentencepiece
```

## Run

```
cd data/shakespeare/ && python3 prepare.py && cd ../..
torchrun --standalone --nproc_per_node={GPU_CNT} llama_train.py --dp={dp_size} --tp={tp_size} --max_iters={max_iters}
```

## Experiments

Like nanoGPT, we finetune the model with a constant learning rate `3e-5` and set `grad_clip = 1`.
The model state as well as the gradients and the optimizer states are in `bf16`.

![](./figures/llama2_3b_train_losses.jpg)


## Caveats

1. Currently, it does not works with `transformers==4.38.2`. The error happens when doing a backward step, the `aten._scaled_dot_product_efficient_attention` operator outputs the error message: `attn_bias: wrong shape (head dimension)`.
57 changes: 57 additions & 0 deletions examples/llama2_4D_finetune/data/shakespeare/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
################################################################################
# Copyright (c) 2022 Andrej Karpathy

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
################################################################################
# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
################################################################################

import os
import requests
import numpy as np
from transformers import LlamaTokenizer

# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(__file__), "input.txt")
if not os.path.exists(input_file_path):
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
with open(input_file_path, "w", encoding="utf-8") as f:
f.write(requests.get(data_url).text)

with open(input_file_path, encoding="utf-8") as f:
data = f.read()
n = len(data)
train_data = data[: int(n * 0.9)]
val_data = data[int(n * 0.9) :]

# tokenize with llama2 tokenizer
tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_7b")
train_ids = tokenizer.encode(train_data)
val_ids = tokenizer.encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(__file__), "train.bin"))
val_ids.tofile(os.path.join(os.path.dirname(__file__), "val.bin"))

# train.bin has 318,905 tokens
# val.bin has 37,782 tokens
9 changes: 9 additions & 0 deletions examples/llama2_4D_finetune/data/shakespeare/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

# tiny shakespeare

Tiny shakespeare, of the good old char-rnn fame :)

After running `prepare.py`:

- train.bin has 318,905 tokens
- val.bin has 37,782 tokens
64 changes: 64 additions & 0 deletions examples/llama2_4D_finetune/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import os
from typing import Optional

import numpy as np
import torch

from vescale.dtensor.device_mesh import DeviceMesh
from vescale import distribute_tensor
from vescale.dtensor.placement_types import Replicate
from vescale.dtensor import empty as d_empty


class DataLoader:
def __init__(self, dataset: str, seqlen: int, mesh: Optional[DeviceMesh] = None, dp_rank: int = 0):
self.data_dir = os.path.join("data", dataset)
self.seqlen = seqlen
self.mesh = mesh
self.dp_rank = dp_rank
if mesh is not None:
self.device_type = mesh.device_type
else:
self.device_type = "cuda"

def get_batch(self, split, bsz, lbsz):
# We recreate np.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
if split == "train":
data = np.memmap(os.path.join(self.data_dir, "train.bin"), dtype=np.uint16, mode="r")
else:
data = np.memmap(os.path.join(self.data_dir, "val.bin"), dtype=np.uint16, mode="r")
if self.mesh is not None:
ix = d_empty((bsz,), device_mesh=self.mesh, placements=[Replicate()])
else:
ix = torch.empty((bsz,), device="cuda")
ix = torch.randint_like(ix, len(data) - self.seqlen, dtype=torch.int64)
if self.mesh is not None:
ix = ix.to_local()
if self.mesh is None or self.mesh.get_rank() == 0:
print(f"sum(ix) {sum(ix)}")
ix = torch.split(ix, lbsz)[self.dp_rank]
x = torch.stack([torch.from_numpy((data[i : i + self.seqlen]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + self.seqlen]).astype(np.int64)) for i in ix])
x, y = x.to(self.device_type), y.to(self.device_type)
if self.mesh is not None:
x = distribute_tensor(x, self.mesh["TP"], [Replicate()])
y = distribute_tensor(y, self.mesh["TP"], [Replicate()])
return x, y
98 changes: 98 additions & 0 deletions examples/llama2_4D_finetune/exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import os
import re


def parse_train_loss(log_fn, name=None):
lines = open(log_fn).readlines()
train_losses = []
for line in lines:
if "loss" in line and "iter" in line:
token = line.split()[line.split().index("loss") + 1]
train_loss = float(token)
train_losses.append(train_loss)
if name is None:
name = log_fn
print(f'"{name}": {train_losses},')


def parse(log_fn, name=None):
lines = open(log_fn).readlines()
val_losses = []
for line in lines:
if "val_loss" in line:
token = line.split()[line.split().index("val_loss:") + 1]
val_loss = float(token)
val_losses.append(val_loss)
if name is None:
name = log_fn
print(f'"{name}": {val_losses},')


GPU_CNT = 4
DP_SIZES = [1, 2, 4]
# DP_SIZES = [4]
SINGLE_GPU_RUN = "python3"
MULTI_GPU_RUN = f"torchrun --standalone --nproc_per_node={GPU_CNT}"
CODE = "llama_train.py"
LOG_PREFIX = "llama2"
TRAIN_BIN_PATH = "data/shakespeare/train.bin"


def run_exps(max_iters, dtypes, run=True):
if not os.path.isfile(TRAIN_BIN_PATH):
os.system(f"cd data/shakespeare/ && python3 prepare.py && cd ../..")
os.makedirs("logs", exist_ok=True)
if run:
for dtype in dtypes:
dt = "bfloat16" if dtype == "bf16" else "float32"
cmd = f"{SINGLE_GPU_RUN} {CODE} --dp=1 --tp=1 --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log"
# print(f"run {cmd} > {log_fn} 2> {log_fn}.err")
# os.system(f"{cmd} > {log_fn} 2> {log_fn}.err")
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
dt = "bfloat16" if dtype == "bf16" else "float32"
cmd = f"{MULTI_GPU_RUN} {CODE} --dp={dp_size} --tp={tp_size} --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
print(f"run {cmd} > {log_fn} 2> {log_fn}.err")
os.system(f"{cmd} > {log_fn} 2> {log_fn}.err")

print("train_loss = {")
for dtype in dtypes:
parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}")
print("}")

# print("val_loss = {")
# for dtype in dtypes:
# # parse(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
# for dp_size in DP_SIZES:
# tp_size = GPU_CNT // dp_size
# log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
# parse(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}")
# print("}")


if __name__ == "__main__":
run_exps(100000, ["bf16"], run=True)
# run_exps(10, ["bf16"], run=False)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 2a072bf

Please sign in to comment.