Skip to content

Commit

Permalink
[Example] add an example of running open llama model in 4D using veSc…
Browse files Browse the repository at this point in the history
…ale (#22)

This PR adds an 4D parallelism example of using veScale to run a [open
llama model](https://huggingface.co/openlm-research/open_llama_7b) that
is directly imported from HuggingFace without any model code
modifications.
  • Loading branch information
jc-bytedance authored Apr 4, 2024
1 parent 364c3b2 commit 028806f
Show file tree
Hide file tree
Showing 12 changed files with 869 additions and 0 deletions.
32 changes: 32 additions & 0 deletions python/example/open_llama_4D_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# veScale Open Llama Example
## Overview
In this directory, we provides an 4D parallelism example of using veScale to run
a [open llama model](https://huggingface.co/openlm-research/open_llama_7b) that is directly imported
from HuggingFace without any model code modifications.


## Run
### Single Machine 8 cards
```
torchrun --standalone --nnodes=1 --nproc-per-node=8 ./run_open_llama_w_vescale.py --dp=4 --tp=2 --warmup=10 --iter=40
```
This will start a 8-cards MFU benchmark for open Llama with veScale with dp=4 and tp=2.

### Distributed Environment (4 Machine 32 cards example)
```
torchrun --nnodes=4 --nproc-per-node=8 --node_rank=$node_rank --master_addr=$master_addr --master_port=$master_port ./run_open_llama_w_vescale.py --dp=16 --tp=2 --warmup=10 --iter=40
```
This will start a 32 cards MFU benchmark for open Llama with veScale with dp=16 and tp=2.

### Options
1. `--total_bsz`: the total number of batch size for one iteration. The default is 16.
2. `--dp`: the amount of data parallelism (DDP). This arg has no default value.
3. `--tp`: the amount of tensor parallelism. This arg has no default value.
4. `--warmup`: the number of warmup iteration performed. The default is 5.
5. `--iter`: the number of iteration used for calculating the MFU. The default is 10.
6. `--no-ckpt"`: This arg turn off loading check points from Huggingface.

## Caveats
1. The scripts are purely for demonstration propose and mfu calculation. You need to write your own training script
it in order to fine-tune open llama with your data.
2. This is a known issue with transformer version greater than 4.37.2. We will be fixing it later.
22 changes: 22 additions & 0 deletions python/example/open_llama_4D_benchmark/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.30.0.dev0",
"use_cache": true,
"vocab_size": 32000
}
23 changes: 23 additions & 0 deletions python/example/open_llama_4D_benchmark/download_open_llama_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################


from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b")

print(model)
29 changes: 29 additions & 0 deletions python/example/open_llama_4D_benchmark/llama_mfu_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

# reference: https://www.adamcasson.com/posts/transformer-flops
# reference: https://arxiv.org/pdf/2001.08361.pdf


def estimate_llama(config, bsz, sqence_length):
embed = 4 * bsz * sqence_length * config.hidden_size
ff = 3 * 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size
attn_mask = 2 * sqence_length * config.hidden_size
attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn = attn_qkv + attn_mask + attn_proj
return embed + (ff + attn) * config.num_hidden_layers
123 changes: 123 additions & 0 deletions python/example/open_llama_4D_benchmark/run_open_llama_w_vescale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
import torch
import argparse

os.environ["VESCALE_DISABLE_RUN_CHECK"] = "1"

from vescale.dtensor.device_mesh import init_device_mesh
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.dmodule.api import parallelize_module
from sharding_plan import sharding_plan

from transformers import AutoModelForCausalLM, AutoConfig, LlamaModel

from llama_mfu_calculator import estimate_llama

local_rank = int(os.environ["LOCAL_RANK"])
parser = argparse.ArgumentParser()
parser.add_argument("--total_bsz", type=int, default=16)
parser.add_argument("--dp", type=int)
parser.add_argument("--tp", type=int)
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument("--iter", type=int, default=10)
parser.add_argument("--no-ckpt", action="store_true")

args = parser.parse_args()

assert args.total_bsz % args.dp == 0, f"total batch size {args.total_bsz} is not divisiable by dp size {args.dp}"
bsz = args.total_bsz // args.dp
s = 2048

# init model
if args.no_ckpt:
dir_path = os.path.dirname(os.path.realpath(__file__))
config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json"))
model = LlamaModel(config)
else:
model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b")
model = model.model
config = model.config
assert s <= config.max_position_embeddings

# -------- training config --------
device_mesh = init_device_mesh(
"cuda",
(
args.dp,
args.tp,
),
mesh_dim_names=("DP", "TP"),
)

input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda()

model = model.cuda().bfloat16()
vescale_model = parallelize_module(model, device_mesh["TP"], sharding_plan)

ddp_model = DDP(
vescale_model,
data_pg_or_device_mesh=device_mesh["DP"],
use_distributed_optimizer=True,
)
orig_optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

ve_optimizer = DistributedOptimizer(
orig_optimizer,
overlap_param_gather=True,
models=[ddp_model],
)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

# -------- warm up --------
for _ in range(args.warmup):
ve_optimizer.zero_grad()
vescale_output = ddp_model(input).last_hidden_state
vescale_loss = vescale_output.mean()
vescale_loss.backward()
ve_optimizer.step()

# -------- training loop --------
start.record()
for _ in range(args.iter):
ve_optimizer.zero_grad()
vescale_output = ddp_model(input).last_hidden_state
vescale_loss = vescale_output.mean()
vescale_loss.backward()
ve_optimizer.step()
end.record()
torch.cuda.synchronize()
exec_t = start.elapsed_time(end) / 1000 / args.iter

if local_rank == 0:
flops_dict = {
"A100": 312,
"H100": 1000,
}
d_name = torch.cuda.get_device_name()
total_flops = flops_dict["A100"] * (10**12) * device_mesh.ndevice
for k, v in flops_dict.items():
if k in d_name:
total_flops = v * (10**12) * device_mesh.ndevice
break
print(f"1 iter time: {exec_t}")
# fwd + bwd =3
print("mfu:", estimate_llama(config, bsz, s) * 3 * args.dp * 100 / exec_t / total_flops)
63 changes: 63 additions & 0 deletions python/example/open_llama_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

from vescale.dtensor.placement_types import Replicate, Shard

# forward resharding plan for a single open llama decoder
_decoder_fwd_resharding_plan = {
"input": {"hidden_states": [Shard(1)], "attention_mask": [Replicate()], "position_ids": [Replicate()]},
# atten
"self_attn.input": {"hidden_states": [Replicate()], "attention_mask": [Replicate()], "position_ids": [Replicate()]},
"self_attn.o_proj.output": [[Shard(1)]],
"self_attn.output": [[Shard(1)], None, None],
# feedforward(mlp)
"mlp.input": [[Replicate()]],
"mlp.output": [[Shard(1)]],
"output": [[Shard(1)], None],
}

# parameter sharding plan for a single open llama decoder
_decoder_param_sharding_plan = {
# atten weight, no bias
"self_attn.q_proj.weight": [Shard(0)],
"self_attn.k_proj.weight": [Shard(0)],
"self_attn.v_proj.weight": [Shard(0)],
"self_attn.o_proj.weight": [Shard(1)],
# feedforward(mlp)
"mlp.up_proj.weight": [Shard(0)],
"mlp.gate_proj.weight": [Shard(0)],
"mlp.down_proj.weight": [Shard(1)],
}

# forward resharding plan for the whole open llama model
model_fwd_resharding_plan = {
".input": [[Replicate()]],
"embed_tokens.output": [[Shard(1)]],
"norm.input": [[Shard(1)]],
".output": {
"last_hidden_state": [Replicate()],
},
**{rf"layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()},
}

# model parameter sharding plan for the whole open llama model
model_param_sharding_plan = {
"embed_tokens.weight": [Shard(1)],
**{rf"layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()},
}

sharding_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan}
22 changes: 22 additions & 0 deletions test/model/open_llama/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.30.0.dev0",
"use_cache": true,
"vocab_size": 32000
}
Loading

0 comments on commit 028806f

Please sign in to comment.