From 6107bcf3b766422bf18f0b92c257a533bc2ece1b Mon Sep 17 00:00:00 2001 From: jiachengyang Date: Wed, 3 Apr 2024 16:58:09 -0700 Subject: [PATCH] add open llama example and tests --- .../example/open_llama_4D_benchmark/README.md | 32 ++++ .../open_llama_4D_benchmark/config.json | 22 +++ .../download_open_llama_ckpt.py | 23 +++ .../llama_mfu_calculator.py | 29 ++++ .../run_open_llama_w_vescale.py | 123 +++++++++++++ .../open_llama_4D_benchmark/sharding_plan.py | 63 +++++++ test/model/open_llama/config.json | 22 +++ test/model/open_llama/test_attention.py | 96 ++++++++++ test/model/open_llama/test_decoder_layer.py | 108 ++++++++++++ test/model/open_llama/test_mlp.py | 94 ++++++++++ test/model/open_llama/test_open_llama.py | 164 ++++++++++++++++++ test/model/open_llama/test_rms_norm.py | 93 ++++++++++ 12 files changed, 869 insertions(+) create mode 100644 python/example/open_llama_4D_benchmark/README.md create mode 100644 python/example/open_llama_4D_benchmark/config.json create mode 100644 python/example/open_llama_4D_benchmark/download_open_llama_ckpt.py create mode 100644 python/example/open_llama_4D_benchmark/llama_mfu_calculator.py create mode 100644 python/example/open_llama_4D_benchmark/run_open_llama_w_vescale.py create mode 100644 python/example/open_llama_4D_benchmark/sharding_plan.py create mode 100644 test/model/open_llama/config.json create mode 100644 test/model/open_llama/test_attention.py create mode 100644 test/model/open_llama/test_decoder_layer.py create mode 100644 test/model/open_llama/test_mlp.py create mode 100644 test/model/open_llama/test_open_llama.py create mode 100644 test/model/open_llama/test_rms_norm.py diff --git a/python/example/open_llama_4D_benchmark/README.md b/python/example/open_llama_4D_benchmark/README.md new file mode 100644 index 0000000..a42d18f --- /dev/null +++ b/python/example/open_llama_4D_benchmark/README.md @@ -0,0 +1,32 @@ +# veScale Open Llama Example +## Overview +In this directory, we provides an 4D parallelism example of using veScale to run +a [open llama model](https://huggingface.co/openlm-research/open_llama_7b) that is directly imported +from HuggingFace without any model code modifications. + + +## Run +### Single Machine 8 cards +``` +torchrun --standalone --nnodes=1 --nproc-per-node=8 ./run_open_llama_w_vescale.py --dp=4 --tp=2 --warmup=10 --iter=40 +``` +This will start a 8-cards MFU benchmark for open Llama with veScale with dp=4 and tp=2. + +### Distributed Environment (4 Machine 32 cards example) +``` +torchrun --nnodes=4 --nproc-per-node=8 --node_rank=$node_rank --master_addr=$master_addr --master_port=$master_port ./run_open_llama_w_vescale.py --dp=16 --tp=2 --warmup=10 --iter=40 +``` +This will start a 32 cards MFU benchmark for open Llama with veScale with dp=16 and tp=2. + +### Options +1. `--total_bsz`: the total number of batch size for one iteration. The default is 16. +2. `--dp`: the amount of data parallelism (DDP). This arg has no default value. +3. `--tp`: the amount of tensor parallelism. This arg has no default value. +4. `--warmup`: the number of warmup iteration performed. The default is 5. +5. `--iter`: the number of iteration used for calculating the MFU. The default is 10. +6. `--no-ckpt"`: This arg turn off loading check points from Huggingface. + +## Caveats +1. The scripts are purely for demonstration propose and mfu calculation. You need to write your own training script + it in order to fine-tune open llama with your data. +2. This is a known issue with transformer version greater than 4.37.2. We will be fixing it later. \ No newline at end of file diff --git a/python/example/open_llama_4D_benchmark/config.json b/python/example/open_llama_4D_benchmark/config.json new file mode 100644 index 0000000..bef60c9 --- /dev/null +++ b/python/example/open_llama_4D_benchmark/config.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "pad_token_id": 0, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.30.0.dev0", + "use_cache": true, + "vocab_size": 32000 + } \ No newline at end of file diff --git a/python/example/open_llama_4D_benchmark/download_open_llama_ckpt.py b/python/example/open_llama_4D_benchmark/download_open_llama_ckpt.py new file mode 100644 index 0000000..876228a --- /dev/null +++ b/python/example/open_llama_4D_benchmark/download_open_llama_ckpt.py @@ -0,0 +1,23 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b") + +print(model) diff --git a/python/example/open_llama_4D_benchmark/llama_mfu_calculator.py b/python/example/open_llama_4D_benchmark/llama_mfu_calculator.py new file mode 100644 index 0000000..9bacdd5 --- /dev/null +++ b/python/example/open_llama_4D_benchmark/llama_mfu_calculator.py @@ -0,0 +1,29 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +# reference: https://www.adamcasson.com/posts/transformer-flops +# reference: https://arxiv.org/pdf/2001.08361.pdf + + +def estimate_llama(config, bsz, sqence_length): + embed = 4 * bsz * sqence_length * config.hidden_size + ff = 3 * 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length + attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size + attn_mask = 2 * sqence_length * config.hidden_size + attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length + attn = attn_qkv + attn_mask + attn_proj + return embed + (ff + attn) * config.num_hidden_layers diff --git a/python/example/open_llama_4D_benchmark/run_open_llama_w_vescale.py b/python/example/open_llama_4D_benchmark/run_open_llama_w_vescale.py new file mode 100644 index 0000000..8117551 --- /dev/null +++ b/python/example/open_llama_4D_benchmark/run_open_llama_w_vescale.py @@ -0,0 +1,123 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import os +import torch +import argparse + +os.environ["VESCALE_DISABLE_RUN_CHECK"] = "1" + +from vescale.dtensor.device_mesh import init_device_mesh +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.optim.distributed_optimizer import DistributedOptimizer +from vescale.dmodule.api import parallelize_module +from sharding_plan import sharding_plan + +from transformers import AutoModelForCausalLM, AutoConfig, LlamaModel + +from llama_mfu_calculator import estimate_llama + +local_rank = int(os.environ["LOCAL_RANK"]) +parser = argparse.ArgumentParser() +parser.add_argument("--total_bsz", type=int, default=16) +parser.add_argument("--dp", type=int) +parser.add_argument("--tp", type=int) +parser.add_argument("--warmup", type=int, default=5) +parser.add_argument("--iter", type=int, default=10) +parser.add_argument("--no-ckpt", action="store_true") + +args = parser.parse_args() + +assert args.total_bsz % args.dp == 0, f"total batch size {args.total_bsz} is not divisiable by dp size {args.dp}" +bsz = args.total_bsz // args.dp +s = 2048 + +# init model +if args.no_ckpt: + dir_path = os.path.dirname(os.path.realpath(__file__)) + config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json")) + model = LlamaModel(config) +else: + model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b") + model = model.model + config = model.config +assert s <= config.max_position_embeddings + +# -------- training config -------- +device_mesh = init_device_mesh( + "cuda", + ( + args.dp, + args.tp, + ), + mesh_dim_names=("DP", "TP"), +) + +input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda() + +model = model.cuda().bfloat16() +vescale_model = parallelize_module(model, device_mesh["TP"], sharding_plan) + +ddp_model = DDP( + vescale_model, + data_pg_or_device_mesh=device_mesh["DP"], + use_distributed_optimizer=True, +) +orig_optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01) + +ve_optimizer = DistributedOptimizer( + orig_optimizer, + overlap_param_gather=True, + models=[ddp_model], +) + +start = torch.cuda.Event(enable_timing=True) +end = torch.cuda.Event(enable_timing=True) + +# -------- warm up -------- +for _ in range(args.warmup): + ve_optimizer.zero_grad() + vescale_output = ddp_model(input).last_hidden_state + vescale_loss = vescale_output.mean() + vescale_loss.backward() + ve_optimizer.step() + +# -------- training loop -------- +start.record() +for _ in range(args.iter): + ve_optimizer.zero_grad() + vescale_output = ddp_model(input).last_hidden_state + vescale_loss = vescale_output.mean() + vescale_loss.backward() + ve_optimizer.step() +end.record() +torch.cuda.synchronize() +exec_t = start.elapsed_time(end) / 1000 / args.iter + +if local_rank == 0: + flops_dict = { + "A100": 312, + "H100": 1000, + } + d_name = torch.cuda.get_device_name() + total_flops = flops_dict["A100"] * (10**12) * device_mesh.ndevice + for k, v in flops_dict.items(): + if k in d_name: + total_flops = v * (10**12) * device_mesh.ndevice + break + print(f"1 iter time: {exec_t}") + # fwd + bwd =3 + print("mfu:", estimate_llama(config, bsz, s) * 3 * args.dp * 100 / exec_t / total_flops) diff --git a/python/example/open_llama_4D_benchmark/sharding_plan.py b/python/example/open_llama_4D_benchmark/sharding_plan.py new file mode 100644 index 0000000..12bcd65 --- /dev/null +++ b/python/example/open_llama_4D_benchmark/sharding_plan.py @@ -0,0 +1,63 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from vescale.dtensor.placement_types import Replicate, Shard + +# forward resharding plan for a single open llama decoder +_decoder_fwd_resharding_plan = { + "input": {"hidden_states": [Shard(1)], "attention_mask": [Replicate()], "position_ids": [Replicate()]}, + # atten + "self_attn.input": {"hidden_states": [Replicate()], "attention_mask": [Replicate()], "position_ids": [Replicate()]}, + "self_attn.o_proj.output": [[Shard(1)]], + "self_attn.output": [[Shard(1)], None, None], + # feedforward(mlp) + "mlp.input": [[Replicate()]], + "mlp.output": [[Shard(1)]], + "output": [[Shard(1)], None], +} + +# parameter sharding plan for a single open llama decoder +_decoder_param_sharding_plan = { + # atten weight, no bias + "self_attn.q_proj.weight": [Shard(0)], + "self_attn.k_proj.weight": [Shard(0)], + "self_attn.v_proj.weight": [Shard(0)], + "self_attn.o_proj.weight": [Shard(1)], + # feedforward(mlp) + "mlp.up_proj.weight": [Shard(0)], + "mlp.gate_proj.weight": [Shard(0)], + "mlp.down_proj.weight": [Shard(1)], +} + +# forward resharding plan for the whole open llama model +model_fwd_resharding_plan = { + ".input": [[Replicate()]], + "embed_tokens.output": [[Shard(1)]], + "norm.input": [[Shard(1)]], + ".output": { + "last_hidden_state": [Replicate()], + }, + **{rf"layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()}, +} + +# model parameter sharding plan for the whole open llama model +model_param_sharding_plan = { + "embed_tokens.weight": [Shard(1)], + **{rf"layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()}, +} + +sharding_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan} diff --git a/test/model/open_llama/config.json b/test/model/open_llama/config.json new file mode 100644 index 0000000..bef60c9 --- /dev/null +++ b/test/model/open_llama/config.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "pad_token_id": 0, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.30.0.dev0", + "use_cache": true, + "vocab_size": 32000 + } \ No newline at end of file diff --git a/test/model/open_llama/test_attention.py b/test/model/open_llama/test_attention.py new file mode 100644 index 0000000..f014531 --- /dev/null +++ b/test/model/open_llama/test_attention.py @@ -0,0 +1,96 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +import torch +from torch.testing._internal.common_utils import run_tests + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.api import distribute_tensor +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dmodule.api import parallelize_module + +from common_dtensor import DTensorTestBase, with_comms + +from transformers import AutoConfig, LlamaModel + +dir_path = os.path.dirname(os.path.realpath(__file__)) + + +def get_model(): + config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json")) + config.num_hidden_layers = 1 + model = LlamaModel(config) + attn = model.layers[0].self_attn + return attn, config + + +class AttentionTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + @with_comms + def test_attention(self): + bsz = 6 + s = 18 + hidden_size = 4096 + # -----------golden----------- + + input = torch.rand(bsz, s, hidden_size).cuda() + input.requires_grad_() + input.retain_grad() + non_parallel_attention, _ = get_model() + non_parallel_attention = non_parallel_attention.cuda() + golden_outputs = non_parallel_attention(input) + golden_loss = golden_outputs[0].mean() + golden_loss.backward() + + # -----------vescale---------- + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + vescale_attention, config = get_model() + fwd_resharding_plan = { + # atten + ".input": {"hidden_states": [Replicate()]}, + "o_proj.output": [[Shard(1)]], + ".output": [[Shard(1)], None, None], + } + param_sharding_plan = { + # atten weight, no bias + "q_proj.weight": [Shard(0)], + "k_proj.weight": [Shard(0)], + "v_proj.weight": [Shard(0)], + "o_proj.weight": [Shard(1)], + } + + vescale_attention = parallelize_module( + vescale_attention, device_mesh, {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} + ) + + d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)]) + d_input.requires_grad_() + d_input.retain_grad() + + vescale_outputs = vescale_attention(d_input) + vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim) + vescale_loss = vescale_outputs[0].mean() + + vescale_loss.backward() + + +if __name__ == "__main__": + run_tests() diff --git a/test/model/open_llama/test_decoder_layer.py b/test/model/open_llama/test_decoder_layer.py new file mode 100644 index 0000000..c55ac9a --- /dev/null +++ b/test/model/open_llama/test_decoder_layer.py @@ -0,0 +1,108 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +import torch +from torch.testing._internal.common_utils import run_tests + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.api import distribute_tensor +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dmodule.api import parallelize_module + +from common_dtensor import DTensorTestBase, with_comms + +from transformers import AutoConfig, LlamaModel + +dir_path = os.path.dirname(os.path.realpath(__file__)) + + +def get_model(): + config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json")) + config.num_hidden_layers = 1 + model = LlamaModel(config) + docoder = model.layers[0] + return docoder, config + + +class DecoderTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + @with_comms + def test_decoder(self): + bsz = 6 + s = 18 + hidden_size = 4096 + # -----------golden----------- + + input = torch.rand(bsz, s, hidden_size).cuda() + input.requires_grad_() + input.retain_grad() + non_parallel_decoder, _ = get_model() + non_parallel_decoder = non_parallel_decoder.cuda() + golden_outputs = non_parallel_decoder(input) + golden_loss = golden_outputs[0].mean() + golden_loss.backward() + + # -----------vescale---------- + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + vescale_decoder, config = get_model() + fwd_resharding_plan = { + ".input": [[Shard(1)]], + # atten + "self_attn.input": {"hidden_states": [Replicate()]}, + "self_attn.o_proj.output": [[Shard(1)]], + "self_attn.output": [[Shard(1)], None, None], + # feedforward(mlp) no bias + "mlp.input": [[Replicate()]], + "mlp.output": [[Shard(1)]], + ".output": [[Shard(1)]], + } + param_sharding_plan = { + # atten weight, no bias + "self_attn.q_proj.weight": [Shard(0)], + "self_attn.k_proj.weight": [Shard(0)], + "self_attn.v_proj.weight": [Shard(0)], + "self_attn.o_proj.weight": [Shard(1)], + # feedforward(mlp) + "mlp.up_proj.weight": [Shard(0)], + "mlp.gate_proj.weight": [Shard(0)], + "mlp.down_proj.weight": [Shard(1)], + } + + vescale_decoder = parallelize_module( + vescale_decoder, + device_mesh, + {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}, + ) + + d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)]) + d_input.requires_grad_() + d_input.retain_grad() + + vescale_outputs = vescale_decoder(d_input) + vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim) + vescale_loss = vescale_outputs[0].mean() + + vescale_loss.backward() + vescale_decoder.finish_grad_sync() + + +if __name__ == "__main__": + run_tests() diff --git a/test/model/open_llama/test_mlp.py b/test/model/open_llama/test_mlp.py new file mode 100644 index 0000000..382802e --- /dev/null +++ b/test/model/open_llama/test_mlp.py @@ -0,0 +1,94 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +import torch +from torch.testing._internal.common_utils import run_tests + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.api import distribute_tensor +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dmodule.api import parallelize_module + +from common_dtensor import DTensorTestBase, with_comms + +from transformers import AutoConfig, LlamaModel + +dir_path = os.path.dirname(os.path.realpath(__file__)) + + +def get_model(): + config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json")) + config.num_hidden_layers = 1 + model = LlamaModel(config) + mlp = model.layers[0].mlp + return mlp, config + + +class FeedForwardTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + @with_comms + def test_feed_forward(self): + bsz = 6 + s = 18 + hidden_size = 4096 + # -----------golden----------- + + input = torch.rand(bsz, s, hidden_size).cuda() + input.requires_grad_() + input.retain_grad() + non_parallel_mlp, _ = get_model() + non_parallel_mlp = non_parallel_mlp.cuda() + golden_output = non_parallel_mlp(input) + golden_loss = golden_output.mean() + golden_loss.backward() + + # -----------vescale---------- + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + fwd_resharding_plan = { + # feedforward(mlp) no bias + ".input": [[Replicate()]], + ".output": [[Shard(1)]], + } + param_sharding_plan = { + # feedforward(mlp) + "up_proj.weight": [Shard(0)], + "gate_proj.weight": [Shard(0)], + "down_proj.weight": [Shard(1)], + } + + vescale_mlp, _ = get_model() + vescale_mlp = parallelize_module( + vescale_mlp, device_mesh, {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} + ) + + d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)]) + d_input.requires_grad_() + d_input.retain_grad() + + vescale_output = vescale_mlp(d_input) + vescale_output = vescale_output.redistribute(placements=[Replicate()] * device_mesh.ndim) + vescale_loss = vescale_output.mean() + + vescale_loss.backward() + + +if __name__ == "__main__": + run_tests() diff --git a/test/model/open_llama/test_open_llama.py b/test/model/open_llama/test_open_llama.py new file mode 100644 index 0000000..4090b77 --- /dev/null +++ b/test/model/open_llama/test_open_llama.py @@ -0,0 +1,164 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +import torch +from torch.testing._internal.common_utils import run_tests + +from vescale.dtensor.device_mesh import DeviceMesh, init_device_mesh +from vescale.dtensor.api import distribute_tensor +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.optim.distributed_optimizer import DistributedOptimizer +from vescale.dmodule.api import parallelize_module + +from common_dtensor import DTensorTestBase, with_comms + +from transformers import AutoConfig, LlamaModel + + +dir_path = os.path.dirname(os.path.realpath(__file__)) + + +def get_model(layer_number=None): + config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json")) + if layer_number is None: + config.num_hidden_layers = 1 + else: + config.num_hidden_layers = layer_number + model = LlamaModel(config) + model = model + return model, config + + +decoder_fwd_resharding_plan = { + "input": { + "hidden_states": [Shard(1)], + # "attention_mask": [Replicate()], + "position_ids": [Replicate()], + }, + # atten + "self_attn.input": { + "hidden_states": [Replicate()], + }, + "self_attn.o_proj.output": [[Shard(1)]], + "self_attn.output": [[Shard(1)], None, None], + # feedforward(mlp) no bias + "mlp.input": [[Replicate()]], + "mlp.output": [[Shard(1)]], + "output": [[Shard(1)], None], +} +decoder_param_sharding_plan = { + # atten weight, no bias + "self_attn.q_proj.weight": [Shard(0)], + "self_attn.k_proj.weight": [Shard(0)], + "self_attn.v_proj.weight": [Shard(0)], + "self_attn.o_proj.weight": [Shard(1)], + # feedforward(mlp) + "mlp.up_proj.weight": [Shard(0)], + "mlp.gate_proj.weight": [Shard(0)], + "mlp.down_proj.weight": [Shard(1)], +} + +model_fwd_resharding_plan = { + ".input": [[Replicate()]], + "norm.input": [[Shard(1)]], + ".output": { + "last_hidden_state": [Replicate()], + }, + **{rf"layers.\d+.{k}": v for k, v in decoder_fwd_resharding_plan.items()}, +} +model_param_sharding_plan = { + "embed_tokens.weight": [Shard(1)], + **{rf"layers.\d+.{k}": v for k, v in decoder_param_sharding_plan.items()}, +} + + +class llama2Test(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_llama2_layer4(self): + bsz = 6 + s = 18 + # -----------golden----------- + + non_parallel_llama2, config = get_model(layer_number=4) + input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda() + non_parallel_llama2 = non_parallel_llama2.cuda() + golden_output = non_parallel_llama2(input).last_hidden_state + golden_loss = golden_output.mean() + golden_loss.backward() + + # -----------vescale---------- + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + vescale_model, config = get_model(layer_number=4) + + vescale_model = parallelize_module( + vescale_model, + device_mesh, + {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan}, + ) + + d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)]) + + vescale_output = vescale_model(d_input).last_hidden_state + vescale_output = vescale_output.redistribute(placements=[Replicate()] * device_mesh.ndim) + vescale_loss = vescale_output.mean() + + vescale_loss.backward() + vescale_model.finish_grad_sync() + + @with_comms + def test_llama2_layer32_with_ddp(self): + bsz = 6 + s = 18 + device_mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("DP", "TP")) + vescale_model, config = get_model() + input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda() + + vescale_model = parallelize_module( + vescale_model, + device_mesh["TP"], + {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan}, + ) + + ddp_model = DDP( + vescale_model, + data_pg_or_device_mesh=device_mesh["DP"], + use_distributed_optimizer=True, + ) + orig_optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01) + + ve_optimizer = DistributedOptimizer( + orig_optimizer, + clip_grad=0.0, + overlap_param_gather=True, + models=[ddp_model], + ) + + ve_optimizer.zero_grad() + vescale_output = ddp_model(input.detach()).last_hidden_state + vescale_loss = vescale_output.mean() + vescale_loss.backward() + ve_optimizer.step() + + +if __name__ == "__main__": + run_tests() diff --git a/test/model/open_llama/test_rms_norm.py b/test/model/open_llama/test_rms_norm.py new file mode 100644 index 0000000..3e3824b --- /dev/null +++ b/test/model/open_llama/test_rms_norm.py @@ -0,0 +1,93 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +import torch +from torch.testing._internal.common_utils import run_tests + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.api import distribute_tensor +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dmodule.api import parallelize_module + +from common_dtensor import DTensorTestBase, with_comms + +from transformers import AutoConfig, LlamaModel + +dir_path = os.path.dirname(os.path.realpath(__file__)) + + +def get_model(): + config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json")) + config.num_hidden_layers = 1 + model = LlamaModel(config) + rms_norm = model.layers[0].input_layernorm + return rms_norm + + +class RMSNormTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + @with_comms + def test_rms_norm(self): + bsz = 6 + s = 18 + hidden_size = 4096 + # -----------golden----------- + + input = torch.rand(bsz, s, hidden_size).cuda() + input.requires_grad_() + input.retain_grad() + non_parallel_norm = get_model().cuda() + golden_output = non_parallel_norm(input) + golden_loss = golden_output.mean() + golden_loss.backward() + + # -----------vescale---------- + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + vescale_norm = get_model().cuda() + fwd_resharding_plan = { + ".input": [[Shard(1)]], + } + param_sharding_plan = {} + + vescale_norm = parallelize_module( + vescale_norm, + device_mesh, + {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}, + ) + d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)]) + d_input.requires_grad_() + d_input.retain_grad() + + vescale_output = vescale_norm(d_input) + vescale_output = vescale_output.redistribute(placements=[Replicate()] * device_mesh.ndim) + vescale_loss = vescale_output.mean() + + vescale_loss.backward() + vescale_norm.finish_grad_sync() + d_gard = d_input.grad.redistribute(placements=[Replicate()] * device_mesh.ndim) + torch.testing.assert_close(vescale_output._local_tensor, golden_output) + torch.testing.assert_close(vescale_loss._local_tensor, golden_loss) + torch.testing.assert_close(d_gard._local_tensor, input.grad) + torch.testing.assert_close(vescale_norm.weight.grad._local_tensor, non_parallel_norm.weight.grad) + + +if __name__ == "__main__": + run_tests()