Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] add an example of running open llama model in 4D using veScale #22

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions python/example/open_llama_4D_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# veScale Open Llama Example
## Overview
In this directory, we provides an 4D parallelism example of using veScale to run
a [open llama model](https://huggingface.co/openlm-research/open_llama_7b) that is directly imported
from HuggingFace without any model code modifications.


## Run
### Single Machine 8 cards
```
torchrun --standalone --nnodes=1 --nproc-per-node=8 ./run_open_llama_w_vescale.py --dp=4 --tp=2 --warmup=10 --iter=40
```
This will start a 8-cards MFU benchmark for open Llama with veScale with dp=4 and tp=2.

### Distributed Environment (4 Machine 32 cards example)
```
torchrun --nnodes=4 --nproc-per-node=8 --node_rank=$node_rank --master_addr=$master_addr --master_port=$master_port ./run_open_llama_w_vescale.py --dp=16 --tp=2 --warmup=10 --iter=40
```
This will start a 32 cards MFU benchmark for open Llama with veScale with dp=16 and tp=2.

### Options
1. `--total_bsz`: the total number of batch size for one iteration. The default is 16.
2. `--dp`: the amount of data parallelism (DDP). This arg has no default value.
3. `--tp`: the amount of tensor parallelism. This arg has no default value.
4. `--warmup`: the number of warmup iteration performed. The default is 5.
5. `--iter`: the number of iteration used for calculating the MFU. The default is 10.
6. `--no-ckpt"`: This arg turn off loading check points from Huggingface.

## Caveats
1. The scripts are purely for demonstration propose and mfu calculation. You need to write your own training script
it in order to fine-tune open llama with your data.
2. This is a known issue with transformer version greater than 4.37.2. We will be fixing it later.
22 changes: 22 additions & 0 deletions python/example/open_llama_4D_benchmark/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.30.0.dev0",
"use_cache": true,
"vocab_size": 32000
}
23 changes: 23 additions & 0 deletions python/example/open_llama_4D_benchmark/download_open_llama_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################


from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b")

print(model)
29 changes: 29 additions & 0 deletions python/example/open_llama_4D_benchmark/llama_mfu_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

# reference: https://www.adamcasson.com/posts/transformer-flops
# reference: https://arxiv.org/pdf/2001.08361.pdf


def estimate_llama(config, bsz, sqence_length):
embed = 4 * bsz * sqence_length * config.hidden_size
ff = 3 * 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size
attn_mask = 2 * sqence_length * config.hidden_size
attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn = attn_qkv + attn_mask + attn_proj
return embed + (ff + attn) * config.num_hidden_layers
123 changes: 123 additions & 0 deletions python/example/open_llama_4D_benchmark/run_open_llama_w_vescale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
import torch
import argparse

os.environ["VESCALE_DISABLE_RUN_CHECK"] = "1"

from vescale.dtensor.device_mesh import init_device_mesh
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.dmodule.api import parallelize_module
from sharding_plan import sharding_plan

from transformers import AutoModelForCausalLM, AutoConfig, LlamaModel

from llama_mfu_calculator import estimate_llama

local_rank = int(os.environ["LOCAL_RANK"])
parser = argparse.ArgumentParser()
parser.add_argument("--total_bsz", type=int, default=16)
parser.add_argument("--dp", type=int)
parser.add_argument("--tp", type=int)
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument("--iter", type=int, default=10)
parser.add_argument("--no-ckpt", action="store_true")

args = parser.parse_args()

assert args.total_bsz % args.dp == 0, f"total batch size {args.total_bsz} is not divisiable by dp size {args.dp}"
bsz = args.total_bsz // args.dp
s = 2048

# init model
if args.no_ckpt:
dir_path = os.path.dirname(os.path.realpath(__file__))
config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json"))
model = LlamaModel(config)
else:
model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b")
model = model.model
config = model.config
assert s <= config.max_position_embeddings

# -------- training config --------
device_mesh = init_device_mesh(
"cuda",
(
args.dp,
args.tp,
),
mesh_dim_names=("DP", "TP"),
)

input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda()

model = model.cuda().bfloat16()
vescale_model = parallelize_module(model, device_mesh["TP"], sharding_plan)

ddp_model = DDP(
vescale_model,
data_pg_or_device_mesh=device_mesh["DP"],
use_distributed_optimizer=True,
)
orig_optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

ve_optimizer = DistributedOptimizer(
orig_optimizer,
overlap_param_gather=True,
models=[ddp_model],
)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

# -------- warm up --------
for _ in range(args.warmup):
ve_optimizer.zero_grad()
vescale_output = ddp_model(input).last_hidden_state
vescale_loss = vescale_output.mean()
vescale_loss.backward()
ve_optimizer.step()

# -------- training loop --------
start.record()
for _ in range(args.iter):
ve_optimizer.zero_grad()
vescale_output = ddp_model(input).last_hidden_state
vescale_loss = vescale_output.mean()
vescale_loss.backward()
ve_optimizer.step()
end.record()
torch.cuda.synchronize()
exec_t = start.elapsed_time(end) / 1000 / args.iter

if local_rank == 0:
flops_dict = {
"A100": 312,
"H100": 1000,
}
d_name = torch.cuda.get_device_name()
total_flops = flops_dict["A100"] * (10**12) * device_mesh.ndevice
for k, v in flops_dict.items():
if k in d_name:
total_flops = v * (10**12) * device_mesh.ndevice
break
print(f"1 iter time: {exec_t}")
# fwd + bwd =3
print("mfu:", estimate_llama(config, bsz, s) * 3 * args.dp * 100 / exec_t / total_flops)
63 changes: 63 additions & 0 deletions python/example/open_llama_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

from vescale.dtensor.placement_types import Replicate, Shard

# forward resharding plan for a single open llama decoder
_decoder_fwd_resharding_plan = {
"input": {"hidden_states": [Shard(1)], "attention_mask": [Replicate()], "position_ids": [Replicate()]},
# atten
"self_attn.input": {"hidden_states": [Replicate()], "attention_mask": [Replicate()], "position_ids": [Replicate()]},
"self_attn.o_proj.output": [[Shard(1)]],
"self_attn.output": [[Shard(1)], None, None],
# feedforward(mlp)
"mlp.input": [[Replicate()]],
"mlp.output": [[Shard(1)]],
"output": [[Shard(1)], None],
}

# parameter sharding plan for a single open llama decoder
_decoder_param_sharding_plan = {
# atten weight, no bias
"self_attn.q_proj.weight": [Shard(0)],
"self_attn.k_proj.weight": [Shard(0)],
"self_attn.v_proj.weight": [Shard(0)],
"self_attn.o_proj.weight": [Shard(1)],
# feedforward(mlp)
"mlp.up_proj.weight": [Shard(0)],
"mlp.gate_proj.weight": [Shard(0)],
"mlp.down_proj.weight": [Shard(1)],
}

# forward resharding plan for the whole open llama model
model_fwd_resharding_plan = {
".input": [[Replicate()]],
"embed_tokens.output": [[Shard(1)]],
"norm.input": [[Shard(1)]],
".output": {
"last_hidden_state": [Replicate()],
},
**{rf"layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()},
}

# model parameter sharding plan for the whole open llama model
model_param_sharding_plan = {
"embed_tokens.weight": [Shard(1)],
**{rf"layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()},
}

sharding_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan}
22 changes: 22 additions & 0 deletions test/model/open_llama/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.30.0.dev0",
"use_cache": true,
"vocab_size": 32000
}
Loading