Skip to content

Commit

Permalink
PP API and nD Distributed Timeline Profiling (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
MackZackA authored Jul 30, 2024
1 parent c4afc72 commit aa95bb7
Show file tree
Hide file tree
Showing 98 changed files with 18,591 additions and 23 deletions.
Binary file added docs/pictures/ndtimeline_arch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/ndtimeline_trace.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/pp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion examples/open_llama_4D_benchmark/llama_mfu_calculator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion examples/open_llama_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pytest
tqdm
optree
accelerate
transformers==4.37.2
transformers==4.40.2
flash_attn
matplotlib
mmh3
4 changes: 1 addition & 3 deletions test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def init_method(self):
@skip_unless_torch_gpu
@with_comms
def test_load(self):
ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset(
self.init_method, dp_size=2, tp_size=2
)
ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset(self.init_method, dp_size=2, tp_size=2)

# Load the model and optimizer after first data

Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_dp_reshard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_load_save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_tp_reshard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
6 changes: 4 additions & 2 deletions test/model/open_llama/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def test_attention(self):
input.retain_grad()
non_parallel_attention, _ = get_model()
non_parallel_attention = non_parallel_attention.cuda()
golden_outputs = non_parallel_attention(input)
dummy_position_ids = torch.randint(low=0, high=s, size=(bsz, s)).cuda()
golden_outputs = non_parallel_attention(input, position_ids=dummy_position_ids)
golden_loss = golden_outputs[0].mean()
golden_loss.backward()

Expand Down Expand Up @@ -84,8 +85,9 @@ def test_attention(self):
d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)])
d_input.requires_grad_()
d_input.retain_grad()
d_position_id = distribute_tensor(dummy_position_ids.detach(), device_mesh, [Replicate()])

vescale_outputs = vescale_attention(d_input)
vescale_outputs = vescale_attention(d_input, position_ids=d_position_id)
vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim)
vescale_loss = vescale_outputs[0].mean()

Expand Down
6 changes: 4 additions & 2 deletions test/model/open_llama/test_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def test_decoder(self):
input.retain_grad()
non_parallel_decoder, _ = get_model()
non_parallel_decoder = non_parallel_decoder.cuda()
golden_outputs = non_parallel_decoder(input)
dummy_position_id = torch.randint(low=0, high=s, size=(bsz, s)).cuda()
golden_outputs = non_parallel_decoder(input, position_ids=dummy_position_id)
golden_loss = golden_outputs[0].mean()
golden_loss.backward()

Expand Down Expand Up @@ -95,8 +96,9 @@ def test_decoder(self):
d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)])
d_input.requires_grad_()
d_input.retain_grad()
d_position_id = distribute_tensor(dummy_position_id.detach(), device_mesh, [Replicate()])

vescale_outputs = vescale_decoder(d_input)
vescale_outputs = vescale_decoder(d_input, position_ids=d_position_id)
vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim)
vescale_loss = vescale_outputs[0].mean()

Expand Down
1 change: 1 addition & 0 deletions test/ndtimeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# make pylint happy
37 changes: 37 additions & 0 deletions test/ndtimeline/test_local_raw_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import os
from vescale.ndtimeline.world_info import WorldInfo
from vescale.ndtimeline.handlers import LocalRawNDHandler
from vescale.ndtimeline.variables import LOCAL_LOGGING_PATH


def test_basic_usage():
h = LocalRawNDHandler(run_id=0, chunk_sz=10, backup_cnt=3)
file_name = "timeline_run0_raw.log"
h("test_metric", 1.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
assert os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name))
for _ in range(4):
h("test_metric", 1.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
h("test_metric2", 2.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
assert os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".2"))
assert not os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".4"))
os.remove(os.path.join(LOCAL_LOGGING_PATH, file_name))
for i in range(1, 4):
os.remove(os.path.join(LOCAL_LOGGING_PATH, file_name + "." + str(i)))
assert not os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".2"))
30 changes: 30 additions & 0 deletions test/ndtimeline/test_metric_level.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

from vescale.ndtimeline import NDMetricLevel


def test_cmp_level():
assert NDMetricLevel.FRAMEWORK_DEBUG >= NDMetricLevel.INFO
assert NDMetricLevel.USER_DEBUG >= NDMetricLevel.INFO
assert NDMetricLevel.USER_DEBUG > NDMetricLevel.INFO
assert NDMetricLevel.USER_INFO < NDMetricLevel.INFO
assert NDMetricLevel.USER_INFO <= NDMetricLevel.INFO
assert NDMetricLevel.INFO < NDMetricLevel.DEBUG
assert NDMetricLevel.TRACE <= NDMetricLevel.TRACE
assert NDMetricLevel.TRACE >= NDMetricLevel.TRACE
assert NDMetricLevel.TRACE == NDMetricLevel.TRACE
61 changes: 61 additions & 0 deletions test/ndtimeline/test_parser_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import pytest
from vescale.ndtimeline.world_info import WorldInfo
from vescale.ndtimeline.handlers import ParserNDHandler
from vescale.ndtimeline.exceptions import NDHandlerError


def test_normal_input_with_tags():
metric_name = "test_metric"
recent_elapsed_raw_parts = [1.0, 3.2, 1.4]
elapsed = sum(recent_elapsed_raw_parts)
recent_since_start_raw_parts = [1710332816.6118143, 1710332833.2222, 1710332846.1313]
single_tag = {"is_test": True}
tags = [single_tag] * (len(recent_elapsed_raw_parts) - 1) + [{"is_test": False}]
step_range = range(0, 1)
world_info = WorldInfo(0, 0)
callback = ParserNDHandler()
records = callback(
metric_name, elapsed, recent_elapsed_raw_parts, recent_since_start_raw_parts, tags, step_range, world_info, {}
)
assert len(records) == 1
assert records[0].step == 0


def test_normal_invalid_input():
metric_name = "test_metric"
recent_elapsed_raw_parts = [1.0, 3.2, 1.4]
elapsed = sum(recent_elapsed_raw_parts)
recent_since_start_raw_parts = [1710332816.6118143, 1710332846.1313]
single_tag = {"is_test": True}
tags = [single_tag] * (len(recent_elapsed_raw_parts) - 1) + [{"is_test": False}]
step_range = range(0, 1)
world_info = WorldInfo(0, 0)
callback = ParserNDHandler()
with pytest.raises(NDHandlerError):
callback(
metric_name,
elapsed,
recent_elapsed_raw_parts,
recent_since_start_raw_parts,
tags,
step_range,
world_info,
{},
)
53 changes: 53 additions & 0 deletions test/parallel/pipeline/api/four_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import torch
import torch.nn as nn
import os


class MLP(nn.Module):
def __init__(self, features_in, feature_middle, features_out, value):
super().__init__()
self.value = value
self.counter = 0
self.fc1 = nn.Linear(1024, 1024, bias=False)
self.fc1.weight.data.fill_(value)
self.fc2 = nn.Linear(1024, 1024, bias=False)
self.fc2.weight.data.fill_(value * 2)
self.gelu = nn.GELU()

def forward(self, x):
t = self.fc1(x)
t = self.gelu(t)
t = self.fc2(t)
torch.save(t, f"{os.environ['model_name']}_mlp{self.value}_fwd{self.counter}_out_tensor.pt")
self.counter += 1
return t


class FourMLP(nn.Module):
def __init__(self, hidden):
super().__init__()
self.mlp1 = MLP(hidden * 1, hidden * 2, hidden * 3, 0)
self.mlp2 = MLP(hidden * 3, hidden * 4, hidden * 5, 1)
self.mlp3 = MLP(hidden * 5, hidden * 6, hidden * 7, 2)
self.mlp4 = MLP(hidden * 7, hidden * 8, hidden * 9, 3)
self.sequence = nn.Sequential(self.mlp1, self.mlp2, self.mlp3, self.mlp4)

def forward(self, x):
return self.sequence(x)
Loading

0 comments on commit aa95bb7

Please sign in to comment.