diff --git a/docs/pictures/ndtimeline_arch.jpg b/docs/pictures/ndtimeline_arch.jpg
new file mode 100644
index 0000000..cef58a9
Binary files /dev/null and b/docs/pictures/ndtimeline_arch.jpg differ
diff --git a/docs/pictures/ndtimeline_trace.png b/docs/pictures/ndtimeline_trace.png
new file mode 100644
index 0000000..51517d4
Binary files /dev/null and b/docs/pictures/ndtimeline_trace.png differ
diff --git a/docs/pictures/pp.png b/docs/pictures/pp.png
new file mode 100644
index 0000000..edabe7f
Binary files /dev/null and b/docs/pictures/pp.png differ
diff --git a/examples/open_llama_4D_benchmark/download_open_llama_ckpt.py b/examples/open_llama_4D_benchmark/download_open_llama_ckpt.py
index 876228a..5fcf32e 100644
--- a/examples/open_llama_4D_benchmark/download_open_llama_ckpt.py
+++ b/examples/open_llama_4D_benchmark/download_open_llama_ckpt.py
@@ -1,6 +1,6 @@
################################################################################
#
-# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
diff --git a/examples/open_llama_4D_benchmark/llama_mfu_calculator.py b/examples/open_llama_4D_benchmark/llama_mfu_calculator.py
index 9bacdd5..b67908d 100644
--- a/examples/open_llama_4D_benchmark/llama_mfu_calculator.py
+++ b/examples/open_llama_4D_benchmark/llama_mfu_calculator.py
@@ -1,6 +1,6 @@
################################################################################
#
-# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
diff --git a/examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py b/examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py
index 8117551..22f7cf8 100644
--- a/examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py
+++ b/examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py
@@ -1,6 +1,6 @@
################################################################################
#
-# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
diff --git a/examples/open_llama_4D_benchmark/sharding_plan.py b/examples/open_llama_4D_benchmark/sharding_plan.py
index 12bcd65..2aefd81 100644
--- a/examples/open_llama_4D_benchmark/sharding_plan.py
+++ b/examples/open_llama_4D_benchmark/sharding_plan.py
@@ -1,6 +1,6 @@
################################################################################
#
-# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
diff --git a/requirements.txt b/requirements.txt
index 5e4d40f..82b132a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,6 +6,7 @@ pytest
tqdm
optree
accelerate
-transformers==4.37.2
+transformers==4.40.2
flash_attn
+matplotlib
mmh3
\ No newline at end of file
diff --git a/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py b/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
index 3c487b5..5e1e8aa 100644
--- a/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
+++ b/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
@@ -101,9 +101,7 @@ def init_method(self):
@skip_unless_torch_gpu
@with_comms
def test_load(self):
- ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset(
- self.init_method, dp_size=2, tp_size=2
- )
+ ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset(self.init_method, dp_size=2, tp_size=2)
# Load the model and optimizer after first data
diff --git a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py
index b1f6cb3..370dadd 100644
--- a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py
+++ b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py
@@ -1,6 +1,6 @@
################################################################################
#
-# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
diff --git a/test/checkpoint/open_llama/test_open_llama_load_save.py b/test/checkpoint/open_llama/test_open_llama_load_save.py
index 0a3a29a..c0a8377 100644
--- a/test/checkpoint/open_llama/test_open_llama_load_save.py
+++ b/test/checkpoint/open_llama/test_open_llama_load_save.py
@@ -1,6 +1,6 @@
################################################################################
#
-# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
diff --git a/test/checkpoint/open_llama/test_open_llama_tp_reshard.py b/test/checkpoint/open_llama/test_open_llama_tp_reshard.py
index 5096062..2a85cae 100644
--- a/test/checkpoint/open_llama/test_open_llama_tp_reshard.py
+++ b/test/checkpoint/open_llama/test_open_llama_tp_reshard.py
@@ -1,6 +1,6 @@
################################################################################
#
-# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
diff --git a/test/model/open_llama/test_attention.py b/test/model/open_llama/test_attention.py
index f014531..ad7c281 100644
--- a/test/model/open_llama/test_attention.py
+++ b/test/model/open_llama/test_attention.py
@@ -56,7 +56,8 @@ def test_attention(self):
input.retain_grad()
non_parallel_attention, _ = get_model()
non_parallel_attention = non_parallel_attention.cuda()
- golden_outputs = non_parallel_attention(input)
+ dummy_position_ids = torch.randint(low=0, high=s, size=(bsz, s)).cuda()
+ golden_outputs = non_parallel_attention(input, position_ids=dummy_position_ids)
golden_loss = golden_outputs[0].mean()
golden_loss.backward()
@@ -84,8 +85,9 @@ def test_attention(self):
d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)])
d_input.requires_grad_()
d_input.retain_grad()
+ d_position_id = distribute_tensor(dummy_position_ids.detach(), device_mesh, [Replicate()])
- vescale_outputs = vescale_attention(d_input)
+ vescale_outputs = vescale_attention(d_input, position_ids=d_position_id)
vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim)
vescale_loss = vescale_outputs[0].mean()
diff --git a/test/model/open_llama/test_decoder_layer.py b/test/model/open_llama/test_decoder_layer.py
index c55ac9a..b32292c 100644
--- a/test/model/open_llama/test_decoder_layer.py
+++ b/test/model/open_llama/test_decoder_layer.py
@@ -56,7 +56,8 @@ def test_decoder(self):
input.retain_grad()
non_parallel_decoder, _ = get_model()
non_parallel_decoder = non_parallel_decoder.cuda()
- golden_outputs = non_parallel_decoder(input)
+ dummy_position_id = torch.randint(low=0, high=s, size=(bsz, s)).cuda()
+ golden_outputs = non_parallel_decoder(input, position_ids=dummy_position_id)
golden_loss = golden_outputs[0].mean()
golden_loss.backward()
@@ -95,8 +96,9 @@ def test_decoder(self):
d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)])
d_input.requires_grad_()
d_input.retain_grad()
+ d_position_id = distribute_tensor(dummy_position_id.detach(), device_mesh, [Replicate()])
- vescale_outputs = vescale_decoder(d_input)
+ vescale_outputs = vescale_decoder(d_input, position_ids=d_position_id)
vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim)
vescale_loss = vescale_outputs[0].mean()
diff --git a/test/ndtimeline/__init__.py b/test/ndtimeline/__init__.py
new file mode 100644
index 0000000..98f6b56
--- /dev/null
+++ b/test/ndtimeline/__init__.py
@@ -0,0 +1 @@
+# make pylint happy
diff --git a/test/ndtimeline/test_local_raw_handler.py b/test/ndtimeline/test_local_raw_handler.py
new file mode 100644
index 0000000..28253c0
--- /dev/null
+++ b/test/ndtimeline/test_local_raw_handler.py
@@ -0,0 +1,37 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+from vescale.ndtimeline.world_info import WorldInfo
+from vescale.ndtimeline.handlers import LocalRawNDHandler
+from vescale.ndtimeline.variables import LOCAL_LOGGING_PATH
+
+
+def test_basic_usage():
+ h = LocalRawNDHandler(run_id=0, chunk_sz=10, backup_cnt=3)
+ file_name = "timeline_run0_raw.log"
+ h("test_metric", 1.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
+ assert os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name))
+ for _ in range(4):
+ h("test_metric", 1.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
+ h("test_metric2", 2.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
+ assert os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".2"))
+ assert not os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".4"))
+ os.remove(os.path.join(LOCAL_LOGGING_PATH, file_name))
+ for i in range(1, 4):
+ os.remove(os.path.join(LOCAL_LOGGING_PATH, file_name + "." + str(i)))
+ assert not os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".2"))
diff --git a/test/ndtimeline/test_metric_level.py b/test/ndtimeline/test_metric_level.py
new file mode 100644
index 0000000..96f755c
--- /dev/null
+++ b/test/ndtimeline/test_metric_level.py
@@ -0,0 +1,30 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from vescale.ndtimeline import NDMetricLevel
+
+
+def test_cmp_level():
+ assert NDMetricLevel.FRAMEWORK_DEBUG >= NDMetricLevel.INFO
+ assert NDMetricLevel.USER_DEBUG >= NDMetricLevel.INFO
+ assert NDMetricLevel.USER_DEBUG > NDMetricLevel.INFO
+ assert NDMetricLevel.USER_INFO < NDMetricLevel.INFO
+ assert NDMetricLevel.USER_INFO <= NDMetricLevel.INFO
+ assert NDMetricLevel.INFO < NDMetricLevel.DEBUG
+ assert NDMetricLevel.TRACE <= NDMetricLevel.TRACE
+ assert NDMetricLevel.TRACE >= NDMetricLevel.TRACE
+ assert NDMetricLevel.TRACE == NDMetricLevel.TRACE
diff --git a/test/ndtimeline/test_parser_handler.py b/test/ndtimeline/test_parser_handler.py
new file mode 100644
index 0000000..b745ccf
--- /dev/null
+++ b/test/ndtimeline/test_parser_handler.py
@@ -0,0 +1,61 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import pytest
+from vescale.ndtimeline.world_info import WorldInfo
+from vescale.ndtimeline.handlers import ParserNDHandler
+from vescale.ndtimeline.exceptions import NDHandlerError
+
+
+def test_normal_input_with_tags():
+ metric_name = "test_metric"
+ recent_elapsed_raw_parts = [1.0, 3.2, 1.4]
+ elapsed = sum(recent_elapsed_raw_parts)
+ recent_since_start_raw_parts = [1710332816.6118143, 1710332833.2222, 1710332846.1313]
+ single_tag = {"is_test": True}
+ tags = [single_tag] * (len(recent_elapsed_raw_parts) - 1) + [{"is_test": False}]
+ step_range = range(0, 1)
+ world_info = WorldInfo(0, 0)
+ callback = ParserNDHandler()
+ records = callback(
+ metric_name, elapsed, recent_elapsed_raw_parts, recent_since_start_raw_parts, tags, step_range, world_info, {}
+ )
+ assert len(records) == 1
+ assert records[0].step == 0
+
+
+def test_normal_invalid_input():
+ metric_name = "test_metric"
+ recent_elapsed_raw_parts = [1.0, 3.2, 1.4]
+ elapsed = sum(recent_elapsed_raw_parts)
+ recent_since_start_raw_parts = [1710332816.6118143, 1710332846.1313]
+ single_tag = {"is_test": True}
+ tags = [single_tag] * (len(recent_elapsed_raw_parts) - 1) + [{"is_test": False}]
+ step_range = range(0, 1)
+ world_info = WorldInfo(0, 0)
+ callback = ParserNDHandler()
+ with pytest.raises(NDHandlerError):
+ callback(
+ metric_name,
+ elapsed,
+ recent_elapsed_raw_parts,
+ recent_since_start_raw_parts,
+ tags,
+ step_range,
+ world_info,
+ {},
+ )
diff --git a/test/parallel/pipeline/api/four_mlp.py b/test/parallel/pipeline/api/four_mlp.py
new file mode 100644
index 0000000..44d2d49
--- /dev/null
+++ b/test/parallel/pipeline/api/four_mlp.py
@@ -0,0 +1,53 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import torch
+import torch.nn as nn
+import os
+
+
+class MLP(nn.Module):
+ def __init__(self, features_in, feature_middle, features_out, value):
+ super().__init__()
+ self.value = value
+ self.counter = 0
+ self.fc1 = nn.Linear(1024, 1024, bias=False)
+ self.fc1.weight.data.fill_(value)
+ self.fc2 = nn.Linear(1024, 1024, bias=False)
+ self.fc2.weight.data.fill_(value * 2)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ torch.save(t, f"{os.environ['model_name']}_mlp{self.value}_fwd{self.counter}_out_tensor.pt")
+ self.counter += 1
+ return t
+
+
+class FourMLP(nn.Module):
+ def __init__(self, hidden):
+ super().__init__()
+ self.mlp1 = MLP(hidden * 1, hidden * 2, hidden * 3, 0)
+ self.mlp2 = MLP(hidden * 3, hidden * 4, hidden * 5, 1)
+ self.mlp3 = MLP(hidden * 5, hidden * 6, hidden * 7, 2)
+ self.mlp4 = MLP(hidden * 7, hidden * 8, hidden * 9, 3)
+ self.sequence = nn.Sequential(self.mlp1, self.mlp2, self.mlp3, self.mlp4)
+
+ def forward(self, x):
+ return self.sequence(x)
diff --git a/test/parallel/pipeline/api/test_pipe_engine_api.py b/test/parallel/pipeline/api/test_pipe_engine_api.py
new file mode 100644
index 0000000..5075bd5
--- /dev/null
+++ b/test/parallel/pipeline/api/test_pipe_engine_api.py
@@ -0,0 +1,417 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+from common_dtensor import DTensorTestBase, with_comms
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from vescale.optim.base_optimizer import BasicOptimizer
+from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+from vescale.engine import PipeEngine
+from vescale.plan import (
+ PipelineParallelPlan,
+ PipelineScheduleType,
+ ModeType,
+ PipelineSplitMethodType,
+)
+
+
+class MLP(nn.Module):
+ def __init__(self, n_features):
+ super().__init__()
+ self.fc1 = nn.Linear(n_features, n_features * 2, bias=False)
+ torch.nn.init.uniform_(self.fc1.weight, 0, 1)
+ self.fc2 = nn.Linear(n_features * 2, n_features)
+ torch.nn.init.uniform_(self.fc2.weight, 0, 1)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ return t
+
+
+class FourMLP(nn.Module):
+ def __init__(self, hidden):
+ super().__init__()
+ self.mlp1 = MLP(hidden)
+ self.mlp2 = MLP(hidden)
+ self.mlp3 = MLP(hidden)
+ self.mlp4 = MLP(hidden)
+ self.sequence = nn.Sequential(self.mlp1, self.mlp2, self.mlp3, self.mlp4)
+
+ def forward(self, x):
+ return self.sequence(x)
+
+
+class EightMLP(nn.Module):
+ def __init__(self, hidden):
+ super().__init__()
+ self.mlp1 = MLP(hidden)
+ self.mlp2 = MLP(hidden)
+ self.mlp3 = MLP(hidden)
+ self.mlp4 = MLP(hidden)
+ self.mlp5 = MLP(hidden)
+ self.mlp6 = MLP(hidden)
+ self.mlp7 = MLP(hidden)
+ self.mlp8 = MLP(hidden)
+
+ def forward(self, x):
+ x = self.mlp1(x)
+ x.retain_grad()
+ x = self.mlp2(x)
+ x.retain_grad()
+ x = self.mlp3(x)
+ x.retain_grad()
+ x = self.mlp4(x)
+ x.retain_grad()
+ x = self.mlp5(x)
+ x.retain_grad()
+ x = self.mlp6(x)
+ x.retain_grad()
+ x = self.mlp7(x)
+ x.retain_grad()
+ x = self.mlp8(x)
+ return x
+
+
+class ScheduleTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 4
+
+ @staticmethod
+ def loss_fn(x):
+ return torch.sum(x)
+
+ def _prepare_runtime_engine(self, model, forward_only: bool = False):
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.MANUAL,
+ num_stages=4,
+ virtual_chunks=1,
+ smallest_unsplittable_units=["mlp1", "mlp2", "mlp3", "mlp4"],
+ split_points=["mlp1", "mlp2", "mlp3", "mlp4"],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ forward_only=forward_only,
+ )
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model,
+ pipe_config,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+ _parameters = list(stage_modules[0].parameters())
+ optimizer = torch.optim.SGD(_parameters, **optimizer_fn_kwargs)
+ basic_optimizer = BasicOptimizer(optimizer, models=stage_modules)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, pipe_config)
+ engine = PipeEngine(
+ pipe_module,
+ VESCALE_DEVICE_MESH,
+ self.loss_fn,
+ pipe_config,
+ )
+
+ return engine, optimizer
+
+ def _prepare_runtime_interleaved_engine(self, model, forward_only: bool = False):
+ num_layer = 8
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.MANUAL,
+ num_stages=4,
+ virtual_chunks=2,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(num_layer)],
+ split_points=["mlp2", "mlp4", "mlp6", "mlp8"],
+ batch_p2p_comm=True,
+ overlap_p2p_comm=False,
+ schedule_type=PipelineScheduleType.INTERLEAVED_1F1B,
+ forward_only=forward_only,
+ )
+
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model,
+ pipe_config,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+ _parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters())
+ optimizer = torch.optim.SGD(_parameters, **optimizer_fn_kwargs)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, pipe_config)
+ engine = PipeEngine(
+ pipe_module,
+ VESCALE_DEVICE_MESH,
+ self.loss_fn,
+ pipe_config,
+ )
+ return engine, optimizer
+
+ @with_comms
+ def test_runtime_engine(self):
+ """
+ Tests pipeline engine.
+ """
+ local_rank = self.rank
+ device = f"cuda:{local_rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(local_rank)
+ n_hidden = 3
+ batches = 8
+ model = FourMLP(n_hidden).cuda()
+
+ all_batches_out = []
+ if local_rank == 3:
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i
+ data = data.float().cuda(3)
+ model.cuda(3)
+ out = model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ print(loss)
+ print(" ====================================== ")
+
+ engine, optimizer = self._prepare_runtime_engine(model)
+
+ data_iterator = []
+ for i in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + i
+ data_iterator.append(data.to(device))
+
+ minibatch_loss, _ = engine(data_iterator)
+
+ if local_rank == 3:
+ self.assertEqual(minibatch_loss, sum(all_batches_out))
+
+ @with_comms
+ def test_simple_inference_schedule(self):
+ """
+ Tests pipeline engine's inference mode.
+ """
+ local_rank = self.rank
+ device = f"cuda:{local_rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(local_rank)
+ n_hidden = 3
+ batches = 8
+ model = FourMLP(n_hidden).cuda()
+
+ all_batches_out = []
+ if local_rank == 3:
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i
+ data = data.float().cuda(3)
+ model.cuda(3)
+ out = model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ print(loss)
+ print(" ====================================== ")
+
+ engine, optimizer = self._prepare_runtime_engine(model, forward_only=True)
+
+ data_iterator = []
+ for i in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + i
+ data_iterator.append(data.to(device))
+
+ minibatch_loss, _ = engine(data_iterator)
+
+ if local_rank == 3:
+ self.assertEqual(minibatch_loss, sum(all_batches_out))
+
+ @with_comms
+ def test_runtime_interleaved_1f1b_engine_batch(self):
+ """
+ Tests pipeline engine with interleaved 1f1b schedule under
+ batch p2p communication.
+ """
+ global local_rank
+ local_rank = self.rank
+ device = f"cuda:{local_rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(local_rank)
+ n_hidden = 3
+ batches = 8
+ model = EightMLP(n_hidden).cuda()
+ single_model_data = []
+ all_batches_out = []
+ if local_rank == 3:
+ true_model = model
+ true_model = true_model.cuda()
+ true_model.train()
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i % 8
+ data = data.float().cuda(3)
+ single_model_data.append(data)
+ out = true_model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ print(" ====================================== ")
+
+ pipe_engine, optimizer = self._prepare_runtime_interleaved_engine(model)
+
+ data_iterator = []
+ for j in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + j
+ data_iterator.append(data.to(device))
+
+ minibatch_loss, _ = pipe_engine(data_iterator)
+
+ if local_rank == 3:
+ ground_truth_loss = sum(all_batches_out)
+ self.assertEqual(minibatch_loss, ground_truth_loss)
+
+ @with_comms
+ def test_runtime_interleaved_1f1b_engine_p2p(self):
+ """
+ Tests pipeline engine with interleaved 1f1b schedule under
+ overlapped p2p communication.
+ """
+ global local_rank
+ local_rank = self.rank
+ device = f"cuda:{local_rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(local_rank)
+ n_hidden = 3
+ batches = 8
+ model = EightMLP(n_hidden).cuda()
+ single_model_data = []
+ all_batches_out = []
+ if local_rank == 3:
+ true_model = model
+ true_model.train()
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i % 8 # + i
+ data = data.float().cuda(3)
+ single_model_data.append(data)
+ out = true_model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ print(" ====================================== ")
+
+ num_layer = 8
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.MANUAL,
+ num_stages=4,
+ virtual_chunks=2,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(num_layer)],
+ split_points=["mlp2", "mlp4", "mlp6", "mlp8"],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.INTERLEAVED_1F1B,
+ )
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model,
+ pipe_config,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+ _parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters())
+ optimizer = torch.optim.SGD(_parameters, **optimizer_fn_kwargs)
+ basic_optimizer = BasicOptimizer(optimizer, models=stage_modules)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, pipe_config)
+ engine = PipeEngine(
+ pipe_module,
+ VESCALE_DEVICE_MESH,
+ self.loss_fn,
+ pipe_config,
+ )
+
+ data_iterator = []
+ for j in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + j
+ data_iterator.append(data.to(device))
+
+ minibatch_loss, _ = engine.forward_backward(data_iterator)
+
+ if local_rank == 3:
+ ground_truth_loss = sum(all_batches_out)
+ self.assertEqual(minibatch_loss, ground_truth_loss)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/api/test_pipe_single_stage_ops.py b/test/parallel/pipeline/api/test_pipe_single_stage_ops.py
new file mode 100644
index 0000000..9d0922b
--- /dev/null
+++ b/test/parallel/pipeline/api/test_pipe_single_stage_ops.py
@@ -0,0 +1,219 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+from vescale.plan import PipelineScheduleType, PipelineParallelPlan, ModeType, PipelineSplitMethodType
+from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
+from vescale.engine import PipeEngine
+from common_dtensor import DTensorTestBase, with_comms
+from torch.optim import SGD
+
+microbatch_size = 16
+factor = 8
+batch_size = microbatch_size * factor
+RANDOM_SEED = 9999
+
+
+class MLP(nn.Module):
+ def __init__(self, value):
+ super().__init__()
+ self.value = value
+ self.counter = 0
+ self.fc1 = nn.Linear(32, 32, bias=False)
+ self.fc1.weight.data.fill_(value)
+ self.fc2 = nn.Linear(32, 32, bias=False)
+ self.fc2.weight.data.fill_(value * 2)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ torch.save(t, f"{os.environ['model_name']}_mlp{self.value}_fwd{self.counter}_out_tensor.pt")
+ self.counter += 1
+ return t
+
+
+class MLPWithForwardUtil(nn.Module):
+ def __init__(self, value):
+ super().__init__()
+ self.value = value
+ self.counter = 0
+ self.fc1 = nn.Linear(32, 32, bias=False)
+ self.fc1.weight.data.fill_(value)
+ self.fc2 = nn.Linear(32, 32, bias=False)
+ self.fc2.weight.data.fill_(value * 2)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ torch.save(t, f"{os.environ['model_name']}_mlp{self.value}_fwd{self.counter}_out_tensor.pt")
+ self.counter += 1
+ return t
+
+ def forward_util(self, p2p_input, local_input=None):
+ print("This is an auxilary forward_util() provided by the user")
+ if p2p_input is not None:
+ print("Modified p2p_input value!")
+ p2p_input *= 2
+ else:
+ print("Load local input as p2p input")
+ p2p_input = local_input
+ if local_input is not None:
+ print("Handling local inputs")
+ return [p2p_input]
+
+
+class EightMLP(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.mlp1 = MLPWithForwardUtil(0)
+ self.mlp2 = MLP(1)
+ self.mlp3 = MLP(2)
+ self.mlp4 = MLP(3)
+ self.mlp5 = MLPWithForwardUtil(3)
+ self.mlp6 = MLP(3)
+ self.mlp7 = MLP(3)
+ self.mlp8 = MLP(3)
+ self.sequence = nn.Sequential(
+ self.mlp1,
+ self.mlp2,
+ self.mlp3,
+ self.mlp4,
+ self.mlp5,
+ self.mlp6,
+ self.mlp7,
+ self.mlp8,
+ )
+
+ def forward(self, x):
+ return self.sequence(x)
+
+
+class PipelineSingleStageOpsTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 4
+
+ @staticmethod
+ def loss_fn(x):
+ return x.mean()
+
+ def test_stage_forward(self):
+ """
+ Test single stage forward.
+ """
+ if self.rank == 0:
+ self._run_no_pp_model()
+ n_gpus = torch.cuda.device_count()
+ assert n_gpus >= 2, "Requires at least 2 GPUs to run model with pp engine"
+ self._run_stage_forward()
+
+ def _run_no_pp_model(self):
+ os.environ["model_name"] = "golden"
+ model = EightMLP().to("cuda:0")
+ optimizer = torch.optim.SGD(
+ model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False
+ )
+ torch.manual_seed(9999)
+ batch = [torch.ones(microbatch_size, 128, 32, dtype=torch.float32).to("cuda:0") for _ in range(factor)]
+ for mb in batch:
+ out = model(mb)
+
+ @with_comms
+ def _run_stage_forward(self):
+ os.environ["model_name"] = "pp"
+ device = f"cuda:{self.rank}"
+ torch.cuda.set_device(device)
+ model = EightMLP().cuda()
+
+ num_layers = 8
+ config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.MANUAL,
+ num_stages=4,
+ virtual_chunks=2,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(num_layers)],
+ split_points=["mlp2", "mlp4", "mlp6", "mlp8"],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.INTERLEAVED_1F1B,
+ )
+
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=("PP", "DP", "TP"),
+ )
+
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model,
+ config,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+ _parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters())
+ optimizer = SGD(_parameters, **optimizer_fn_kwargs)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config)
+
+ engine = PipeEngine(
+ pipe_module,
+ VESCALE_DEVICE_MESH,
+ self.loss_fn,
+ config,
+ )
+ torch.manual_seed(9999)
+ batch = [torch.ones(microbatch_size, 128, 32, dtype=torch.float32).to(device) for _ in range(factor)]
+ if self.rank == 0:
+ # first stage only receives inputs from dataloader
+ chunk_id = 0
+ print(f"Chunk ID: {chunk_id}")
+ output_chunk_one = engine.module(None, local_inputs=batch[0], chunk_id=chunk_id)
+ chunk_id = 1
+ print(f"Chunk ID: {chunk_id}")
+ output_chunk_two = engine.module(batch[1], local_inputs=None, chunk_id=chunk_id)
+ assert not torch.equal(output_chunk_one, output_chunk_two)
+ if self.rank == 2:
+ # other stages can receive inputs communicated by their peers
+ chunk_id = 0
+ print(f"Chunk ID: {chunk_id}")
+ output_chunk_three = engine.module(batch[2], local_inputs=None, chunk_id=chunk_id)
+ chunk_id = 1
+ print(f"Chunk ID: {chunk_id}")
+ output_chunk_four = engine.module(batch[3], local_inputs=None, chunk_id=chunk_id)
+ assert not torch.equal(output_chunk_three, output_chunk_four)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/api/test_schedule_engine.py b/test/parallel/pipeline/api/test_schedule_engine.py
new file mode 100644
index 0000000..c508511
--- /dev/null
+++ b/test/parallel/pipeline/api/test_schedule_engine.py
@@ -0,0 +1,121 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+import torch
+from common_dtensor import DTensorTestBase, with_comms
+from torch.testing._internal.common_utils import run_tests
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules
+from vescale.pipe._schedules.instruction_base import StageDeps
+from vescale.pipe.pipe_emmiter import ScheduleEngine
+from vescale.plan.spec import PipelineScheduleType, ModeType, PipelineSplitMethodType
+from vescale.plan.pipeline_parallel import PipelineParallelPlan
+from four_mlp import FourMLP
+from torch.optim import SGD
+
+
+class ScheduleEngineRuntimeTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 2
+
+ @staticmethod
+ def loss_fn(x):
+ return x.mean()
+
+ def _setup(self):
+ os.environ["model_name"] = "pp"
+ global local_rank
+ local_rank = self.rank
+ device = f"cuda:{local_rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+
+ torch.manual_seed(9999)
+ microbatch_size = 2
+ factor = 4
+ batch = [torch.ones(microbatch_size, 128, 1024, dtype=torch.float32).to(device) for _ in range(factor)]
+ return batch, microbatch_size
+
+ @with_comms
+ def test_simple_1f1b(self):
+ """
+ Test simple 1f1b schedule with schedule runtime.
+ """
+ batch, microbatch_size = self._setup()
+
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(2, 1, 1),
+ mesh_dim_names=("PP", "DP", "TP"),
+ )
+
+ model = FourMLP(1024).cuda()
+ num_layers = 4
+
+ config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ num_stages=2,
+ virtual_chunks=1,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(num_layers)],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ )
+
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model,
+ config,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+ _parameters = list(stage_modules[0].parameters())
+ optimizer = SGD(_parameters, **optimizer_fn_kwargs)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config)
+
+ dep = pipe_module.stage_deps
+ device_mesh_list = VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes()
+ stage_deps = StageDeps(dep, device_mesh_list, pipe_module)
+
+ pipe_engine = ScheduleEngine(
+ stage_deps,
+ meshes=VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes(),
+ schedule=PipelineScheduleType.SIMPLE_1F1B,
+ batches=len(batch),
+ data_iterator=iter(batch),
+ stage_id=VESCALE_DEVICE_MESH.get_pipeline_parallel_rank(),
+ shape=(microbatch_size, 128, 1024),
+ dtype=torch.float32,
+ )
+ minibatch_loss, all_forward_outputs = ScheduleEngine.execute(pipe_engine)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/api/test_simple_api.py b/test/parallel/pipeline/api/test_simple_api.py
new file mode 100644
index 0000000..97b3d3b
--- /dev/null
+++ b/test/parallel/pipeline/api/test_simple_api.py
@@ -0,0 +1,195 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+from common_dtensor import DTensorTestBase, with_comms
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from vescale.debug.pdb import ForkedPdb
+from vescale.optim.base_optimizer import BasicOptimizer
+from vescale.pipe.pipe_stage import construct_pipeline_stage
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+from vescale.engine import PipeEngine
+from vescale.plan import (
+ PipelineParallelPlan,
+ PipelineScheduleType,
+ ModeType,
+ PipelineSplitMethodType,
+)
+
+
+class MLP(nn.Module):
+ def __init__(self, n_features):
+ super().__init__()
+ self.fc1 = nn.Linear(n_features, n_features * 2, bias=False)
+ torch.nn.init.uniform_(self.fc1.weight, 0, 1)
+ self.fc2 = nn.Linear(n_features * 2, n_features)
+ torch.nn.init.uniform_(self.fc2.weight, 0, 1)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ return t
+
+
+class FourMLP(nn.Module):
+ def __init__(self, hidden):
+ super().__init__()
+ self.mlp1 = MLP(hidden)
+ self.mlp2 = MLP(hidden)
+ self.mlp3 = MLP(hidden)
+ self.mlp4 = MLP(hidden)
+ self.sequence = nn.Sequential(self.mlp1, self.mlp2, self.mlp3, self.mlp4)
+
+ def forward(self, x):
+ return self.sequence(x)
+
+
+class EightMLP(nn.Module):
+ def __init__(self, hidden):
+ super().__init__()
+ self.mlp1 = MLP(hidden)
+ self.mlp2 = MLP(hidden)
+ self.mlp3 = MLP(hidden)
+ self.mlp4 = MLP(hidden)
+ self.mlp5 = MLP(hidden)
+ self.mlp6 = MLP(hidden)
+ self.mlp7 = MLP(hidden)
+ self.mlp8 = MLP(hidden)
+
+ def forward(self, x):
+ x = self.mlp1(x)
+ x.retain_grad()
+ x = self.mlp2(x)
+ x.retain_grad()
+ x = self.mlp3(x)
+ x.retain_grad()
+ x = self.mlp4(x)
+ x.retain_grad()
+ x = self.mlp5(x)
+ x.retain_grad()
+ x = self.mlp6(x)
+ x.retain_grad()
+ x = self.mlp7(x)
+ x.retain_grad()
+ x = self.mlp8(x)
+ return x
+
+
+class SimpleAPITest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 4
+
+ @staticmethod
+ def loss_fn(x):
+ return torch.sum(x)
+
+ def _prepare_runtime_engine(self, model, forward_only: bool = False):
+ pipe_plan = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.MANUAL,
+ num_stages=4,
+ virtual_chunks=1,
+ smallest_unsplittable_units=["mlp1", "mlp2", "mlp3", "mlp4"],
+ split_points=["mlp1", "mlp2", "mlp3", "mlp4"],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ forward_only=forward_only,
+ )
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+ pipe_module = construct_pipeline_stage(
+ model,
+ pipe_plan,
+ VESCALE_DEVICE_MESH,
+ lr_scheduler=None,
+ update_split_points=True,
+ )
+ optimizer = torch.optim.SGD(pipe_module.parameters(), **optimizer_fn_kwargs)
+ basic_optimizer = BasicOptimizer(optimizer, models=pipe_module)
+ engine = PipeEngine(
+ pipe_module,
+ VESCALE_DEVICE_MESH,
+ self.loss_fn,
+ pipe_plan,
+ )
+
+ return engine, optimizer
+
+ @with_comms
+ def test_simple_api(self):
+ """
+ Tests pipeline engine with simple API.
+ """
+ local_rank = self.rank
+ device = f"cuda:{local_rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(local_rank)
+ n_hidden = 3
+ batches = 8
+ model = FourMLP(n_hidden).cuda()
+
+ all_batches_out = []
+ if local_rank == 3:
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i
+ data = data.float().cuda(3)
+ model.cuda(3)
+ out = model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ print(loss)
+ print(" ====================================== ")
+
+ engine, optimizer = self._prepare_runtime_engine(model)
+
+ data_iterator = []
+ for i in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + i
+ data_iterator.append(data.to(device))
+
+ minibatch_loss, _ = engine(data_iterator)
+
+ if local_rank == 3:
+ self.assertEqual(minibatch_loss, sum(all_batches_out))
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/backend/eight_mlp.py b/test/parallel/pipeline/backend/eight_mlp.py
new file mode 100644
index 0000000..b4d2e0f
--- /dev/null
+++ b/test/parallel/pipeline/backend/eight_mlp.py
@@ -0,0 +1,288 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import torch.nn as nn
+from vescale.dtensor.placement_types import Shard, Replicate
+
+
+class Embed(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.embedding = nn.Embedding(8, 64)
+
+ def forward(self, x):
+ return self.embedding(x)
+
+ def get_word_embeddings_weight(self):
+ return self.embedding.weight
+
+
+class EmbedTwo(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.embedding = nn.Embedding(8, 64)
+
+ def forward(self, x):
+ return self.embedding(x)
+
+ def get_word_embeddings_weight(self):
+ return self.embedding.weight
+
+
+class MLP(nn.Module):
+ def __init__(self, features_in, features_out, value):
+ super().__init__()
+ self.value = value
+ self.fc1 = nn.Linear(features_in, 16, bias=False)
+ self.fc1.weight.data.fill_(value)
+ self.fc2 = nn.Linear(16, features_out, bias=False)
+ self.fc2.weight.data.fill_(value * 2)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ return t
+
+
+class SmallMLP(nn.Module):
+ def __init__(self, features_in, features_out, value):
+ super().__init__()
+ self.value = value
+ self.fc1 = nn.Linear(features_in, features_out, bias=False)
+ self.fc1.weight.data.fill_(value)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ return t
+
+
+class HierachicalMLP(nn.Module):
+ def __init__(self, features_in, features_out, value):
+ super().__init__()
+ self.value = value
+ self.fc0 = SmallMLP(features_in, features_in, value)
+ self.fc1 = nn.Linear(features_in, 16, bias=False)
+ self.fc2 = nn.Linear(16, features_out, bias=False)
+ self.fc3 = SmallMLP(features_out, features_out, value)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ x = x + x
+ x = self.fc0(x)
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ t = self.fc3(t)
+ return t
+
+
+class EightMLP(nn.Module):
+ def __init__(self, hidden=64, fixed_size=True, embedded_module=False):
+ super().__init__()
+ module = HierachicalMLP if embedded_module else MLP
+ if fixed_size:
+ self.mlp1 = module(hidden, hidden, 0)
+ self.mlp2 = module(hidden, hidden, 1)
+ self.mlp3 = module(hidden, hidden, 2)
+ self.mlp4 = module(hidden, hidden, 3)
+ self.mlp5 = module(hidden, hidden, 4)
+ self.mlp6 = module(hidden, hidden, 5)
+ self.mlp7 = module(hidden, hidden, 6) # tranformerlayer7 = TransformerLayer(hidden)
+ self.mlp8 = module(hidden, hidden, 7) # tranformerlayer8 = TransformerLayer(hidden)
+ else:
+ self.mlp1 = module(hidden * 1, hidden * 2, 0)
+ self.mlp2 = module(hidden * 2, hidden * 3, 1)
+ self.mlp3 = module(hidden * 3, hidden * 4, 2)
+ self.mlp4 = module(hidden * 4, hidden * 5, 3)
+ self.mlp5 = module(hidden * 5, hidden * 6, 4)
+ self.mlp6 = module(hidden * 6, hidden * 7, 5)
+ self.mlp7 = module(hidden * 7, hidden * 8, 6)
+ self.mlp8 = module(hidden * 8, hidden * 9, 7)
+
+ def forward(self, x):
+ x = self.mlp1(x)
+ x = self.mlp2(x)
+ x = self.mlp3(x)
+ x = self.mlp4(x)
+ x = self.mlp5(x)
+ x = self.mlp6(x)
+ x = self.mlp7(x)
+ x = self.mlp8(x)
+ return x
+
+
+class EightMLPDiffNames(nn.Module):
+ def __init__(self, hidden=64):
+ super().__init__()
+ self.mlp1 = MLP(hidden, hidden, 0)
+ self.mlp2 = MLP(hidden, hidden, 1)
+ self.mlp3 = MLP(hidden, hidden, 2)
+ self.layer1 = MLP(hidden, hidden, 3)
+ self.layer2 = MLP(hidden, hidden, 4)
+ self.layer3 = MLP(hidden, hidden, 5)
+ self.layer4 = MLP(hidden, hidden, 6)
+ self.more_layer1 = MLP(hidden, hidden, 7)
+
+ def forward(self, x):
+ x = self.mlp1(x)
+ x = self.mlp2(x)
+ x = self.mlp3(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.more_layer1(x)
+ return x
+
+
+class EightMLPWithOps(nn.Module):
+ def __init__(self, hidden=64):
+ super().__init__()
+ self.mlp1 = MLP(hidden, hidden, 0)
+ self.mlp2 = MLP(hidden, hidden, 1)
+ self.mlp3 = MLP(hidden, hidden, 2)
+ self.mlp4 = MLP(hidden, hidden, 3)
+ self.mlp5 = MLP(hidden, hidden, 4)
+ self.mlp6 = MLP(hidden, hidden, 5)
+ self.mlp7 = MLP(hidden, hidden, 6)
+ self.mlp8 = MLP(hidden, hidden, 7)
+
+ def forward(self, x):
+ x = x + x
+ x = self.mlp1(x)
+ x = x * 2
+ x = self.mlp2(x)
+ x = x * 2
+ x = x * 2
+ x = x * 2
+ x = self.mlp3(x)
+ x = self.mlp4(x)
+ x = self.mlp5(x)
+ x = self.mlp6(x)
+ x = self.mlp7(x)
+ x = self.mlp8(x)
+ return x
+
+
+class EightMLPWithOpsTail(nn.Module):
+ def __init__(self, hidden=64):
+ super().__init__()
+ self.mlp1 = MLP(hidden, hidden, 0)
+ self.mlp2 = MLP(hidden, hidden, 1)
+ self.mlp3 = MLP(hidden, hidden, 2)
+ self.mlp4 = MLP(hidden, hidden, 3)
+ self.mlp5 = MLP(hidden, hidden, 4)
+ self.mlp6 = MLP(hidden, hidden, 5)
+ self.mlp7 = MLP(hidden, hidden, 6)
+ self.mlp8 = MLP(hidden, hidden, 7)
+
+ def forward(self, x):
+ x = x + x
+ x = self.mlp1(x)
+ x = x * 2
+ x = self.mlp2(x)
+ x = x * 2
+ x = self.mlp3(x)
+ x = self.mlp4(x)
+ x = self.mlp5(x)
+ x = self.mlp6(x)
+ x = self.mlp7(x)
+ x = self.mlp8(x)
+ x = x * 2
+ x = x * 4
+ x = x + 4
+ return x
+
+
+class EightMLPSharedEmbed(nn.Module):
+ def __init__(self, hidden=64):
+ super().__init__()
+ self.embed1 = Embed()
+ self.mlp1 = MLP(hidden, hidden, 0)
+ self.mlp2 = MLP(hidden, hidden, 1)
+ self.mlp3 = MLP(hidden, hidden, 2)
+ self.mlp4 = MLP(hidden, hidden, 3)
+ self.mlp5 = MLP(hidden, hidden, 4)
+ self.mlp6 = MLP(hidden, hidden, 5)
+ self.mlp7 = MLP(hidden, hidden, 6)
+ self.mlp8 = MLP(hidden, hidden, 7)
+ self.embed2 = EmbedTwo()
+
+ def forward(self, x):
+ x = self.embed1(x).float()
+ x = self.mlp1(x)
+ x = self.mlp2(x)
+ x = self.mlp3(x)
+ x = self.mlp4(x)
+ x = self.mlp5(x)
+ x = self.mlp6(x)
+ x = self.mlp7(x)
+ x = self.mlp8(x).long()
+ x = self.embed2(x)
+ return x
+
+
+sharding_plan = {
+ "forward": {
+ r"mlp\d.input": [[Replicate()]],
+ r"mlp\d.output": [[Replicate()]],
+ },
+ "parameter": {
+ r"mlp\d.fc1.weight": [Shard(0)],
+ r"mlp\d.fc2.weight": [Shard(1)],
+ },
+}
+
+sharding_plan_two = {
+ "forward": {
+ r"mlp\d.input": [[Replicate()]],
+ r"mlp\d.output": [[Replicate()]],
+ },
+ "parameter": {
+ r"mlp\d.weight": [Shard(1)],
+ },
+}
+
+sharding_plan_combo = {
+ "forward": {
+ r"mlp\d.input": [[Replicate()]],
+ r"mlp\d.output": [[Replicate()]],
+ r"layer\d.input": [[Replicate()]],
+ r"layer\d.output": [[Replicate()]],
+ r"more_layer\d.input": [[Replicate()]],
+ r"more_layer\d.output": [[Replicate()]],
+ },
+ "parameter": {
+ r"mlp\d.weight": [Shard(1)],
+ r"layer\d.weight": [[Replicate()]],
+ },
+}
+
+sharding_plan_fc = {
+ "forward": {
+ r"mlp\d.fc\d.input": [[Replicate()]],
+ r"mlp\d.fc\d.output": [[Replicate()]],
+ },
+ "parameter": {
+ r"mlp\d.fc1.weight": [Shard(0)],
+ r"mlp\d.fc2.weight": [Shard(1)],
+ },
+}
diff --git a/test/parallel/pipeline/backend/test_p2p_comm.py b/test/parallel/pipeline/backend/test_p2p_comm.py
new file mode 100644
index 0000000..55d47bb
--- /dev/null
+++ b/test/parallel/pipeline/backend/test_p2p_comm.py
@@ -0,0 +1,994 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+import torch
+import torch.distributed as dist
+from torch.testing._internal.common_utils import run_tests
+from vescale import DeviceMesh, distribute_tensor
+from vescale.dtensor.placement_types import Replicate
+from vescale.pipe.p2p_communication import (
+ _communicate,
+ _communicate_shapes,
+ _mapping_local_rank_to_target_rank_by_device_mesh,
+ recv_forward,
+ recv_backward,
+ send_forward,
+ send_backward,
+ send_forward_recv_backward,
+ send_backward_recv_forward,
+ send_forward_recv_forward,
+ send_backward_recv_backward,
+ send_forward_backward_recv_forward_backward,
+ drain_recv_reqs,
+)
+from common_dtensor import (
+ DTensorTestBase,
+ with_comms,
+)
+
+
+class PipeP2PTest(DTensorTestBase):
+ @staticmethod
+ def set_up_device_mesh_stages(world_size, device, n):
+ assert world_size % n == 0, "world size must be divisible by the number of stages"
+ n_device = world_size // n
+ return (DeviceMesh(device, list(range(n_device * i, n_device * (i + 1)))) for i in range(n))
+
+ @staticmethod
+ def apply_xavier_normal_with_seed(tensor, seed=99999):
+ torch.manual_seed(seed)
+ torch.nn.init.xavier_normal_(tensor)
+
+ @property
+ def world_size(self):
+ return 8
+
+ @property
+ def sequence_len(self):
+ return 8
+
+ @property
+ def batch_size(self):
+ return 4
+
+ @property
+ def input_size(self):
+ return 2
+
+ @property
+ def stages(self):
+ return 4
+
+ def _generate_device_meshes(self):
+ device = f"cuda:{self.rank}"
+ # stage1
+ device_mesh_stage1 = DeviceMesh(device, list(range(self.world_size // 2)))
+ # stage2
+ device_mesh_stage2 = DeviceMesh(device, list(range(self.world_size // 2, self.world_size)))
+ return device_mesh_stage1, device_mesh_stage2
+
+ def _generate_three_device_meshes(self):
+ device = f"cuda:{self.rank}"
+ # stage1
+ device_mesh_stage1 = DeviceMesh(device, list(range(self.world_size // 4)))
+ # stage2
+ device_mesh_stage2 = DeviceMesh(device, list(range(self.world_size // 4, self.world_size // 2)))
+ # stage3
+ device_mesh_stage3 = DeviceMesh(device, list(range(self.world_size // 2, self.world_size // 4 * 3)))
+ return device_mesh_stage1, device_mesh_stage2, device_mesh_stage3
+
+ @with_comms
+ def test_communicate_shapes(self):
+ """
+ Test correctness function of _communicate_shapes().
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2 = self._generate_device_meshes()
+
+ # stage 1 tensor
+ tensor_stage1 = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ torch.nn.init.xavier_normal_(tensor_stage1)
+ dist.all_reduce(tensor_stage1, async_op=False)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ target_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=self.rank, current_device_mesh=device_mesh_stage1, target_device_mesh=device_mesh_stage2
+ )
+ _communicate_shapes(
+ local_rank=self.rank,
+ tensor_send_next=dtensor_stage1,
+ tensor_send_prev=None,
+ next_rank=target_rank,
+ prev_rank=None,
+ recv_prev=False,
+ recv_next=False,
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ target_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=self.rank, current_device_mesh=device_mesh_stage2, target_device_mesh=device_mesh_stage1
+ )
+ recv_prev_shape, _ = _communicate_shapes(
+ local_rank=self.rank,
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ prev_rank=target_rank,
+ next_rank=None,
+ recv_prev=True,
+ recv_next=False,
+ )
+ self.assertTrue(recv_prev_shape == [self.sequence_len, self.batch_size, self.input_size])
+
+ @with_comms
+ def test_communicate_no_batch_p2p_comm(self):
+ """
+ Test correctness of p2p communication ops.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2 = self._generate_device_meshes()
+ # stage 1 tensor
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ _communicate(
+ tensor_send_next=dtensor_stage1._local_tensor,
+ tensor_send_prev=None,
+ current_device_mesh=device_mesh_stage1,
+ next_device_mesh=device_mesh_stage2,
+ recv_prev=False,
+ recv_next=False,
+ tensor_shape=None,
+ batch_p2p_comm=False,
+ wait_on_reqs=True,
+ dtype=None,
+ )
+
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ recv_prev_tensor, _, _ = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ recv_prev=True,
+ recv_next=False,
+ tensor_shape=None,
+ batch_p2p_comm=False,
+ wait_on_reqs=True,
+ dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ recv_prev_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+
+ @with_comms
+ def test_communicate_batch_p2p_comm(self):
+ """
+ Test correctness of batch communication ops.
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2 = self._generate_device_meshes()
+ # stage 1 tensor
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ _communicate(
+ tensor_send_next=dtensor_stage1._local_tensor,
+ tensor_send_prev=None,
+ current_device_mesh=device_mesh_stage1,
+ next_device_mesh=device_mesh_stage2,
+ recv_prev=False,
+ recv_next=False,
+ tensor_shape=None,
+ batch_p2p_comm=True,
+ wait_on_reqs=True,
+ dtype=None,
+ )
+
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ recv_prev_tensor, _, _ = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ recv_prev=True,
+ recv_next=False,
+ tensor_shape=None,
+ batch_p2p_comm=True,
+ wait_on_reqs=True,
+ dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ recv_prev_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+
+ @with_comms
+ def test_send_forward_and_recv_forward(self):
+ """
+ Test correctness of send_forward() and recv_forward().
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ stage_list = list(self.set_up_device_mesh_stages(self.world_size, device, self.stages))
+ seed_list = list(range(99990, 99990 + self.stages))
+ stage_n_dict = {(self.rank in stage.mesh.tolist()): i for i, stage in enumerate(stage_list)}
+ stage_n = stage_n_dict[True]
+ send_seed = seed_list[stage_n]
+ recv_seed = seed_list[stage_n - 1]
+ prev_stage = stage_list[stage_n - 1]
+ curr_stage = stage_list[stage_n]
+ next_stage = stage_list[(stage_n + 1) % len(stage_list)]
+ send_t = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ expt_t = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ self.apply_xavier_normal_with_seed(send_t, seed=send_seed)
+ self.apply_xavier_normal_with_seed(expt_t, seed=recv_seed)
+
+ if stage_n % 2 == 0:
+ send_forward(
+ output_tensor=send_t,
+ current_device_mesh=curr_stage,
+ peer_device_mesh=next_stage,
+ tensor_shape=send_t.shape,
+ )
+ else:
+ recv_prev_tensor = recv_forward(
+ tensor_shape=expt_t.shape,
+ recv_dtype=expt_t.dtype,
+ current_device_mesh=curr_stage,
+ peer_device_mesh=prev_stage,
+ )
+ self.assertTrue(torch.equal(recv_prev_tensor, expt_t))
+
+ @with_comms
+ def test_send_backward_and_recv_backward(self):
+ """
+ Test correctness of send_backward() and recv_backward().
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ stage_list = list(self.set_up_device_mesh_stages(self.world_size, device, self.stages))
+ seed_list = list(range(99990, 99990 + self.stages))
+ stage_n_dict = {(self.rank in stage.mesh.tolist()): i for i, stage in enumerate(stage_list)}
+ stage_n = stage_n_dict[True]
+ send_seed = seed_list[stage_n]
+ recv_seed = seed_list[(stage_n + 1) % len(seed_list)]
+ prev_stage = stage_list[stage_n - 1]
+ curr_stage = stage_list[stage_n]
+ next_stage = stage_list[(stage_n + 1) % len(stage_list)]
+ send_t = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ expt_t = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ self.apply_xavier_normal_with_seed(send_t, seed=send_seed)
+ self.apply_xavier_normal_with_seed(expt_t, seed=recv_seed)
+
+ if stage_n % 2 == 0:
+ send_backward(
+ input_tensor_grad=send_t,
+ current_device_mesh=curr_stage,
+ peer_device_mesh=prev_stage,
+ tensor_shape=send_t.shape,
+ )
+ else:
+ recv_prev_tensor = recv_backward(
+ tensor_shape=expt_t.shape,
+ recv_dtype=expt_t.dtype,
+ current_device_mesh=curr_stage,
+ peer_device_mesh=next_stage,
+ )
+ self.assertTrue(torch.equal(recv_prev_tensor, expt_t))
+
+ @with_comms
+ def test_send_forward_recv_backward_and_send_backward_recv_forward(self):
+ """
+ Test correctness of send_backward() and recv_backward().
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ stage_list = list(self.set_up_device_mesh_stages(self.world_size, device, self.stages))
+ fwd_seed_list = list(range(99990, 99990 + self.stages))
+ bwd_seed_list = list(range(77770, 77770 + self.stages))
+ stage_n_dict = {(self.rank in stage.mesh.tolist()): i for i, stage in enumerate(stage_list)}
+ stage_n = stage_n_dict[True]
+ fwd_send_seed = fwd_seed_list[stage_n]
+ fwd_recv_seed = fwd_seed_list[stage_n - 1]
+ bwd_send_seed = bwd_seed_list[stage_n]
+ bwd_recv_seed = bwd_seed_list[(stage_n + 1) % len(bwd_seed_list)]
+ prev_stage = stage_list[stage_n - 1]
+ curr_stage = stage_list[stage_n]
+ next_stage = stage_list[(stage_n + 1) % len(stage_list)]
+ fwd_send_t = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ fwd_expt_t = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ bwd_send_t = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ bwd_expt_t = torch.empty(self.sequence_len, self.batch_size, self.input_size, device=device)
+ self.apply_xavier_normal_with_seed(fwd_send_t, seed=fwd_send_seed)
+ self.apply_xavier_normal_with_seed(fwd_expt_t, seed=fwd_recv_seed)
+ self.apply_xavier_normal_with_seed(bwd_send_t, seed=bwd_send_seed)
+ self.apply_xavier_normal_with_seed(bwd_expt_t, seed=bwd_recv_seed)
+ if stage_n % 2 == 0:
+ recv_bwd_tensor = send_forward_recv_backward(
+ output_tensor=fwd_send_t,
+ tensor_shape=bwd_expt_t.shape,
+ recv_dtype=bwd_expt_t.dtype,
+ current_device_mesh=curr_stage,
+ peer_device_mesh=next_stage,
+ )
+ self.assertTrue(torch.equal(recv_bwd_tensor, bwd_expt_t))
+ else:
+ recv_fwd_tensor = send_backward_recv_forward(
+ input_tensor_grad=bwd_send_t,
+ tensor_shape=fwd_expt_t.shape,
+ recv_dtype=fwd_expt_t.dtype,
+ current_device_mesh=curr_stage,
+ peer_device_mesh=prev_stage,
+ )
+ self.assertTrue(torch.equal(recv_fwd_tensor, fwd_expt_t))
+
+ @with_comms
+ def test_send_forward_recv_forward_no_shape(self):
+ """
+ Test correctness of send_forward_recv_forward without sharing tensor shape in advance.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ send_forward(
+ output_tensor=dtensor_stage1.to_local(),
+ current_device_mesh=device_mesh_stage1,
+ peer_device_mesh=device_mesh_stage2,
+ tensor_shape=None,
+ )
+
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor = send_forward_recv_forward(
+ output_tensor=dtensor_stage2._local_tensor,
+ recv_prev=True,
+ tensor_shape=None,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ stage3_recv_tensor = recv_forward(
+ tensor_shape=None,
+ recv_dtype=torch.float32,
+ current_device_mesh=device_mesh_stage3,
+ peer_device_mesh=device_mesh_stage2,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage3_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1,
+ )
+ )
+
+ @with_comms
+ def test_send_forward_recv_forward_with_shape(self):
+ """
+ Test correctness of send_forward_recv_forward with known tensor shape.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ send_forward(
+ output_tensor=dtensor_stage1.to_local(),
+ current_device_mesh=device_mesh_stage1,
+ peer_device_mesh=device_mesh_stage2,
+ tensor_shape=shape,
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor = send_forward_recv_forward(
+ output_tensor=dtensor_stage2._local_tensor,
+ recv_prev=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ stage3_recv_tensor = recv_forward(
+ tensor_shape=shape,
+ recv_dtype=torch.float32,
+ current_device_mesh=device_mesh_stage3,
+ peer_device_mesh=device_mesh_stage2,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage3_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1,
+ )
+ )
+
+ @with_comms
+ def test_send_backward_recv_backward_no_shape(self):
+ """
+ Test correctness of send_backward_recv_backward().
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = None
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ tensor_stage3 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2
+ dtensor_stage3 = distribute_tensor(tensor_stage3, device_mesh_stage3, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ stage1_recv_tensor = recv_backward(
+ tensor_shape=shape,
+ recv_dtype=torch.float32,
+ current_device_mesh=device_mesh_stage1,
+ peer_device_mesh=device_mesh_stage2,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage1_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1,
+ )
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor = send_backward_recv_backward(
+ input_tensor_grad=dtensor_stage2._local_tensor,
+ recv_next=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2,
+ )
+ )
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ send_backward(
+ input_tensor_grad=dtensor_stage3.to_local(),
+ current_device_mesh=device_mesh_stage3,
+ peer_device_mesh=device_mesh_stage2,
+ tensor_shape=shape,
+ )
+
+ @with_comms
+ def test_send_backward_recv_backward_with_shape(self):
+ """
+ Test correctness of send_backward_recv_backward().
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ tensor_stage3 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2
+ dtensor_stage3 = distribute_tensor(tensor_stage3, device_mesh_stage3, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ stage1_recv_tensor = recv_backward(
+ tensor_shape=shape,
+ recv_dtype=torch.float32,
+ current_device_mesh=device_mesh_stage1,
+ peer_device_mesh=device_mesh_stage2,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage1_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1,
+ )
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor = send_backward_recv_backward(
+ input_tensor_grad=dtensor_stage2._local_tensor,
+ recv_next=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2,
+ )
+ )
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ send_backward(
+ input_tensor_grad=dtensor_stage3.to_local(),
+ current_device_mesh=device_mesh_stage3,
+ peer_device_mesh=device_mesh_stage2,
+ tensor_shape=shape,
+ )
+
+ @with_comms
+ def test_send_forward_backward_recv_forward_backward_with_shape(self):
+ """
+ Test correctness of send_forward_backward_recv_forward_backward().
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ tensor_stage3 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2
+ dtensor_stage3 = distribute_tensor(tensor_stage3, device_mesh_stage3, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ recv_bwd_tensor = send_forward_recv_backward(
+ output_tensor=dtensor_stage1._local_tensor,
+ tensor_shape=shape,
+ recv_dtype=torch.float32,
+ current_device_mesh=device_mesh_stage1,
+ peer_device_mesh=device_mesh_stage2,
+ )
+ self.assertTrue(
+ torch.equal(
+ recv_bwd_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ )
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ input_tensor, output_tensor_grad = send_forward_backward_recv_forward_backward(
+ output_tensor=dtensor_stage2._local_tensor,
+ input_tensor_grad=dtensor_stage2._local_tensor,
+ recv_prev=True,
+ recv_next=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ input_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+ self.assertTrue(
+ torch.equal(
+ output_tensor_grad,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2,
+ )
+ )
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ recv_fwd_tensor = send_backward_recv_forward(
+ input_tensor_grad=dtensor_stage3.to_local(),
+ tensor_shape=shape,
+ recv_dtype=torch.float32,
+ current_device_mesh=device_mesh_stage3,
+ peer_device_mesh=device_mesh_stage2,
+ )
+ self.assertTrue(
+ torch.equal(
+ recv_fwd_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ )
+ )
+
+ @with_comms
+ def test_send_forward_backward_recv_forward_backward_no_shape(self):
+ """
+ Test correctness of send_forward_backward_recv_forward_backward()
+ without sharing tensor shapes in advance.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ tensor_stage3 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2
+ dtensor_stage3 = distribute_tensor(tensor_stage3, device_mesh_stage3, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ recv_bwd_tensor = send_forward_recv_backward(
+ output_tensor=dtensor_stage1._local_tensor,
+ tensor_shape=None,
+ recv_dtype=torch.float32,
+ current_device_mesh=device_mesh_stage1,
+ peer_device_mesh=device_mesh_stage2,
+ )
+ self.assertTrue(
+ torch.equal(
+ recv_bwd_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ )
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ input_tensor, output_tensor_grad = send_forward_backward_recv_forward_backward(
+ output_tensor=dtensor_stage2._local_tensor,
+ input_tensor_grad=dtensor_stage2._local_tensor,
+ recv_prev=True,
+ recv_next=True,
+ tensor_shape=None,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ input_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+ self.assertTrue(
+ torch.equal(
+ output_tensor_grad,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2,
+ )
+ )
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ recv_fwd_tensor = send_backward_recv_forward(
+ input_tensor_grad=dtensor_stage3.to_local(),
+ tensor_shape=None,
+ recv_dtype=torch.float32,
+ current_device_mesh=device_mesh_stage3,
+ peer_device_mesh=device_mesh_stage2,
+ )
+ self.assertTrue(
+ torch.equal(
+ recv_fwd_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ )
+ )
+
+ @with_comms
+ def test_send_forward_recv_forward_with_shape_next_device_mesh_none(self):
+ """
+ Test correctness of send_forward_recv_forward() with tensor shapes known.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, _ = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ send_forward(
+ output_tensor=dtensor_stage1.to_local(),
+ current_device_mesh=device_mesh_stage1,
+ peer_device_mesh=device_mesh_stage2,
+ tensor_shape=shape,
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor = send_forward_recv_forward(
+ output_tensor=dtensor_stage2._local_tensor,
+ recv_prev=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=None,
+ recv_dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+
+ @with_comms
+ def test_send_backward_recv_backward_with_shape_device_mesh_none(self):
+ """
+ Test correctness of send_backward_recv_backward() with tensor shapes known.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ tensor_stage3 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2
+ dtensor_stage3 = distribute_tensor(tensor_stage3, device_mesh_stage3, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor = send_backward_recv_backward(
+ input_tensor_grad=dtensor_stage2._local_tensor,
+ recv_next=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=None,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ )
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2,
+ )
+ )
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ send_backward(
+ input_tensor_grad=dtensor_stage3.to_local(),
+ current_device_mesh=device_mesh_stage3,
+ peer_device_mesh=device_mesh_stage2,
+ tensor_shape=shape,
+ )
+
+ @with_comms
+ def test_send_backward_recv_backward_with_shape_p2p_overlap(self):
+ """
+ Test correctness of send_backward_recv_backward() with overlapped p2p on.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ tensor_stage3 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2
+ dtensor_stage3 = distribute_tensor(tensor_stage3, device_mesh_stage3, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor, bwd_wait_handles = send_backward_recv_backward(
+ input_tensor_grad=dtensor_stage2._local_tensor,
+ recv_next=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=None,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ )
+
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ stage3_recv_tensor, bwd_wait_handles = send_backward_recv_backward(
+ input_tensor_grad=dtensor_stage3._local_tensor,
+ recv_next=False,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage3,
+ prev_device_mesh=device_mesh_stage2,
+ next_device_mesh=None,
+ recv_dtype=torch.float32,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ )
+ drain_recv_reqs("backward")
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2,
+ )
+ )
+
+ @with_comms
+ def test_send_forward_recv_forward_with_shape_p2p_overlap(self):
+ """
+ Test correctness of send_forward_recv_forward() with overlapped p2p on.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ stage1_recv_tensor, fwd_wait_handles = send_forward_recv_forward(
+ output_tensor=dtensor_stage1._local_tensor,
+ recv_prev=False,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage1,
+ prev_device_mesh=None,
+ next_device_mesh=device_mesh_stage2,
+ recv_dtype=torch.float32,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor, fwd_wait_handles = send_forward_recv_forward(
+ output_tensor=dtensor_stage2._local_tensor,
+ recv_prev=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=None,
+ recv_dtype=torch.float32,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ )
+ drain_recv_reqs("forward")
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+
+ @with_comms
+ def test_send_backward_recv_backward_with_shape_p2p_overlap_auto_modify(self):
+ """
+ Test correctness of send_backward_recv_backward() with overlapped p2p on.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ tensor_stage3 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2
+ dtensor_stage3 = distribute_tensor(tensor_stage3, device_mesh_stage3, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor, bwd_wait_handles = send_backward_recv_backward(
+ input_tensor_grad=dtensor_stage2._local_tensor,
+ recv_next=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=None,
+ next_device_mesh=device_mesh_stage3,
+ recv_dtype=torch.float32,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ )
+
+ if self.rank in device_mesh_stage3.mesh.tolist():
+ stage3_recv_tensor, bwd_wait_handles = send_backward_recv_backward(
+ input_tensor_grad=dtensor_stage3._local_tensor,
+ recv_next=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage3,
+ prev_device_mesh=device_mesh_stage2,
+ next_device_mesh=None,
+ recv_dtype=torch.float32,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ )
+
+ drain_recv_reqs("backward")
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor,
+ torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 2,
+ )
+ )
+
+ @with_comms
+ def test_send_forward_recv_forward_with_shape_p2p_overlap_auto_modify(self):
+ """
+ Test correctness of send_forward_recv_forward() with overlapped p2p on.
+ """
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1, device_mesh_stage2, device_mesh_stage3 = self._generate_three_device_meshes()
+ # stage 1 tensor
+ shape = (self.sequence_len, self.batch_size, self.input_size)
+ tensor_stage1 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ dtensor_stage1 = distribute_tensor(tensor_stage1, device_mesh_stage1, placements=[Replicate()])
+ tensor_stage2 = torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device) + 1
+ dtensor_stage2 = distribute_tensor(tensor_stage2, device_mesh_stage2, placements=[Replicate()])
+ # send to stage 2
+ if self.rank in device_mesh_stage1.mesh.tolist():
+ stage1_recv_tensor, fwd_wait_handles = send_forward_recv_forward(
+ output_tensor=dtensor_stage1._local_tensor,
+ recv_prev=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage1,
+ prev_device_mesh=None,
+ next_device_mesh=device_mesh_stage2,
+ recv_dtype=torch.float32,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ )
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ stage2_recv_tensor, fwd_wait_handles = send_forward_recv_forward(
+ output_tensor=dtensor_stage2._local_tensor,
+ recv_prev=True,
+ tensor_shape=shape,
+ current_device_mesh=device_mesh_stage2,
+ prev_device_mesh=device_mesh_stage1,
+ next_device_mesh=None,
+ recv_dtype=torch.float32,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ )
+
+ drain_recv_reqs("forward")
+ if self.rank in device_mesh_stage2.mesh.tolist():
+ self.assertTrue(
+ torch.equal(
+ stage2_recv_tensor, torch.ones(self.sequence_len, self.batch_size, self.input_size, device=device)
+ )
+ )
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/backend/test_pipe.py b/test/parallel/pipeline/backend/test_pipe.py
new file mode 100644
index 0000000..65614ac
--- /dev/null
+++ b/test/parallel/pipeline/backend/test_pipe.py
@@ -0,0 +1,342 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import torch
+import numpy as np
+import torch.fx as fx
+import re
+from torch.testing._internal.common_utils import run_tests
+from common_dtensor import DTensorTestBase, with_comms
+from vescale.pipe import PipeModule, construct_stage_modules, construct_pipeline_split_graph
+from vescale.plan import (
+ PipelineParallelPlan,
+ PipelineScheduleType,
+ PipelineSplitMethodType,
+ ModeType,
+ TracerType,
+)
+from vescale.initialize.deferred_init import deferred_init, is_deferred
+from eight_mlp import EightMLP, sharding_plan, sharding_plan_fc
+from vescale.dmodule._dmodule import DModule
+from vescale.dmodule.api import parallelize_module
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+import torch.distributed as dist
+from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
+from vescale.optim.distributed_optimizer import DistributedOptimizer
+from vescale.dtensor.api import distribute_tensor
+from vescale.dtensor.placement_types import Replicate
+from torch.fx.passes.split_utils import split_by_tags
+
+
+class PipeModuleTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 4
+
+ @staticmethod
+ def loss_fn(x):
+ return x.mean()
+
+ def _setup(self, pp_size: int = 2, dp_size: int = 1, tp_size: int = 2, virtual_chunks: int = 1):
+ num_layers = 8
+ VESCALE_DEVICE_MESH.init_device_mesh("cuda", (pp_size, dp_size, tp_size), mesh_dim_names=("PP", "DP", "TP"))
+ deferred_mlp = deferred_init(EightMLP, hidden=8)
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ num_stages=2,
+ virtual_chunks=virtual_chunks,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(num_layers)],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B
+ if virtual_chunks == 1
+ else PipelineScheduleType.INTERLEAVED_1F1B,
+ )
+ return deferred_mlp, pipe_config
+
+ @with_comms
+ def test_generate_stage_dependency(self):
+ """
+ Tests PipeModule's ability to generate inter-stage dependency.
+ """
+ deferred_mlp, config = self._setup()
+ num_stages = 2
+
+ _, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ deferred_mlp, config, VESCALE_DEVICE_MESH, update_split_points=True
+ )
+
+ target_deps = np.zeros((num_stages, num_stages))
+ target_deps[0, 1] = 1
+ target_p2p_mapping = {0: [(0, 0)], 1: [(0, 0)]}
+ self.assertEqual(stage_dependency, target_deps)
+ flattened_index_mapping = {
+ i: [(spec[0].peer_stage_idx, spec[0].peer_output_idx)] for i, spec in p2p_index_mapping.items()
+ }
+ self.assertEqual(flattened_index_mapping, target_p2p_mapping)
+
+ @with_comms
+ def test_generate_stage_dependency_four_stages(self):
+ """
+ Tests PipeModule's ability to generate inter-stage dependency among four pipeline stages.
+ """
+ deferred_mlp, config = self._setup(pp_size=4, dp_size=1, tp_size=1, virtual_chunks=1)
+ num_stages = 4
+ config.num_stages = num_stages
+
+ _, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ deferred_mlp, config, VESCALE_DEVICE_MESH, update_split_points=True
+ )
+
+ target_deps = np.zeros((num_stages, num_stages))
+ target_deps[0, 1] = 1
+ target_deps[1, 2] = 1
+ target_deps[2, 3] = 1
+ target_p2p_mapping = {0: [(0, 0)], 1: [(0, 0)], 2: [(1, 0)], 3: [(2, 0)]}
+ self.assertEqual(stage_dependency, target_deps)
+ flattened_index_mapping = {
+ i: [(spec[0].peer_stage_idx, spec[0].peer_output_idx)] for i, spec in p2p_index_mapping.items()
+ }
+ self.assertEqual(flattened_index_mapping, target_p2p_mapping)
+
+ @with_comms
+ def test_forward(self):
+ """
+ Tests PipeModule's forward function.
+ """
+ deferred_mlp, _ = self._setup(virtual_chunks=2)
+ num_layers = 8
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ num_stages=4,
+ virtual_chunks=2,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(num_layers)],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ )
+
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+ tp_mesh = VESCALE_DEVICE_MESH["TP"]
+
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ deferred_mlp,
+ pipe_config,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+ for i in range(len(stage_modules)):
+ parallelized_module = parallelize_module(
+ stage_modules[i],
+ tp_mesh,
+ sharding_plan,
+ factory=False,
+ )
+ stage_modules[i] = parallelized_module
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+ _parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters())
+ optimizer = torch.optim.SGD(_parameters, **optimizer_fn_kwargs)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, pipe_config)
+
+ model_chunk_one = pipe_module[0]
+ model_chunk_two = pipe_module[1]
+ assert DModule.is_dmodule(pipe_module.stage_modules[0])
+ assert DModule.is_dmodule(pipe_module.stage_modules[1])
+ input = torch.randn((3, 8))
+ out_chunk_one = pipe_module(input, chunk_id=0)
+ out_chunk_two = pipe_module(input, chunk_id=1)
+ assert torch.equal(out_chunk_one, model_chunk_one(input))
+ assert torch.equal(out_chunk_two, model_chunk_two(input))
+
+
+class PipeModuleTraceTest(DTensorTestBase):
+ @with_comms
+ def test_compile_mode(self):
+ """
+ Tests correctness of registering hooks on partitioned model graphs.
+ """
+ model = EightMLP(8)
+
+ def hook(sel, args):
+ print(f"{torch.distributed.get_rank()}: call hook")
+ return args
+
+ graph = fx.symbolic_trace(model)
+ input = torch.randn((3, 8))
+ rule = r"mlp\d+.*"
+ for node in graph.graph.nodes:
+ if re.match(rule, node.name):
+ if int(node.name[3]) <= 4:
+ node.tag = "stage0"
+ else:
+ node.tag = "stage1"
+ global_graph = split_by_tags(graph, ["stage0", "stage1"])
+ splited_module = global_graph.get_submodule("stage0")
+ splited_module.mlp1.fc1.register_forward_pre_hook(hook)
+ splited_module.mlp1.gelu.register_forward_pre_hook(hook)
+ splited_module.mlp1.fc2.register_forward_pre_hook(hook)
+ splited_module.mlp2.fc1.register_forward_pre_hook(hook)
+ splited_module.mlp2.gelu.register_forward_pre_hook(hook)
+ splited_module.mlp2.fc2.register_forward_pre_hook(hook)
+ splited_module.mlp3.fc1.register_forward_pre_hook(hook)
+ splited_module.mlp3.gelu.register_forward_pre_hook(hook)
+ splited_module.mlp3.fc2.register_forward_pre_hook(hook)
+ splited_module.mlp4.fc1.register_forward_pre_hook(hook)
+ splited_module.mlp4.gelu.register_forward_pre_hook(hook)
+ splited_module.mlp4.fc2.register_forward_pre_hook(hook)
+ splited_module(input)
+
+ @with_comms
+ def test_compile_equivalent(self):
+ """
+ Tests correctness of registering hooks on partitioned model graphs.
+ """
+ model = EightMLP(8)
+
+ def hook(sel, args):
+ print(f"{torch.distributed.get_rank()}: call hook")
+ return args
+
+ graph = fx.symbolic_trace(model)
+ input = torch.randn((3, 8))
+ rule = r"mlp\d+.*"
+ for node in graph.graph.nodes:
+ if re.match(rule, node.name):
+ if int(node.name[3]) <= 4:
+ node.tag = "stage0"
+ else:
+ node.tag = "stage1"
+ global_graph = split_by_tags(graph, ["stage0", "stage1"])
+ splited_module = global_graph.get_submodule("stage0")
+ call_modules_fqns = [node.target for node in splited_module.graph.nodes if node.op == "call_module"]
+ for submodule_path in call_modules_fqns:
+ splited_module.get_submodule(submodule_path).register_forward_pre_hook(hook)
+ splited_module(input)
+
+ @with_comms
+ def test_decomposable_5d_parallelization(self):
+ """
+ Tests decomposable API of writing 5D parallelization from plan to parallelization.
+ """
+ # build device mesh
+ device_mesh = VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda", mesh_shape=(2, 1, 2), mesh_dim_names=["PP", "DP", "TP"]
+ )
+ # deferred init mlp module
+ deferred_mlp = deferred_init(EightMLP, hidden=8)
+ # pipe module config
+ boundaries = ["mlp4", "mlp8"]
+ num_layers = 8
+ pipe_config = PipelineParallelPlan(
+ num_stages=2,
+ split_method=PipelineSplitMethodType.MANUAL,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(num_layers)],
+ split_points=boundaries,
+ tracer_type=TracerType.TORCH_FX,
+ tracer_kwargs={"shard_plan": sharding_plan},
+ )
+ split_graph = construct_pipeline_split_graph(deferred_mlp, pipe_config, update_split_points=True)
+
+ # parallelize and materialize module
+ model_chunks = []
+ for i in range(pipe_config.num_stages):
+ stage = getattr(split_graph, f"stage{i}")
+ stage = parallelize_module(
+ stage, VESCALE_DEVICE_MESH.get_tensor_parallel_mesh(), sharding_plan, factory=False
+ )
+ assert not is_deferred(stage)
+ model_chunks.append(stage)
+ if dist.get_rank() == 0:
+ assert model_chunks[0].mlp1.fc1.weight._spec.placements[0].is_shard()
+
+ # make ddp module
+ ddp_models = []
+ for model_chunk in model_chunks:
+ ddp_models.append(
+ DDP(
+ model_chunk,
+ VESCALE_DEVICE_MESH.get_data_parallel_mesh(),
+ accumulate_allreduce_grads_in_fp32=True,
+ overlap_grad_reduce=True,
+ use_distributed_optimizer=True,
+ )
+ )
+
+ if dist.get_rank() == 0:
+ assert model_chunks[0].mlp1.fc1.weight._spec.placements[0].is_shard()
+
+ # make optimizer
+ doptim = DistributedOptimizer(
+ torch.optim.Adam(split_graph.parameters(), lr=0.01),
+ models=ddp_models,
+ overlap_param_gather=False,
+ )
+ tp_mesh = VESCALE_DEVICE_MESH.get_tensor_parallel_mesh()
+ stage_id = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank()
+
+ num_layers = 8
+ dataloader = [distribute_tensor(torch.zeros((5, 8)), tp_mesh, [Replicate()]) * i for i in range(num_layers)]
+ for sample in dataloader:
+ doptim.zero_grad()
+ output = ddp_models[stage_id](sample)
+ loss = output.mean()
+ loss.backward()
+ doptim.step()
+
+ @with_comms
+ def test_manual_split_various_boundary_level(self):
+ """
+ Tests PipeModule's ability to split stage by boundaries of various depths.
+ """
+ VESCALE_DEVICE_MESH.init_device_mesh("cuda", (2, 1, 2), mesh_dim_names=("PP", "DP", "TP"))
+ deferred_mlp = deferred_init(EightMLP, hidden=8)
+ pipe_config = PipelineParallelPlan(
+ num_stages=2,
+ split_method=PipelineSplitMethodType.MANUAL,
+ smallest_unsplittable_units=["mlp7", "mlp8"],
+ split_points=["mlp4.fc1", "mlp8"],
+ tracer_type=TracerType.TORCH_FX,
+ tracer_kwargs={"partition_units": ["mlp7", "mlp8"]},
+ )
+
+ split_graph = construct_pipeline_split_graph(deferred_mlp, pipe_config, update_split_points=True)
+ for i in range(pipe_config.num_stages):
+ stage = getattr(split_graph, f"stage{i}")
+ stage = parallelize_module(
+ stage, VESCALE_DEVICE_MESH.get_tensor_parallel_mesh(), sharding_plan_fc, factory=False
+ )
+ assert not is_deferred(stage)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/backend/test_pipe_parser.py b/test/parallel/pipeline/backend/test_pipe_parser.py
new file mode 100644
index 0000000..565b349
--- /dev/null
+++ b/test/parallel/pipeline/backend/test_pipe_parser.py
@@ -0,0 +1,172 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from torch.testing._internal.common_utils import run_tests
+from common_dtensor import DTensorTestBase, with_comms
+from vescale.pipe.pipe_parser import PipeParser
+from vescale.initialize.deferred_init import deferred_init
+from vescale.plan import PipelineParallelPlan, PipelineScheduleType, ModeType, PipelineSplitMethodType
+from eight_mlp import EightMLP, EightMLPWithOps, EightMLPWithOpsTail
+
+
+class TestPipeParser(DTensorTestBase):
+ @with_comms
+ def test_parse_naive_model(self):
+ """
+ Tests trace capture with torch.fx symbolic tracer under user-defined granularity.
+ """
+ deferred_mlp = deferred_init(EightMLP, hidden=8)
+ partition_units = ["mlp4", "mlp8"]
+ pipe_parser = PipeParser()
+ model_graph = pipe_parser.parse(deferred_mlp)
+ print(model_graph)
+ assert not all(node.target in partition_units for node in model_graph.graph.nodes)
+
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ smallest_unsplittable_units=partition_units,
+ )
+ model_graph_partition_units = pipe_parser.parse(deferred_mlp, pipe_config)
+ print(model_graph_partition_units)
+ assert any(node.target in partition_units for node in model_graph_partition_units.graph.nodes)
+
+ @with_comms
+ def test_parse_huggingface_model(self):
+ """
+ Tests trace capture with huggingface symbolic tracer under user-defined granularity.
+ """
+ from transformers import LlamaModel, LlamaConfig
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
+
+ configuration = LlamaConfig()
+ configuration.hidden_size = 256
+ configuration.intermediate_size = 1376
+ configuration.num_attention_heads = 1
+ configuration.num_hidden_layers = 2
+ model = LlamaModel(configuration)
+
+ # below two lists of partition units refer to the same submodules we never wish to partition
+ partition_units = ["layers.0", "layers.1", "norm"]
+ partition_units_equivalent = [LlamaDecoderLayer, LlamaRMSNorm]
+ pipe_config = PipelineParallelPlan(smallest_unsplittable_units=partition_units)
+ pipe_config_equivalent = PipelineParallelPlan(smallest_unsplittable_units=partition_units_equivalent)
+
+ pipe_parser = PipeParser()
+ model_graph = pipe_parser.parse(model)
+ print(model_graph)
+ assert not all(node.target in partition_units for node in model_graph.graph.nodes)
+
+ model_graph_partition_units = pipe_parser.parse(model, pipe_config)
+ print(model_graph_partition_units)
+ result = [node.target in partition_units for node in model_graph_partition_units.graph.nodes]
+ assert any(result)
+
+ # the resulting graph should be identical to the one parsed by model_graph_partition_units
+ model_graph_partition_units_equivalent = pipe_parser.parse(model, pipe_config_equivalent)
+ print(model_graph_partition_units_equivalent)
+ result_two = [node.target in partition_units for node in model_graph_partition_units_equivalent.graph.nodes]
+ assert any(result_two)
+ self.assertEqual(result, result_two)
+
+ @with_comms
+ def test_uniform_split(self):
+ """
+ Tests uniform stage split.
+ """
+ deferred_mlp = deferred_init(EightMLP, hidden=8)
+ layers = 8
+ pipe_parser = PipeParser()
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ num_stages=2,
+ virtual_chunks=1,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(layers)],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ )
+ model_graph_partition_units = pipe_parser.parse(deferred_mlp, pipe_config)
+ print(model_graph_partition_units)
+ splited_graph = pipe_parser.partition_stage(deferred_mlp, model_graph_partition_units, pipe_config)
+ self.assertEqual(
+ [node.name for node in splited_graph.stage0.graph.nodes][1:-1], ["mlp1", "mlp2", "mlp3", "mlp4"]
+ )
+ self.assertEqual(
+ [node.name for node in splited_graph.stage1.graph.nodes][1:-1], ["mlp5", "mlp6", "mlp7", "mlp8"]
+ )
+
+ @with_comms
+ def test_uniform_split_model_with_ops(self):
+ """
+ Tests uniform stage split with torch operators as graph components.
+ """
+ deferred_mlp = deferred_init(EightMLPWithOpsTail, hidden=8)
+ layers = 8
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ num_stages=2,
+ virtual_chunks=1,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(layers)],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ )
+ pipe_parser = PipeParser()
+ model_graph_partition_units = pipe_parser.parse(deferred_mlp, pipe_config)
+ print(model_graph_partition_units)
+ splited_graph = pipe_parser.partition_stage(deferred_mlp, model_graph_partition_units, pipe_config)
+ self.assertEqual(
+ [node.name for node in splited_graph.stage0.graph.nodes][1:-1],
+ ["add", "mlp1", "mul", "mlp2", "mul_1", "mlp3", "mlp4"],
+ )
+ self.assertEqual(
+ [node.name for node in splited_graph.stage1.graph.nodes][1:-1],
+ ["mlp5", "mlp6", "mlp7", "mlp8", "mul_2", "mul_3", "add_1"],
+ )
+
+ @with_comms
+ def test_uniform_split_on_modules(self):
+ """
+ Tests uniform stage split on modules with modules and torch operators.
+ """
+ deferred_mlp = deferred_init(EightMLPWithOps, hidden=8)
+ layers = 8
+ pipe_parser = PipeParser()
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ num_stages=2,
+ virtual_chunks=1,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(layers)],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ uniform_split_ops=True,
+ )
+ model_graph_partition_units = pipe_parser.parse(deferred_mlp, pipe_config)
+ print(model_graph_partition_units)
+ splited_graph = pipe_parser.partition_stage(deferred_mlp, model_graph_partition_units, pipe_config)
+ stage_one_modules = ["add", "mlp1", "mul", "mlp2", "mul_1", "mul_2", "mul_3", "mlp3", "mlp4"]
+ stage_two_modules = ["mlp5", "mlp6", "mlp7", "mlp8"]
+ self.assertEqual([node.name for node in splited_graph.stage0.graph.nodes][1:-1], stage_one_modules)
+ self.assertEqual([node.name for node in splited_graph.stage1.graph.nodes][1:-1], stage_two_modules)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/backend/test_shard_plan.py b/test/parallel/pipeline/backend/test_shard_plan.py
new file mode 100644
index 0000000..28f663f
--- /dev/null
+++ b/test/parallel/pipeline/backend/test_shard_plan.py
@@ -0,0 +1,72 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import torch
+from torch.testing._internal.common_utils import run_tests
+from common_dtensor import DTensorTestBase, with_comms
+from vescale.pipe.pipe_parser import PipeParser
+from vescale.initialize.deferred_init import deferred_init
+from eight_mlp import EightMLP, sharding_plan
+from vescale.dmodule.api import parallelize_module
+from vescale.dtensor.api import distribute_tensor
+from vescale.devicemesh_api.api import VESCALE_DEVICE_MESH
+from vescale.dtensor.placement_types import Replicate
+from vescale.plan import PipelineParallelPlan, PipelineSplitMethodType
+
+
+class ShardPlanRegistrationTest(DTensorTestBase):
+ @with_comms
+ def test_manual_split_register_hook(self):
+ """
+ Tests manual stage split and registers hooks.
+ """
+ VESCALE_DEVICE_MESH.init_device_mesh("cuda", (2, 1, 2), mesh_dim_names=("PP", "DP", "TP"))
+ deferred_mlp = deferred_init(EightMLP, hidden=8)
+ partition_units = ["mlp1", "mlp8"]
+ pipe_config = PipelineParallelPlan(
+ num_stages=2,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ smallest_unsplittable_units=partition_units,
+ )
+ pipe_parser = PipeParser()
+ input = torch.randn((3, 8))
+ model_graph = pipe_parser.parse(
+ deferred_mlp,
+ pipe_config,
+ **{"shard_plan": sharding_plan},
+ )
+ pipe_spec = pipe_parser.partition_stage(deferred_mlp, model_graph, pipe_config)
+ model_chunks = []
+ model_partition = pipe_spec.stage0
+ model = parallelize_module(
+ model_partition, VESCALE_DEVICE_MESH.get_tensor_parallel_mesh(), sharding_plan, factory=False
+ )
+
+ # hooks are successfully registered on target modules, as they now have been hierarchically flattened!
+ def hook(sel, args):
+ print("hook registered. Successful registration will trigger this printout!")
+ return args
+
+ model.get_submodule("mlp1").register_forward_pre_hook(hook)
+ d_input = distribute_tensor(input, VESCALE_DEVICE_MESH.get_tensor_parallel_mesh(), [Replicate()])
+ d_out = model(d_input)
+ model_chunks.append(model)
+ assert model_chunks[0].mlp1.fc1.weight._spec.placements[0].is_shard()
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/backend/test_shared_params.py b/test/parallel/pipeline/backend/test_shared_params.py
new file mode 100644
index 0000000..6e3a695
--- /dev/null
+++ b/test/parallel/pipeline/backend/test_shared_params.py
@@ -0,0 +1,301 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+from common_dtensor import DTensorTestBase, with_comms
+from vescale.dtensor.api import distribute_tensor
+from vescale.optim.base_optimizer import BasicOptimizer
+from vescale.initialize import materialize_module
+from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
+from vescale.plan import (
+ PipelineParallelPlan,
+ PipelineSplitMethodType,
+ PipelineScheduleType,
+ ModeType,
+)
+from vescale.pipe import PipeModule, build_shared_module_group, construct_stage_modules, construct_pipeline_split_graph
+from vescale.initialize.deferred_init import deferred_init
+from eight_mlp import sharding_plan, EightMLPSharedEmbed
+from vescale.dtensor.placement_types import Replicate
+from vescale.dmodule.api import parallelize_module
+
+
+microbatch_size = 16
+factor = 16
+batch_size = microbatch_size * factor
+RANDOM_SEED = 9999
+
+
+class SharedParamsTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 8
+
+ @with_comms
+ def test_sync_embedding_weights_two_stages(self):
+ """
+ Test correctness of synchronizing "shared_units" (embedding)
+ weights upon engine initialization.
+ """
+ pp_size = 2
+ dp_size = 2
+ tp_size = 2
+ deferred_mlp = deferred_init(EightMLPSharedEmbed, hidden=8)
+ partition_units = [f"mlp{i + 1}" for i in range(8)] + ["embed1", "embed2"]
+ pp_plan = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ num_stages=2,
+ virtual_chunks=1,
+ smallest_unsplittable_units=partition_units,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ shared_modules=[
+ ["embed1", "embed2"]
+ ], # each sublist represents a group of modules to synchronize params/grads
+ )
+ split_graph = construct_pipeline_split_graph(deferred_mlp, pp_plan, update_split_points=True)
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(pp_size, dp_size, tp_size),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ deferred_mlp,
+ pp_plan,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+ for module in stage_modules:
+ materialize_module(module)
+ module.cuda()
+
+ combined_parameters = list(stage_modules[0].parameters())
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+ optimizer = torch.optim.SGD(combined_parameters, **optimizer_fn_kwargs)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, pp_plan)
+
+ build_shared_module_group(
+ pipe_module,
+ split_graph,
+ pp_plan.num_stages,
+ pp_plan.virtual_chunks,
+ pp_plan.shared_modules,
+ VESCALE_DEVICE_MESH,
+ )
+ if VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 0:
+ embedding = pipe_module[0].get_submodule("embed1").get_word_embeddings_weight().data
+ else:
+ embedding = pipe_module[0].get_submodule("embed2").get_word_embeddings_weight().data
+ pipe_module.sync_shared_params(VESCALE_DEVICE_MESH, group_id=0, share_params=True, chunk_id=0)
+ if VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 0:
+ sync_embedding = pipe_module[0].get_submodule("embed1").get_word_embeddings_weight().data
+ else:
+ sync_embedding = pipe_module[0].get_submodule("embed2").get_word_embeddings_weight().data
+ assert not torch.testing.assert_close(embedding, sync_embedding)
+
+ @with_comms
+ def test_sync_embedding_weights_four_stages(self):
+ """
+ Test correctness of synchronizing "shared_units" (embedding)
+ weights given four stages partitioned.
+ """
+ pp_size = 4
+ dp_size = 2
+ tp_size = 1
+ model = EightMLPSharedEmbed(hidden=8).cuda()
+ partition_units = [f"mlp{i + 1}" for i in range(8)] + ["embed1", "embed2"]
+ pp_plan = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.MANUAL,
+ num_stages=4,
+ virtual_chunks=1,
+ smallest_unsplittable_units=partition_units,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ split_points=["mlp2", "mlp5", "mlp7", "embed2"],
+ shared_modules=[
+ ["embed1", "embed2"]
+ ], # each sublist represents a group of modules to synchronize params/grads
+ )
+
+ split_graph = construct_pipeline_split_graph(model, pp_plan, update_split_points=True)
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(pp_size, dp_size, tp_size),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model,
+ pp_plan,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+ combined_parameters = list(stage_modules[0].parameters())
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+ optimizer = torch.optim.SGD(combined_parameters, **optimizer_fn_kwargs)
+ basic_optimizer = BasicOptimizer(optimizer, models=stage_modules)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, pp_plan)
+
+ build_shared_module_group(
+ pipe_module,
+ split_graph,
+ pp_plan.num_stages,
+ pp_plan.virtual_chunks,
+ pp_plan.shared_modules,
+ VESCALE_DEVICE_MESH,
+ )
+ if VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 0:
+ embedding = pipe_module[0].get_submodule("embed1").get_word_embeddings_weight().data
+ elif VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 3:
+ embedding = pipe_module[0].get_submodule("embed2").get_word_embeddings_weight().data
+ else:
+ embedding = None
+ if VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() in [0, 3]:
+ pipe_module.sync_shared_params(VESCALE_DEVICE_MESH, group_id=0, share_params=True, chunk_id=0)
+ if VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 0:
+ sync_embedding = pipe_module[0].get_submodule("embed1").get_word_embeddings_weight().data
+ assert not torch.testing.assert_close(embedding, sync_embedding)
+ elif VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 3:
+ sync_embedding = pipe_module[0].get_submodule("embed2").get_word_embeddings_weight().data
+ assert not torch.testing.assert_close(embedding, sync_embedding)
+
+ @with_comms
+ def test_sync_embedding_gradients(self):
+ """
+ Test correctness of synchronizing "shared_units" (embedding)
+ weights given uniform partition results.
+ """
+ pp_size = 2
+ dp_size = 4
+ tp_size = 1
+ model = EightMLPSharedEmbed(hidden=8).cuda()
+ partition_units = [f"mlp{i + 1}" for i in range(8)] + ["embed1", "embed2"]
+
+ pp_plan = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.UNIFORM,
+ num_stages=2,
+ virtual_chunks=1,
+ smallest_unsplittable_units=partition_units,
+ schedule_type=PipelineScheduleType.SIMPLE_1F1B,
+ shared_modules=[
+ ["embed1", "embed2"]
+ ], # each sublist represents a group of modules to synchronize params/grads
+ )
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+
+ split_graph = construct_pipeline_split_graph(model, pp_plan, update_split_points=True)
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(pp_size, dp_size, tp_size),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+ tp_mesh = VESCALE_DEVICE_MESH["TP"]
+ dp_mesh = VESCALE_DEVICE_MESH["DP"]
+
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model,
+ pp_plan,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+ for i in range(len(stage_modules)):
+ parallelized_module = parallelize_module(
+ stage_modules[i],
+ tp_mesh,
+ sharding_plan,
+ factory=False,
+ )
+ ddp_module = DDP(
+ parallelized_module,
+ dp_mesh,
+ accumulate_allreduce_grads_in_fp32=True,
+ overlap_grad_reduce=True,
+ use_distributed_optimizer=False,
+ disable_bucketing=False,
+ bucket_size=40000000,
+ )
+ stage_modules[i] = ddp_module
+ combined_parameters = list(stage_modules[0].parameters())
+ optimizer = torch.optim.SGD(combined_parameters, **optimizer_fn_kwargs)
+ basic_optimizer = BasicOptimizer(optimizer, models=stage_modules)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, pp_plan)
+
+ build_shared_module_group(
+ pipe_module,
+ split_graph,
+ pp_plan.num_stages,
+ pp_plan.virtual_chunks,
+ pp_plan.shared_modules,
+ VESCALE_DEVICE_MESH,
+ )
+ loss_fn = nn.MSELoss()
+ input_tensor = distribute_tensor(torch.ones(3).long().cuda(), tp_mesh, [Replicate()])
+
+ if VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 0:
+ embed = pipe_module[0].module.embed1
+ else:
+ embed = pipe_module[0].module.embed2
+ output = embed(input_tensor)
+ target = torch.zeros_like(output)
+ target = distribute_tensor(target, tp_mesh, [Replicate()])
+ losses = loss_fn(output, target)
+ losses.backward()
+ old_grad = embed.embedding.weight.main_grad.clone()
+ pipe_module.sync_shared_params(VESCALE_DEVICE_MESH, group_id=0, share_params=False, chunk_id=0)
+ if VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() == 0:
+ embed = pipe_module[0].module.embed1
+ else:
+ embed = pipe_module[0].module.embed2
+ new_grad = embed.embedding.weight.main_grad.clone()
+ assert not torch.equal(old_grad, new_grad)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/backend/test_trace_parser.py b/test/parallel/pipeline/backend/test_trace_parser.py
new file mode 100644
index 0000000..df32f87
--- /dev/null
+++ b/test/parallel/pipeline/backend/test_trace_parser.py
@@ -0,0 +1,133 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from vescale.pipe.tracer import ModelTracer, HFModelTracer, register_partition_module, hf_symbolic_trace
+from common_dtensor import DTensorTestBase, with_comms
+from transformers import LlamaModel, LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
+
+
+class MLP(nn.Module):
+ def __init__(self, features_in, features_out, value):
+ super().__init__()
+ self.value = value
+ self.fc1 = nn.Linear(features_in, 2 * features_in, bias=False)
+ self.fc1.weight.data.fill_(value)
+ self.fc2 = nn.Linear(2 * features_in, features_out, bias=False)
+ self.fc2.weight.data.fill_(value * 2)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ return t
+
+
+class EightMLP(nn.Module):
+ def __init__(self, hidden=1024):
+ super().__init__()
+ self.mlp1 = MLP(hidden, hidden, 0)
+ self.mlp2 = MLP(hidden, hidden, 1)
+ self.mlp3 = MLP(hidden, hidden, 2)
+ self.mlp4 = MLP(hidden, hidden, 3)
+ self.mlp5 = MLP(hidden, hidden, 4)
+ self.mlp6 = MLP(hidden, hidden, 5)
+ self.mlp7 = MLP(hidden, hidden, 6)
+ self.mlp8 = MLP(hidden, hidden, 7)
+ self.sequence = nn.Sequential(
+ self.mlp1,
+ self.mlp2,
+ self.mlp3,
+ self.mlp4,
+ self.mlp5,
+ self.mlp6,
+ self.mlp7,
+ self.mlp8,
+ )
+
+ def forward(self, x):
+ return self.sequence(x)
+
+
+class TracerTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 1
+
+ @with_comms
+ def test_simple_model_tracer(self):
+ """
+ Test fx tracer to capture native symbolic trace
+ of simple model.
+ """
+ model = EightMLP(16)
+ tracer = ModelTracer()
+ traced_graph = tracer.trace(model)
+ print("Simple Model Graph Trace:")
+ print(traced_graph)
+
+ @with_comms
+ def test_simple_model_tracer_with_partition_units(self):
+ """
+ Test fx tracer to capture symbolic trace with granularity of
+ MLP level (do not dive into operators of MLP) of simple model.
+ """
+ model = EightMLP(16)
+ register_partition_module(model.mlp1)
+ register_partition_module(model.mlp2)
+ register_partition_module(model.mlp3)
+ register_partition_module(model.mlp4)
+ register_partition_module(model.mlp5)
+ register_partition_module(model.mlp6)
+ register_partition_module(model.mlp7)
+ register_partition_module(model.mlp8)
+ tracer = ModelTracer()
+ traced_graph = tracer.trace(model)
+ print(traced_graph)
+
+ @with_comms
+ def test_huggingface_model_tracer_with_partition_units(self):
+ """
+ Test huggingface tracer to capture symbolic trace with granularity
+ of LlamaDecoderLayer and LlamaRMSNorm.
+ """
+ configuration = LlamaConfig()
+ configuration.hidden_size = 1024
+ configuration.intermediate_size = 5504
+ configuration.num_attention_heads = 1
+ configuration.num_hidden_layers = 2
+
+ model = LlamaModel(configuration)
+ submodule_qualified_names = ["layers.0", "layers.1", "norm"]
+ # submodules indicated by submodule_qualified_names are modules that have the classes below
+ partition_unit_modules = [LlamaDecoderLayer, LlamaRMSNorm] + submodule_qualified_names
+ traced_graph = hf_symbolic_trace(
+ model,
+ input_names=["input_ids", "attention_mask"],
+ tracer_cls=HFModelTracer,
+ partition_modules=partition_unit_modules,
+ )
+ print("HF Model Graph Trace:")
+ print(traced_graph)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py b/test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py
new file mode 100644
index 0000000..163f453
--- /dev/null
+++ b/test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py
@@ -0,0 +1,247 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from vescale.plan import (
+ PipelineParallelPlan,
+ PipelineScheduleType,
+ ModeType,
+ PipelineSplitMethodType,
+)
+from vescale.pipe import PipeModule, construct_stage_modules
+from vescale.engine import PipeEngine
+from common_dtensor import DTensorTestBase, with_comms
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+
+microbatch_size = 2
+factor = 8
+batch_size = microbatch_size * factor
+stage = 4
+RANDOM_SEED = 9999
+
+
+class MLP(nn.Module):
+ def __init__(self, features_in, feature_middle, features_out, value, idx=1):
+ super().__init__()
+ self.value = value
+ self.idx = idx
+ self.counter = 0
+ self.fc1 = nn.Linear(features_in, feature_middle, bias=False)
+ self.fc2 = nn.Linear(feature_middle, features_out, bias=False)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ torch.save(t, f"{os.environ['model_name']}_mlp{self.value}_fwd{self.counter}_out_tensor.pt")
+ self.counter += 1
+ return t
+
+
+class EightMLP(nn.Module):
+ def __init__(self, hidden=1024, fixed_size=True):
+ super().__init__()
+ self.mlp1 = MLP(hidden, hidden, hidden, 1, 1)
+ self.mlp2 = MLP(hidden, hidden, hidden, 2, 2)
+ self.mlp3 = MLP(hidden, hidden, hidden, 1, 3)
+ self.mlp4 = MLP(hidden, hidden, hidden, 2, 4)
+ self.mlp5 = MLP(hidden, hidden, hidden, 1, 5)
+ self.mlp6 = MLP(hidden, hidden, hidden, 2, 6)
+ self.mlp7 = MLP(hidden, hidden, hidden, 1, 7)
+ self.mlp8 = MLP(hidden, hidden, hidden, 2, 8)
+
+ def forward(self, x):
+ x = self.mlp1(x)
+ x = self.mlp2(x)
+ x = self.mlp3(x)
+ x = self.mlp4(x)
+ x = self.mlp5(x)
+ x = self.mlp6(x)
+ x = self.mlp7(x)
+ x = self.mlp8(x)
+ return x
+
+
+class PipelineAccuracyAlignmentTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 4
+
+ @staticmethod
+ def loss_fn(x):
+ return x.mean()
+
+ @staticmethod
+ def save_mlp_parameter(model: MLP, f_name):
+ torch.save(model.fc1.weight, f"{f_name}.fc1")
+ torch.save(model.fc2.weight, f"{f_name}.fc2")
+
+ @staticmethod
+ def load_mlp_parameter(f_prefix):
+ fc1_weight = torch.load(f"{f_prefix}.fc1").to("cuda:0")
+ fc2_weight = torch.load(f"{f_prefix}.fc2").to("cuda:0")
+ return (fc1_weight, fc2_weight)
+
+ def check_model_weight_diff(self, f_prefix):
+ def helper(f1, f2):
+ golden_weights = self.load_mlp_parameter(f1)
+ pp_weights = self.load_mlp_parameter(f2)
+ torch.testing.assert_close(golden_weights[0], pp_weights[0])
+ torch.testing.assert_close(golden_weights[1], pp_weights[1])
+
+ helper(f"golden_mlp{self.rank + 1}", f"{f_prefix}_mlp{self.rank + 1}")
+
+ def check_out_tensors(self, model_name):
+ def helper(f1, f2):
+ golden_out = torch.load(f1).to("cuda:0")
+ pp_out = torch.load(f2).to("cuda:0")
+ torch.testing.assert_close(golden_out, pp_out)
+
+ for i in range(1, 3):
+ for j in range(8):
+ helper(f"golden_mlp{i}_fwd{j}_out_tensor.pt", f"{model_name}_mlp{i}_fwd{j}_out_tensor.pt")
+ torch.cuda.synchronize()
+
+ def test_accuracy_alignment(self, fixed_size=True):
+ """
+ Tests alignment of updated parameter and output activations of single device model and
+ the model partitioned into four stages with pipeline parallelism API.
+ """
+ if self.rank == 0:
+ self._run_no_pp_model(fixed_size=fixed_size)
+ torch.cuda.synchronize()
+ n_gpus = torch.cuda.device_count()
+ assert n_gpus >= 2, "Requires at least 2 GPUs to run model with pp engine"
+ self._run_engine_with_1f1b(fixed_size=fixed_size)
+ if self.rank == 0:
+ self.check_out_tensors("pp")
+ self.check_model_weight_diff("engine_1f1b")
+
+ def _run_no_pp_model(self, fixed_size=True):
+ os.environ["model_name"] = "golden"
+ model = EightMLP(16, fixed_size=fixed_size).to("cuda:0")
+ torch.save(model.state_dict(), "baseline_model.pt")
+ optimizer = torch.optim.SGD(
+ model.parameters(),
+ lr=0.01,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ maximize=False,
+ foreach=None,
+ differentiable=False,
+ )
+ torch.manual_seed(9999)
+ batch = [torch.ones(microbatch_size, 128, 16, dtype=torch.float32).to("cuda:0") for _ in range(factor)]
+ for mb in batch:
+ out = model(mb)
+ loss = self.loss_fn(out)
+ loss.backward()
+ optimizer.step()
+ torch.save(out, "golden_out.pt")
+ torch.save(loss, "golden_loss.pt")
+ self.save_mlp_parameter(model.mlp1, "golden_mlp1")
+ self.save_mlp_parameter(model.mlp2, "golden_mlp2")
+ self.save_mlp_parameter(model.mlp3, "golden_mlp3")
+ self.save_mlp_parameter(model.mlp4, "golden_mlp4")
+ self.save_mlp_parameter(model.mlp5, "golden_mlp5")
+ self.save_mlp_parameter(model.mlp6, "golden_mlp6")
+ self.save_mlp_parameter(model.mlp7, "golden_mlp7")
+ self.save_mlp_parameter(model.mlp8, "golden_mlp8")
+
+ @with_comms
+ def _run_engine_with_1f1b(self, fixed_size=True):
+ os.environ["model_name"] = "pp"
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ model = EightMLP(16, fixed_size=fixed_size).cuda()
+ model.load_state_dict(torch.load("baseline_model.pt"))
+
+ pipe_config = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.MANUAL,
+ num_stages=4,
+ virtual_chunks=2,
+ smallest_unsplittable_units=["mlp1", "mlp2", "mlp3", "mlp4", "mlp5", "mlp6", "mlp7", "mlp8"],
+ split_points=["mlp2", "mlp4", "mlp6", "mlp8"],
+ batch_p2p_comm=False,
+ overlap_p2p_comm=True,
+ schedule_type=PipelineScheduleType.INTERLEAVED_1F1B,
+ )
+
+ optimizer_fn_kwargs = {
+ "lr": 0.01,
+ "momentum": 0,
+ "dampening": 0,
+ "weight_decay": 0,
+ "nesterov": False,
+ "maximize": False,
+ "foreach": None,
+ "differentiable": False,
+ }
+
+ torch.manual_seed(9999)
+ with torch.no_grad():
+ batch = [torch.ones(microbatch_size, 128, 16, dtype=torch.float32).to(device) for _ in range(factor)]
+
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=["PP", "DP", "TP"],
+ )
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model,
+ pipe_config,
+ VESCALE_DEVICE_MESH,
+ update_split_points=True,
+ )
+ _parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters())
+ optimizer = torch.optim.SGD(_parameters, **optimizer_fn_kwargs)
+ pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, pipe_config)
+ engine = PipeEngine(
+ pipe_module,
+ VESCALE_DEVICE_MESH,
+ self.loss_fn,
+ pipe_config,
+ )
+
+ engine.forward_backward(batch)
+ optimizer = engine.get_optimizer
+ optimizer.step()
+
+ if self.rank == 0:
+ self.save_mlp_parameter(engine.module[0].get_submodule("mlp1"), "engine_1f1b_mlp1")
+ self.save_mlp_parameter(engine.module[1].get_submodule("mlp5"), "engine_1f1b_mlp5")
+ if self.rank == 1:
+ self.save_mlp_parameter(engine.module[0].get_submodule("mlp2"), "engine_1f1b_mlp2")
+ self.save_mlp_parameter(engine.module[1].get_submodule("mlp6"), "engine_1f1b_mlp6")
+ if self.rank == 2:
+ self.save_mlp_parameter(engine.module[0].get_submodule("mlp3"), "engine_1f1b_mlp3")
+ self.save_mlp_parameter(engine.module[1].get_submodule("mlp7"), "engine_1f1b_mlp7")
+ if self.rank == 3:
+ self.save_mlp_parameter(engine.module[0].get_submodule("mlp4"), "engine_1f1b_mlp4")
+ self.save_mlp_parameter(engine.module[1].get_submodule("mlp8"), "engine_1f1b_mlp8")
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/instruction/four_mlp.py b/test/parallel/pipeline/instruction/four_mlp.py
new file mode 100644
index 0000000..b6ac2f4
--- /dev/null
+++ b/test/parallel/pipeline/instruction/four_mlp.py
@@ -0,0 +1,71 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import torch.nn as nn
+from vescale.dtensor.placement_types import Shard, Replicate
+
+
+class MLP(nn.Module):
+ def __init__(self, features_in, features_out, value):
+ super().__init__()
+ self.value = value
+ self.fc1 = nn.Linear(features_in, 16, bias=False)
+ self.fc1.weight.data.fill_(value)
+ self.fc2 = nn.Linear(16, features_out, bias=False)
+ self.fc2.weight.data.fill_(value * 2)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ t = self.fc1(x)
+ t = self.gelu(t)
+ t = self.fc2(t)
+ return t
+
+
+class FourMLP(nn.Module):
+ def __init__(self, hidden=64, fixed_size=True):
+ super().__init__()
+ if fixed_size:
+ self.mlp1 = MLP(hidden, hidden, 1)
+ self.mlp2 = MLP(hidden, hidden, 2)
+ self.mlp3 = MLP(hidden, hidden, 3)
+ self.mlp4 = MLP(hidden, hidden, 4)
+ else:
+ self.mlp1 = MLP(hidden * 1, hidden * 2, 1)
+ self.mlp2 = MLP(hidden * 2, hidden * 3, 2)
+ self.mlp3 = MLP(hidden * 3, hidden * 4, 3)
+ self.mlp4 = MLP(hidden * 4, hidden * 5, 4)
+
+ def forward(self, x):
+ x = self.mlp1(x)
+ x = self.mlp2(x)
+ x = self.mlp3(x)
+ x = self.mlp4(x)
+ return x
+
+
+sharding_plan = {
+ "forward": {
+ ".input": [[Replicate()]],
+ r"mlp\d.fc1.input": [[Replicate()]],
+ r"mlp\d.fc2.output": [[Replicate()]],
+ },
+ "parameter": {
+ r"mlp\d.fc1.weight": [Shard(0)],
+ r"mlp\d.fc2.weight": [Shard(1)],
+ },
+}
diff --git a/test/parallel/pipeline/instruction/test_multistage_schedule.py b/test/parallel/pipeline/instruction/test_multistage_schedule.py
new file mode 100644
index 0000000..8695c0c
--- /dev/null
+++ b/test/parallel/pipeline/instruction/test_multistage_schedule.py
@@ -0,0 +1,190 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from common_dtensor import DTensorTestBase, with_comms
+from vescale.pipe._schedules.instruction_base import StageDeps
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from vescale.plan.spec import PipelineP2PSpec
+
+
+class MLP(nn.Module):
+ def __init__(self, n_features):
+ super().__init__()
+ self.fc1 = nn.Linear(n_features, n_features * 2, bias=False)
+ torch.nn.init.uniform_(self.fc1.weight, 0, 1)
+ self.fc2 = nn.Linear(n_features * 2, n_features)
+ torch.nn.init.uniform_(self.fc2.weight, 0, 1)
+ self.gelu = nn.GELU()
+
+ def forward(self, x, y=None):
+ out = self.fc2(self.gelu(self.fc1(x)))
+ if y is not None:
+ out = out + y
+ return out
+
+
+class FourMLP(nn.Module):
+ def __init__(self, hidden):
+ super().__init__()
+ self.mlp1 = MLP(hidden)
+ self.mlp2 = MLP(hidden)
+ self.mlp3 = MLP(hidden)
+ self.mlp4 = MLP(hidden)
+
+ def forward(self, x):
+ stage1 = self.mlp1(x)
+ stage2 = self.mlp2(stage1)
+ stage3 = self.mlp3(stage2, x)
+ stage4 = self.mlp4(stage3)
+ return stage4
+
+
+class MultiStageCommTest(DTensorTestBase):
+ def test_send_order(self):
+ """
+ Tests send order.
+
+ stage 0: a , c
+ stage 1: b
+ stage 2: dataloader
+
+ stage 2: forward(c,b,dataloader,a)
+
+ """
+ a = torch.tensor(0)
+ b = torch.tensor(1)
+ c = torch.tensor(2)
+ d = torch.tensor(3)
+ p2p_tensors = [a, c, b]
+ p2p_index = [PipelineP2PSpec(0, 2), PipelineP2PSpec(1, 0), PipelineP2PSpec(2, 0), PipelineP2PSpec(0, 0)]
+ local_inputs = [d]
+
+ p2p_index_without_local = list(filter(lambda item: item.peer_stage_idx != 2, p2p_index))
+ p2p_send_order = sorted(p2p_index_without_local, key=lambda x: (x.peer_stage_idx, x.peer_output_idx))
+ p2p_tensor_order = [p2p_send_order.index(item) for item in p2p_index_without_local]
+ ordered_p2p_tensors = [p2p_tensors[x] for x in p2p_tensor_order]
+
+ assert ordered_p2p_tensors == [c, b, a]
+
+ args = []
+ local_input_mapping = list(filter(lambda item: item.peer_stage_idx == 2, p2p_index))
+ for item in p2p_index:
+ if item.peer_stage_idx == 2:
+ index = local_input_mapping.index(item)
+ args.append(local_inputs[index])
+ else:
+ index = p2p_send_order.index(item)
+ args.append(p2p_tensors[index])
+ assert args == [c, b, d, a]
+
+ @with_comms
+ def test_stage_deps(self):
+ """
+ Tests abstraction of inter-stage communication dependency.
+ """
+ # initialize global device mesh
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=("PP", "DP", "TP"),
+ )
+ print(VESCALE_DEVICE_MESH.get())
+
+ # case 1 - sequential input is one
+ single_deps = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]])
+ stage = StageDeps(
+ single_deps,
+ VESCALE_DEVICE_MESH.get_global_pipeline_parallel_meshes(),
+ [],
+ )
+ if torch.distributed.distributed_c10d.get_rank() == 0:
+ print(stage)
+
+ # case 2 - sequential multi input
+ single_deps = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]])
+ p2p_index_mapping = {1: [PipelineP2PSpec(0, 0), PipelineP2PSpec(0, 1)]}
+ stage = StageDeps(
+ single_deps,
+ VESCALE_DEVICE_MESH.get_global_pipeline_parallel_meshes(),
+ [],
+ p2p_index_mapping=p2p_index_mapping,
+ )
+ if torch.distributed.distributed_c10d.get_rank() == 0:
+ print(stage)
+
+ # case 3 - sequential multi input with local_dataloader
+ """
+ The adjacency matrix for 4 stages is formulated as a 4x4 matrix. The meaning can be interpreted as followed:
+ Row (Stage) 0: [0, 1, 0, 0]. stage 0 sends output to stage 1 (index position 1).
+ Row (Stage) 1: [0, 0, 1, 0]: stage 1 sends output to stage 2 (index position 2).
+ Row (Stage) 2: [0, 0, 0, 1]: stage 2 sends output to stage 3 (index position 3).
+ Row (Stage) 3: [0, 0, 0, 0]: stage 3 sends no output to any other stage.
+ """
+ single_deps = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]])
+ p2p_index_mapping = {1: [PipelineP2PSpec(0, 2), PipelineP2PSpec(1, 0), PipelineP2PSpec(0, 0)]}
+ stage = StageDeps(
+ single_deps,
+ VESCALE_DEVICE_MESH.get_global_pipeline_parallel_meshes(),
+ [],
+ p2p_index_mapping=p2p_index_mapping,
+ )
+ if torch.distributed.distributed_c10d.get_rank() == 0:
+ print(stage)
+
+ # case 4 - multi branch input with single data
+ """
+ The adjacency matrix for 4 stages is formulated as a 4x4 matrix. The meaning can be interpreted as followed:
+ Row (Stage) 0: [0, 1, 0, 0]. stage 0 sends output to stage 1 (index position 1).
+ Row (Stage) 1: [0, 0, 1, 0]: stage 1 sends output to stage 2 (index position 2).
+ Row (Stage) 2: [0, 0, 0, 1]: stage 2 sends output to stage 3 (index position 3).
+ Row (Stage) 3: [0, 0, 0, 0]: stage 3 sends no output to any other stage.
+ """
+ single_deps = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]])
+ p2p_index_mapping = {2: [PipelineP2PSpec(0, 0), PipelineP2PSpec(1, 0)]}
+ stage = StageDeps(
+ single_deps,
+ VESCALE_DEVICE_MESH.get_global_pipeline_parallel_meshes(),
+ [],
+ p2p_index_mapping=p2p_index_mapping,
+ )
+ if torch.distributed.distributed_c10d.get_rank() == 0:
+ print(stage)
+
+ # case 5 - vpp test
+ """
+ The adjacency matrix for 4 stages is formulated as a 4x4 matrix. The meaning can be interpreted as followed:
+ Row (Stage) 0: [0, 1, 0, 0]. stage 0 sends output to stage 1 (index position 1).
+ Row (Stage) 1: [0, 0, 1, 0]: stage 1 sends output to stage 2 (index position 2).
+ Row (Stage) 2: [0, 0, 0, 1]: stage 2 sends output to stage 3 (index position 3).
+ Row (Stage) 3: [0, 0, 0, 0]: stage 3 sends no output to any other stage.
+ """
+ single_deps = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]])
+ stage = StageDeps(
+ single_deps,
+ VESCALE_DEVICE_MESH.get_global_pipeline_parallel_meshes(),
+ [0, 1],
+ )
+ if torch.distributed.distributed_c10d.get_rank() == 0:
+ print(stage)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/instruction/test_pipe_instruction_register.py b/test/parallel/pipeline/instruction/test_pipe_instruction_register.py
new file mode 100644
index 0000000..ba4b36f
--- /dev/null
+++ b/test/parallel/pipeline/instruction/test_pipe_instruction_register.py
@@ -0,0 +1,60 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import unittest
+from unittest import TestCase
+from vescale.pipe._schedules.instruction_base import register_instruction, registed_functions, InstructionBuilder
+
+
+class InstructionRegistrationTest(TestCase):
+ def test_pp_registed_function(self):
+ """
+ Tests instruction registration.
+ """
+
+ @register_instruction(name="instruction_one")
+ def instruction_one(input):
+ print(input)
+ return input
+
+ assert "instruction_one" in registed_functions
+
+ def test_instruction_constructor(self):
+ """
+ Tests instruction construction.
+ """
+
+ @register_instruction(name="I1")
+ def instruction_one(input):
+ return input + 1
+
+ @register_instruction(name="I2")
+ def instruction_two(input):
+ return input * 2
+
+ @register_instruction(name="B")
+ def bubble(input):
+ return input
+
+ instructions = {0: "B,I1,I1,I1,I1,I2,I2", 1: "B,I2,I2,I2,I2,I1,I1,I1"}
+ builder = InstructionBuilder()
+ builder.build_from_dict(instructions)
+ builder.draw_instructions()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/parallel/pipeline/instruction/test_schedule.py b/test/parallel/pipeline/instruction/test_schedule.py
new file mode 100644
index 0000000..9a185dc
--- /dev/null
+++ b/test/parallel/pipeline/instruction/test_schedule.py
@@ -0,0 +1,529 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+from common_dtensor import DTensorTestBase, with_comms
+from vescale.pipe._schedules.instruction_base import get_linear_pp_module_dep2
+from vescale.pipe._schedules.pipedream_flush import PipeDream
+from vescale.pipe._schedules.looping_bfs import InterleavedPipeDreramFlush
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from vescale.dtensor.api import distribute_tensor
+from vescale.dtensor.device_mesh import DeviceMesh
+from vescale.dmodule.api import parallelize_module
+from vescale.dtensor.placement_types import Replicate
+from vescale.plan.spec import PipelineScheduleType
+from vescale.pipe.pipe_emmiter import ScheduleEngine
+
+
+class MLP(nn.Module):
+ def __init__(self, n_features):
+ super().__init__()
+ self.fc1 = nn.Linear(n_features, n_features * 2, bias=False)
+ torch.nn.init.uniform_(self.fc1.weight, 0, 1)
+ self.fc2 = nn.Linear(n_features * 2, n_features)
+ torch.nn.init.uniform_(self.fc2.weight, 0, 1)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ return self.fc2(self.gelu(self.fc1(x)))
+
+ def forward_utils(p2p, dataloader):
+ if p2p is not None:
+ return p2p
+ else:
+ return dataloader
+
+
+class FourMLP(nn.Module):
+ def __init__(self, hidden):
+ super().__init__()
+ self.mlp1 = MLP(hidden)
+ self.mlp2 = MLP(hidden)
+ self.mlp3 = MLP(hidden)
+ self.mlp4 = MLP(hidden)
+
+ def forward(self, x):
+ return self.mlp4(self.mlp3(self.mlp2(self.mlp1(x))))
+
+
+class EightMLP(nn.Module):
+ def __init__(self, hidden):
+ super().__init__()
+ self.mlps = [MLP(hidden) for _ in range(8)]
+
+ def forward(self, x):
+ all_input_x = []
+ for idx, mlp in enumerate(self.mlps):
+ x = mlp(x)
+ x.retain_grad()
+ all_input_x.append(x)
+ print(f"mlp: {idx} output : {x}")
+ return x, all_input_x
+
+
+class PipelineScheduleTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 4
+
+ @staticmethod
+ def loss_fn(x):
+ return x.sum()
+
+ @with_comms
+ def test_1f1b_schedules(self):
+ """
+ Test generation of simple 1f1b schedule.
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1 = DeviceMesh(device, [0])
+ device_mesh_stage2 = DeviceMesh(device, [1])
+ device_mesh_stage3 = DeviceMesh(device, [2])
+ device_mesh_stage4 = DeviceMesh(device, [3])
+ meshes = (device_mesh_stage1, device_mesh_stage2, device_mesh_stage3, device_mesh_stage4)
+ microbatch = 8
+ batch = 8
+ stage = 4
+ schedule = PipeDream(stage, meshes, batch)
+ if torch.distributed.distributed_c10d.get_rank() == 0:
+ print(schedule)
+
+ @with_comms
+ def test_interleaved_1f1b_schedules(self):
+ """
+ Test generation of interleaved 1f1b schedule.
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ device_mesh_stage1 = DeviceMesh(device, [0])
+ device_mesh_stage2 = DeviceMesh(device, [1])
+ device_mesh_stage3 = DeviceMesh(device, [2])
+ device_mesh_stage4 = DeviceMesh(device, [3])
+ meshes = (device_mesh_stage1, device_mesh_stage2, device_mesh_stage3, device_mesh_stage4)
+ batches = 8
+ num_chunks = 2
+ schedule = InterleavedPipeDreramFlush(
+ num_chunks, meshes, default_shape=[1, 1, 3], default_dtype=torch.float32, batches=batches
+ )
+ if self.rank == 0:
+ print(schedule)
+
+ @with_comms
+ def test_runtime_engine_with_profiling(self):
+ """
+ Tests runtime engine with distributed nD timeline profiling.
+ """
+ # initialize global device mesh
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=("PP", "DP", "TP"),
+ )
+ global local_rank
+ local_rank = self.rank
+ device = f"cuda:{local_rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(local_rank)
+ from vescale.ndtimeline import init_ndtimers, flush, wait
+
+ init_ndtimers(rank=int(local_rank), local=int(local_rank), enable_streamer=True)
+ n_hidden = 3
+ batches = 8
+ model = FourMLP(n_hidden)
+ all_batches_out = []
+ if self.rank == 3:
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i
+ data = data.float().cuda(3)
+ model.cuda(3)
+ out = model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ print(loss)
+ print(" ====================================== ")
+ fwd_plan = {
+ ".input": [[Replicate()]],
+ ".output": [[Replicate()]],
+ }
+ model_list = []
+
+ tp_mesh = VESCALE_DEVICE_MESH.get_tensor_parallel_mesh()
+ if local_rank == 0:
+ model.mlp1 = parallelize_module(model.mlp1, tp_mesh, {"parameter": None, "forward": fwd_plan})
+ model_list = [model.mlp1]
+ elif self.rank == 1:
+ model.mlp2 = parallelize_module(model.mlp2, tp_mesh, {"parameter": None, "forward": fwd_plan})
+ model_list = [model.mlp2]
+ elif self.rank == 2:
+ model.mlp3 = parallelize_module(model.mlp3, tp_mesh, {"parameter": None, "forward": fwd_plan})
+ model_list = [model.mlp3]
+ elif self.rank == 3:
+ model.mlp4 = parallelize_module(model.mlp4, tp_mesh, {"parameter": None, "forward": fwd_plan})
+ model_list = [model.mlp4]
+ deps = get_linear_pp_module_dep2(model_list, VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes())
+ data_iterator = []
+ for i in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + i
+ data_iterator.append(distribute_tensor(data.float(), tp_mesh, placements=[Replicate()]))
+ pipe_engine = ScheduleEngine(
+ deps=deps,
+ meshes=VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes(),
+ schedule=PipelineScheduleType.SIMPLE_1F1B,
+ batches=batches,
+ data_iterator=data_iterator,
+ stage_id=local_rank,
+ shape=(1, 1, 3),
+ dtype=torch.float32,
+ )
+ _, all_forward = ScheduleEngine.execute(pipe_engine)
+ if self.rank == 3:
+ loss_per_microbatch = [item[1] for item in all_forward]
+ for t1, t2 in zip(loss_per_microbatch, all_batches_out):
+ self.assertEqual(t1._local_tensor, t2)
+ flush()
+ wait()
+
+ @with_comms
+ def test_interleaved_1f1b_emmiter(self):
+ """
+ Test schedule instructions generated by ScheduleEngine's pipeline emitter.
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ n_hidden = 3
+ batches = 8
+ num_chunks = 2
+ meshes = [DeviceMesh(device, [i]) for i in range(self.world_size)]
+ model = EightMLP(n_hidden)
+ fwd_plan = {
+ ".input": [[Replicate()]],
+ ".output": [[Replicate()]],
+ }
+ vpp_module_chunk_list = []
+ if self.rank == 0:
+ model.mlps[0] = parallelize_module(model.mlps[0], meshes[0], {"parameter": None, "forward": fwd_plan})
+ model.mlps[4] = parallelize_module(model.mlps[4], meshes[0], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[0], model.mlps[4]]
+ elif self.rank == 1:
+ model.mlps[1] = parallelize_module(model.mlps[1], meshes[1], {"parameter": None, "forward": fwd_plan})
+ model.mlps[5] = parallelize_module(model.mlps[5], meshes[1], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[1], model.mlps[5]]
+ elif self.rank == 2:
+ model.mlps[2] = parallelize_module(model.mlps[2], meshes[2], {"parameter": None, "forward": fwd_plan})
+ model.mlps[6] = parallelize_module(model.mlps[6], meshes[2], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[2], model.mlps[6]]
+ elif self.rank == 3:
+ model.mlps[3] = parallelize_module(model.mlps[3], meshes[3], {"parameter": None, "forward": fwd_plan})
+ model.mlps[7] = parallelize_module(model.mlps[7], meshes[3], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[3], model.mlps[7]]
+
+ deps = get_linear_pp_module_dep2(vpp_module_chunk_list, meshes)
+ data_iterator = []
+ for i in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + i
+ data_iterator.append(
+ distribute_tensor(
+ data.float(), DeviceMesh(device, [self.rank], _validate_mesh=False), placements=[Replicate()]
+ )
+ )
+ pipe_engine = ScheduleEngine(
+ deps,
+ meshes,
+ PipelineScheduleType.INTERLEAVED_1F1B,
+ batches,
+ iter(data_iterator),
+ self.rank,
+ (1, 1, 3),
+ dtype=torch.float32,
+ num_chunks=num_chunks,
+ )
+
+ @with_comms
+ def test_runtime_interleaved_1f1b_engine_batch(self):
+ """
+ Test parallelized DModules to perform interleaved 1f1b training.
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ n_hidden = 3
+ batches = 8
+ num_chunks = 2
+ meshes = [DeviceMesh(device, [i]) for i in range(self.world_size)]
+ model = EightMLP(n_hidden)
+ all_batches_out = []
+ if self.rank == 3:
+ true_model = model
+ for i in range(8):
+ true_model.mlps[i] = true_model.mlps[i].cuda(3)
+ true_model.train()
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i
+ data = data.float().cuda(3)
+ out, all_output_x = true_model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ for idx, output in enumerate(all_output_x):
+ print(f"mlp{idx}.grad is {output.grad}")
+ print(" ====================================== ")
+ fwd_plan = {
+ ".input": [[Replicate()]],
+ ".output": [[Replicate()]],
+ }
+ vpp_module_chunk_list = []
+ if self.rank == 0:
+ model.mlps[0] = parallelize_module(model.mlps[0], meshes[0], {"parameter": None, "forward": fwd_plan})
+ model.mlps[4] = parallelize_module(model.mlps[4], meshes[0], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[0], model.mlps[4]]
+ elif self.rank == 1:
+ model.mlps[1] = parallelize_module(model.mlps[1], meshes[1], {"parameter": None, "forward": fwd_plan})
+ model.mlps[5] = parallelize_module(model.mlps[5], meshes[1], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[1], model.mlps[5]]
+ elif self.rank == 2:
+ model.mlps[2] = parallelize_module(model.mlps[2], meshes[2], {"parameter": None, "forward": fwd_plan})
+ model.mlps[6] = parallelize_module(model.mlps[6], meshes[2], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[2], model.mlps[6]]
+ elif self.rank == 3:
+ model.mlps[3] = parallelize_module(model.mlps[3], meshes[3], {"parameter": None, "forward": fwd_plan})
+ model.mlps[7] = parallelize_module(model.mlps[7], meshes[3], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[3], model.mlps[7]]
+ deps = get_linear_pp_module_dep2(vpp_module_chunk_list, meshes)
+ data_iterator = []
+ for i in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + i
+ data_iterator.append(
+ distribute_tensor(
+ data.float(), DeviceMesh(device, [self.rank], _validate_mesh=False), placements=[Replicate()]
+ )
+ )
+ pipe_engine = ScheduleEngine(
+ deps,
+ meshes,
+ PipelineScheduleType.INTERLEAVED_1F1B,
+ batches,
+ [iter(data_iterator) for _ in range(num_chunks)],
+ self.rank,
+ (1, 1, 3),
+ dtype=torch.float32,
+ num_chunks=num_chunks,
+ loss_fn=self.loss_fn,
+ )
+ if self.rank == 0:
+ print("schedule", pipe_engine.p_emmiter.instruction_generator.schema)
+ _, forward_datas = ScheduleEngine.execute(pipe_engine)
+ if self.rank == 3:
+ loss_per_microbatch = [item[1] for item in forward_datas]
+ for t1, t2 in zip(loss_per_microbatch, all_batches_out):
+ self.assertEqual(t1._local_tensor, t2)
+
+ @with_comms
+ def test_runtime_interleaved_1f1b_engine_p2p(self):
+ """
+ Test step-by-step initialization of pipeline engine, generation
+ of simple 1f1b schedule and execution of pipeline engine with
+ p2p overlapped communication.
+ """
+ device = f"cuda:{self.rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(self.rank)
+ n_hidden = 3
+ batches = 8
+ num_chunks = 2
+ meshes = [DeviceMesh(device, [i]) for i in range(self.world_size)]
+ model = EightMLP(n_hidden)
+ all_batches_out = []
+ if self.rank == 3:
+ true_model = model
+ for i in range(8):
+ true_model.mlps[i] = true_model.mlps[i].cuda(3)
+ true_model.train()
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i
+ data = data.float().cuda(3)
+ out, all_output_x = true_model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ for idx, output in enumerate(all_output_x):
+ print(f"mlp{idx}.grad is {output.grad}")
+ print(" ====================================== ")
+ fwd_plan = {
+ ".input": [[Replicate()]],
+ ".output": [[Replicate()]],
+ }
+ vpp_module_chunk_list = []
+ if self.rank == 0:
+ model.mlps[0] = parallelize_module(model.mlps[0], meshes[0], {"parameter": None, "forward": fwd_plan})
+ model.mlps[4] = parallelize_module(model.mlps[4], meshes[0], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[0], model.mlps[4]]
+ elif self.rank == 1:
+ model.mlps[1] = parallelize_module(model.mlps[1], meshes[1], {"parameter": None, "forward": fwd_plan})
+ model.mlps[5] = parallelize_module(model.mlps[5], meshes[1], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[1], model.mlps[5]]
+ elif self.rank == 2:
+ model.mlps[2] = parallelize_module(model.mlps[2], meshes[2], {"parameter": None, "forward": fwd_plan})
+ model.mlps[6] = parallelize_module(model.mlps[6], meshes[2], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[2], model.mlps[6]]
+ elif self.rank == 3:
+ model.mlps[3] = parallelize_module(model.mlps[3], meshes[3], {"parameter": None, "forward": fwd_plan})
+ model.mlps[7] = parallelize_module(model.mlps[7], meshes[3], {"parameter": None, "forward": fwd_plan})
+ vpp_module_chunk_list = [model.mlps[3], model.mlps[7]]
+ deps = get_linear_pp_module_dep2(vpp_module_chunk_list, meshes)
+ data_iterator = []
+ for i in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + i
+ data_iterator.append(
+ distribute_tensor(data.float(), DeviceMesh(device, [0], _validate_mesh=False), placements=[Replicate()])
+ )
+ pipe_engine = ScheduleEngine(
+ deps,
+ meshes,
+ PipelineScheduleType.INTERLEAVED_1F1B,
+ batches,
+ [iter(data_iterator) for _ in range(num_chunks)],
+ self.rank,
+ (1, 1, 3),
+ dtype=torch.float32,
+ num_chunks=num_chunks,
+ overlap_p2p_comm=True,
+ batch_p2p_comm=False,
+ loss_fn=self.loss_fn,
+ )
+ if self.rank == 0:
+ print("schedule", pipe_engine.p_emmiter.instruction_generator.schema)
+ _, forward_datas = ScheduleEngine.execute(pipe_engine)
+ if self.rank == 3:
+ loss_per_microbatch = [item[1] for item in forward_datas]
+ print(loss_per_microbatch, all_batches_out)
+ for t1, t2 in zip(loss_per_microbatch, all_batches_out):
+ self.assertEqual(t1._local_tensor, t2)
+
+ @with_comms
+ def test_zerobubble_engine(self):
+ """
+ Tests zero-bubble pipeline schedule with profiling.
+ """
+ # initialize global device mesh
+ VESCALE_DEVICE_MESH.init_device_mesh(
+ device_type="cuda",
+ mesh_shape=(4, 1, 1),
+ mesh_dim_names=("PP", "DP", "TP"),
+ )
+ global local_rank
+ local_rank = self.rank
+ device = f"cuda:{local_rank}"
+ # must do this: https://pytorch.org/docs/stable/distributed.html
+ torch.cuda.set_device(device)
+ os.environ["LOCAL_RANK"] = str(local_rank)
+ from vescale.ndtimeline import init_ndtimers, flush, wait
+
+ init_ndtimers(rank=int(local_rank), local_rank=int(local_rank), enable_streamer=True)
+ num_chunks = 2
+ n_hidden = 3
+ batches = 8
+ model = EightMLP(n_hidden)
+ for i in range(8):
+ model.mlps[i] = model.mlps[i].cuda()
+ all_batches_out = []
+ if self.rank == 0:
+ true_model = model
+ for i in range(8):
+ true_model.mlps[i] = true_model.mlps[i].cuda(0)
+ true_model.train()
+ for i in range(batches):
+ print(f" ===========batch: {i}================= ")
+ data = torch.zeros(1, 1, n_hidden) + i
+ data = data.float().cuda(0)
+ out, all_output_x = true_model(data)
+ loss = out.sum()
+ all_batches_out.append(loss)
+ loss.backward(create_graph=True)
+ for idx, output in enumerate(all_output_x):
+ print(f"mlp{idx}.grad is {output.grad}")
+ print(" ====================================== ")
+ fwd_plan = {
+ ".input": [[Replicate()]],
+ ".output": [[Replicate()]],
+ }
+ model_list = []
+
+ if self.rank == 0:
+ model_list = [model.mlps[0], model.mlps[7]]
+ elif self.rank == 1:
+ model_list = [model.mlps[1], model.mlps[6]]
+ elif self.rank == 2:
+ model_list = [model.mlps[2], model.mlps[5]]
+ elif self.rank == 3:
+ model_list = [model.mlps[3], model.mlps[4]]
+ deps = get_linear_pp_module_dep2(model_list, VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes())
+ data_iterator = []
+ for i in range(batches):
+ data = torch.zeros(1, 1, n_hidden) + i
+ data_iterator.append(data.float().cuda())
+
+ w = n_hidden * 2 * 4
+ a = n_hidden * 4
+ mem_f = 2 * w + 2 * a # forward weight size
+ mem_w = -2 * a
+ mem_b = -mem_w - mem_f
+ pipe_engine = ScheduleEngine(
+ deps=deps,
+ meshes=VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes(),
+ schedule=PipelineScheduleType.ZERO_BUBBLE,
+ batches=batches,
+ data_iterator=[iter(data_iterator) for _ in range(num_chunks)],
+ stage_id=local_rank,
+ shape=(1, 1, 3),
+ dtype=torch.float32,
+ f_cost=6,
+ b_cost=4,
+ w_cost=4,
+ c_cost=1,
+ f_mem=mem_f,
+ b_mem=mem_b,
+ w_mem=mem_w,
+ max_mem=mem_f * 4 * 2,
+ )
+ _, all_forward = ScheduleEngine.execute(pipe_engine)
+ if self.rank == 0:
+ loss_per_microbatch = [item[1] for item in all_forward]
+ print(loss_per_microbatch, all_batches_out)
+ for t1, t2 in zip(loss_per_microbatch, all_batches_out):
+ self.assertEqual(t1, t2)
+
+ flush()
+ wait()
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/instruction/test_userdefine_schedule.py b/test/parallel/pipeline/instruction/test_userdefine_schedule.py
new file mode 100644
index 0000000..680a4d9
--- /dev/null
+++ b/test/parallel/pipeline/instruction/test_userdefine_schedule.py
@@ -0,0 +1,197 @@
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from common_dtensor import DTensorTestBase, with_comms
+from torch.testing._internal.common_utils import run_tests
+from vescale.pipe._schedules.instruction_base import (
+ register_instruction,
+ VESCALE_INTRUCTION_BUILDER as builder,
+ StageDeps,
+)
+from vescale.initialize.deferred_init import deferred_init
+from vescale.pipe import PipeParser
+from vescale.pipe.pipe_stage import _generate_stage_dependencies
+from vescale.dmodule.api import parallelize_module
+from vescale.dtensor.device_mesh import DeviceMesh
+import torch
+from four_mlp import FourMLP, sharding_plan
+from vescale.pipe._schedules.pipedream_flush import maybe_tensor, cross_mesh_send, cross_mesh_recv
+
+from torch.distributed._functional_collectives import send, recv
+
+from vescale.plan.pipeline_parallel import PipelineParallelPlan
+from vescale.plan.spec import PipelineSplitMethodType
+
+
+class PowerUserScheduleTest(DTensorTestBase):
+ @with_comms
+ def test_user_define_schedule(self):
+ """
+ Tests user-defined pipeline schedule.
+ """
+ global_mesh = DeviceMesh("cuda", [[0, 1], [2, 3]])
+ torch.cuda.set_device(self.rank)
+
+ @register_instruction(name="send")
+ def send_forward():
+ topo = builder.topo
+ send_data = builder.last
+ send_comms = topo.send_tables[builder.stage_id]
+ send_comm = send_comms[0]
+ mapping_group = send_comm.cur_mesh.get_mapping_rank(send_comm.peer_mesh)
+ send(maybe_tensor(send_data), mapping_group, torch.distributed.distributed_c10d._get_default_group())
+ cross_mesh_send(send_comm, send_data)
+
+ @register_instruction(name="recv")
+ def recv_forward():
+ topo = builder.topo
+ recv_comms = topo.recv_tables[builder.stage_id]
+ recv_comm = recv_comms[0]
+ recv_tensor = torch.empty((1, 1, 8), requires_grad=True, dtype=torch.float32).cuda()
+ mapping_group = recv_comm.cur_mesh.get_mapping_rank(recv_comm.peer_mesh)
+ recv_tensor = recv(recv_tensor, mapping_group, torch.distributed.distributed_c10d._get_default_group())
+ recv_dtensor = cross_mesh_recv(recv_comm, recv_tensor)
+ return recv_dtensor
+
+ @register_instruction(name="forward")
+ def forward():
+ model = builder.model
+ last_data = builder.last
+ activation = model(last_data)
+ return activation
+
+ @register_instruction(name="load_data")
+ def load_data():
+ dataloader = builder.dataloader
+ pos = builder.pos
+ data_id = pos // 3
+ return dataloader[data_id]
+
+ instruction_list = {
+ 0: "load_data,forward,send,load_data,forward,send,load_data,forward,send",
+ 1: "recv,forward,recv,forward,recv,forward",
+ }
+ builder.build_from_dict(instruction_list)
+ builder.draw_instructions()
+
+ deferred_model = deferred_init(FourMLP, hidden=8)
+ parser = PipeParser()
+ pipe_config = PipelineParallelPlan(
+ num_stages=2,
+ split_method=PipelineSplitMethodType.MANUAL,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(4)],
+ split_points=["mlp2", "mlp4"],
+ )
+ parser_args = {"shard_plan": sharding_plan}
+ graph = parser.parse(deferred_model, pipe_config, **parser_args)
+ root_graph = parser.partition_stage(deferred_model, graph, pipe_config)
+
+ if self.rank in [0, 1]:
+ pipeline_stage_id = 0
+ elif self.rank in [2, 3]:
+ pipeline_stage_id = 1
+
+ stage_model_pp = root_graph.get_submodule(f"stage{pipeline_stage_id}")
+
+ stage_model_pp_tp = parallelize_module(
+ stage_model_pp,
+ global_mesh.get_submesh([1]),
+ sharding_plan,
+ factory=False,
+ )
+
+ global_tp_meshes = [
+ DeviceMesh("cuda", [0, 1], _validate_mesh=False),
+ DeviceMesh("cuda", [2, 3], _validate_mesh=False),
+ ]
+ np_deps, p2p_index_mapping = _generate_stage_dependencies(root_graph, 2, 1)
+
+ deps = StageDeps(np_deps, global_tp_meshes, [stage_model_pp_tp], p2p_index_mapping)
+ builder.topo = deps
+ builder.model = stage_model_pp_tp
+ builder.stage_id = pipeline_stage_id
+
+ data_iterator = []
+ if self.rank in [0, 1]:
+ for i in range(3):
+ data = torch.zeros((1, 1, 8), dtype=torch.float32) + i
+ data_iterator.append(data)
+ builder.dataloader = data_iterator
+ outputs = builder.run(pipeline_stage_id)
+ if self.rank in [2, 3]:
+ print(outputs)
+
+ def _define_instructions(self):
+ @register_instruction(name="send")
+ def send_forward(*args, **kwargs):
+ send_data = args[0]
+ dst = builder.send_dist
+ send(maybe_tensor(send_data), dst, torch.distributed.distributed_c10d._get_default_group())
+ return (send_data,), {}
+
+ @register_instruction(name="recv")
+ def recv_forward(*args, **kwargs):
+ dst = builder.recv_dist
+ recv_tensor = torch.empty_like(args[0])
+ recv_tensor = recv(recv_tensor, dst, torch.distributed.distributed_c10d._get_default_group())
+ return (recv_tensor,), {}
+
+ # instruction should be stateless.
+ @register_instruction(name="forward")
+ def forward(model, *args, **kwargs):
+ activation = model(*args, **kwargs)
+ return (activation,), {}
+
+ instruction_list = {
+ 0: "forward,send",
+ 1: "recv,forward",
+ }
+
+ builder.build_from_dict(instruction_list)
+ builder.draw_instructions()
+
+ def _parallelize_model(self, global_mesh):
+ deferred_model = deferred_init(FourMLP, hidden=8)
+ parser = PipeParser()
+ pipe_config = PipelineParallelPlan(
+ num_stages=2,
+ split_method=PipelineSplitMethodType.MANUAL,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(4)],
+ split_points=["mlp2", "mlp4"],
+ )
+ parser_args = {"shard_plan": sharding_plan}
+ graph = parser.parse(deferred_model, **parser_args)
+ root_graph = parser.partition_stage(deferred_model, graph, pipe_config)
+
+ if self.rank in [0, 1]:
+ pipeline_stage_id = 0
+ elif self.rank in [2, 3]:
+ pipeline_stage_id = 1
+
+ stage_model_pp = root_graph.get_submodule(f"stage{pipeline_stage_id}")
+
+ tp_submesh = global_mesh.get_submesh([1])
+ stage_model_pp_tp = parallelize_module(
+ stage_model_pp,
+ tp_submesh,
+ sharding_plan,
+ factory=False,
+ )
+
+ return stage_model_pp_tp, root_graph
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/parallel/pipeline/instruction/test_zerobubble.py b/test/parallel/pipeline/instruction/test_zerobubble.py
new file mode 100644
index 0000000..3d1f6a3
--- /dev/null
+++ b/test/parallel/pipeline/instruction/test_zerobubble.py
@@ -0,0 +1,103 @@
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import unittest
+from four_mlp import FourMLP
+import torch
+import torch.optim as optim
+
+
+class ZeroBubbleTest(unittest.TestCase):
+ def test_split_backward(self):
+ """
+ Tests how to separately compute activation gradient and parameter gradient
+ in zero bubble pipeline schedule.
+ """
+ model = FourMLP(hidden=8)
+
+ stage0 = model.mlp1
+ stage1 = model.mlp2
+
+ input = torch.randn(8, 8, requires_grad=True)
+
+ stage0_out = stage0(input)
+ stage1_out = stage1(stage0_out)
+ loss = stage1_out.sum()
+ optimizer = optim.SGD(model.parameters(), lr=0.01)
+
+ # calc zbv grad
+ optimizer.zero_grad()
+
+ # calc activation grad (B)
+ activation_grad_output = torch.autograd.grad(loss, stage1_out, retain_graph=True)
+ activation_grad_stage1 = torch.autograd.grad(
+ stage1_out,
+ stage0_out,
+ grad_outputs=activation_grad_output,
+ retain_graph=True,
+ allow_unused=True,
+ materialize_grads=True,
+ )
+ activation_grad_stage0 = torch.autograd.grad(
+ stage0_out,
+ input,
+ grad_outputs=activation_grad_stage1,
+ retain_graph=True,
+ allow_unused=True,
+ materialize_grads=True,
+ )
+
+ # calc params grad (W)
+ nps1 = {}
+ for key, value in stage1.named_parameters():
+ nps1[key] = value
+
+ nps0 = {}
+ for key, value in stage0.named_parameters():
+ nps0[key] = value
+
+ parameters_grad_stage1 = torch.autograd.grad(
+ stage1_out,
+ nps1.values(),
+ grad_outputs=activation_grad_output,
+ retain_graph=True,
+ allow_unused=True,
+ materialize_grads=True,
+ )
+ parameters_grad_stage0 = torch.autograd.grad(
+ stage0_out,
+ nps0.values(),
+ grad_outputs=activation_grad_stage1,
+ retain_graph=True,
+ allow_unused=True,
+ materialize_grads=True,
+ )
+
+ # calc normal grad
+ optimizer.zero_grad()
+ loss.backward()
+
+ # validate grads are same
+ print("fc1.weight.grad", stage1.fc1.weight.grad)
+ print("fc2.weight.grad", stage1.fc2.weight.grad)
+
+ torch.testing.assert_close(stage1.fc1.weight.grad, parameters_grad_stage1[0])
+ torch.testing.assert_close(stage1.fc2.weight.grad, parameters_grad_stage1[1])
+ torch.testing.assert_close(stage0.fc1.weight.grad, parameters_grad_stage0[0])
+ torch.testing.assert_close(stage0.fc2.weight.grad, parameters_grad_stage0[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/vescale/devicemesh_api/__init__.py b/vescale/devicemesh_api/__init__.py
index a9c8ea7..ed86ee4 100644
--- a/vescale/devicemesh_api/__init__.py
+++ b/vescale/devicemesh_api/__init__.py
@@ -15,4 +15,4 @@
#
################################################################################
-from .api import VESCALE_DEVICE_MESH
+from .api import VESCALE_DEVICE_MESH, VeDeviceMesh
diff --git a/vescale/dtensor/__init__.py b/vescale/dtensor/__init__.py
index 2c2763d..2899061 100644
--- a/vescale/dtensor/__init__.py
+++ b/vescale/dtensor/__init__.py
@@ -25,7 +25,7 @@
)
from vescale.dtensor.device_mesh import DeviceMesh, mesh_resources
from vescale.dtensor.api import normalize_placements
-from vescale.dtensor.dtensor import DTensor
+from vescale.dtensor.dtensor import DTensor, make_dtensor
from vescale.dtensor.ops.utils import normalize_to_torch_size
from vescale.dtensor.placement_types import DTensorSpec, Placement, Replicate, TensorMeta
diff --git a/vescale/dtensor/_diff.py b/vescale/dtensor/_diff.py
index fed72c5..d60fd0d 100644
--- a/vescale/dtensor/_diff.py
+++ b/vescale/dtensor/_diff.py
@@ -21,8 +21,9 @@
import logging
-
VESCALE_DISABLE_REDISTRIBUTE = os.environ.get("VESCALE_DISABLE_REDISTRIBUTE", "1") == "1"
+VESCALE_DUMMY_P2P = os.environ.get("VESCALE_DUMMY_P2P", "0") == "1"
+VESCALE_DUMP_INSTRUCTION = os.environ.get("VESCALE_DUMP_INSTRUCTION", "0") == "1"
global VESCALE_SHARDING_SUGGETSION
VESCALE_SHARDING_SUGGETSION = []
diff --git a/vescale/dtensor/dtensor.py b/vescale/dtensor/dtensor.py
index ef2aeeb..719d0b9 100644
--- a/vescale/dtensor/dtensor.py
+++ b/vescale/dtensor/dtensor.py
@@ -493,15 +493,15 @@ def to_local(
grad_placements: Optional[Sequence[Placement]] = None,
async_output: bool = True,
) -> torch.Tensor:
-
# NOTE: moving impl code here for performance, as here is on the critial path but api function is NEVER used
if grad_placements is not None:
grad_placements: Tuple[Placement] = normalize_placements(
grad_placements, self._spec.mesh.ndim, tensor_ndim=self.ndim
)
-
+
return _ToTorchTensor.apply(self, grad_placements, async_output)
+
def redistribute(
self,
device_mesh: Optional[DeviceMesh] = None,
@@ -539,3 +539,24 @@ def tolist(self) -> Union[List, Number]:
- This operation is not dispatched but a torch function.
"""
return self._local_tensor.tolist()
+
+
+def make_dtensor(
+ local_tensor: torch.Tensor,
+ device_mesh: DeviceMesh,
+ placements: Tuple[Placement, ...],
+ *,
+ shape: torch.Size,
+ dtype: torch.dtype,
+ requires_grad: bool,
+ stride: Tuple[int, ...],
+):
+ return DTensor(
+ local_tensor,
+ device_mesh,
+ placements,
+ shape=shape,
+ dtype=dtype,
+ requires_grad=requires_grad,
+ stride=stride,
+ )
diff --git a/vescale/dtensor/placement_types.py b/vescale/dtensor/placement_types.py
index 5332102..bb4f474 100644
--- a/vescale/dtensor/placement_types.py
+++ b/vescale/dtensor/placement_types.py
@@ -56,9 +56,9 @@ def serialize_from_tensor(tensor: torch.Tensor):
elif tensor[0] == 1:
return Partial()
elif tensor[0] == 2:
- return Shard(dim=tensor[1])
+ return Shard(dim=tensor[1].item())
elif tensor[0] == 3:
- return InterleavedShard(dim=tensor[1], interleaved_size=tensor[2])
+ return InterleavedShard(dim=tensor[1].item(), interleaved_size=tensor[2].item())
class Shard(Placement):
diff --git a/vescale/engine/__init__.py b/vescale/engine/__init__.py
new file mode 100644
index 0000000..78249da
--- /dev/null
+++ b/vescale/engine/__init__.py
@@ -0,0 +1,18 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from .pipe import PipeEngine
diff --git a/vescale/engine/pipe.py b/vescale/engine/pipe.py
new file mode 100644
index 0000000..4cd66c8
--- /dev/null
+++ b/vescale/engine/pipe.py
@@ -0,0 +1,237 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from collections import defaultdict
+from typing import Any, List, Callable
+from vescale.pipe.pipe_stage import PipeModule
+from vescale.plan.pipeline_parallel import PipelineParallelPlan
+from vescale.pipe.pipe_emmiter import ScheduleEngine, StageDeps
+from vescale.devicemesh_api import VeDeviceMesh
+from vescale.plan.spec import PipelineScheduleType
+from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
+from copy import deepcopy
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import os
+
+
+class PipeEngine:
+ def __init__(
+ self,
+ module: PipeModule,
+ global_mesh: VeDeviceMesh,
+ loss_fn: Callable,
+ config: PipelineParallelPlan,
+ ):
+ """
+ Training engine for pipeline parallelism and multi-dimensional
+ parallelism that underlies pipeline parallelism (distributed optimizer, data parallel,
+ tensor model parallel, and sequence parallel, etc).
+ The training engine is responsible for materializes stage partitioning, module registration,
+ training, and optimizer synchronization.
+ """
+ self.module = module
+ self.virtual_chunks_per_stage = config.virtual_chunks
+ self.engine_config = config
+ self.optimizer = self.module.get_optimizer
+ self.lr_scheduler = self.module.get_lr_scheduler
+ self.global_mesh = global_mesh
+ if isinstance(loss_fn, nn.Module):
+ self.loss_fn = loss_fn
+ else:
+ try:
+ self.loss_fn = deepcopy(loss_fn.__func__)
+ except: # noqa: E722
+ self.loss_fn = loss_fn
+ self.schedule_engine = None
+ self.reuse_comm_shape = self.engine_config.reuse_p2p_tensor_shape
+ if self.reuse_comm_shape:
+ os.environ["REUSE_COMM_SHAPE"] = "1"
+ if (
+ self.engine_config.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B
+ and self.virtual_chunks_per_stage == 1
+ ):
+ print("[warning]: #virtual pipeline chunks is 1. Falling back to simple 1F1B schedule.")
+ self.engine_config.schedule_type = PipelineScheduleType.SIMPLE_1F1B
+ self.schedule_type = self.engine_config.schedule_type
+
+ def build_schedule(self, minibatches, data_shape=None):
+ """
+ Build pipeline parallel training schedules.
+ """
+ meshes = self.global_mesh.get_global_tensor_parallel_meshes()
+ dp_rank, tp_rank = self.global_mesh.get_data_parallel_rank(), self.global_mesh.get_tensor_parallel_rank()
+ tp_meshes_dict = defaultdict(list)
+
+ def _locate_tp_mesh(_rank):
+ for tp_mesh in meshes:
+ if _rank in tp_mesh.mesh.tolist():
+ return tp_mesh
+ else:
+ raise ValueError("TP submesh not found.")
+
+ for _rank in range(torch.distributed.get_world_size()):
+ _coordinate = self.global_mesh.get_strategy_coordinate(_rank)
+ tp_mesh = _locate_tp_mesh(_rank)
+ _dp_rank, _tp_rank = _coordinate[1], _coordinate[2]
+ tp_meshes_dict[(_dp_rank, _tp_rank)].append(tp_mesh)
+
+ new_meshes = tp_meshes_dict[(dp_rank, tp_rank)]
+ meshes = new_meshes
+ first_stage_rank = self.global_mesh.get_strategy_coordinate(local_rank=0)[0]
+ # FIXME: the input can either be PipeModule, or a sequence of DDP modules? In the latter case, how to get stage dependency
+ pipe_module = self.module
+ stage_dep_matrix, p2p_index_mapping = pipe_module.stage_deps, pipe_module.p2p_index_mapping
+ stage_dependency = StageDeps(
+ dep=stage_dep_matrix,
+ meshes=meshes,
+ vpp_module_list=pipe_module,
+ p2p_index_mapping=p2p_index_mapping,
+ )
+ num_minibatches = self._align_num_batches(first_stage_rank, len(minibatches))
+ # TODO: insert shape inference
+ batch_p2p_comm = self.engine_config.batch_p2p_comm
+ # if on interleaved 1f1b schedule, set batch_p2p_comm to False to execute p2p communication
+ schedule_type = self.schedule_type
+ if schedule_type in [PipelineScheduleType.INTERLEAVED_1F1B, PipelineScheduleType.ZERO_BUBBLE]:
+ data_iterator = [iter(minibatches) for _ in range(self.virtual_chunks_per_stage)]
+ batch_p2p_comm = False
+ elif schedule_type == PipelineScheduleType.SIMPLE_1F1B:
+ data_iterator = minibatches
+ else:
+ raise NotImplementedError(f"Schedule {schedule_type} not implemented yet.")
+ return ScheduleEngine(
+ stage_dependency,
+ meshes,
+ schedule_type,
+ num_minibatches,
+ data_iterator=data_iterator,
+ stage_id=self.global_mesh.get_pipeline_parallel_rank(),
+ shape=data_shape,
+ dtype=self.engine_config.p2p_tensor_dtype,
+ num_chunks=self.virtual_chunks_per_stage,
+ input_shapes=None,
+ input_shapes_unpad=None,
+ # send_dtypes_map=self.module.recv_dtypes_dict,
+ overlap_p2p_comm=self.engine_config.overlap_p2p_comm,
+ batch_p2p_comm=batch_p2p_comm,
+ loss_fn=self.loss_fn,
+ global_mesh=self.global_mesh,
+ forward_only=self.engine_config.forward_only,
+ )
+
+ def forward_backward(
+ self,
+ minibatch,
+ reuse_schedule=False,
+ data_shape=None,
+ debug_mode: bool = False,
+ ):
+ """
+ Execute the pipeline schedule to complete forward,
+ backward, and gradient step of one minibatch.
+
+ Invoke Scheduler's execute_pipeline() to run a minibatch.
+ """
+ assert isinstance(minibatch, List), "Input must be a list of microbatches"
+ if reuse_schedule:
+ if self.schedule_engine is None:
+ schedule_engine = self.build_schedule(minibatch, data_shape=data_shape)
+ else:
+ schedule_engine = self.schedule_engine
+ schedule_engine.set_data_iterator(minibatch, data_shape=data_shape)
+ else:
+ schedule_engine = self.build_schedule(minibatch, data_shape=data_shape)
+ # returns model output tensors and losses per microbatch
+ return ScheduleEngine.execute(schedule_engine, debug_mode=debug_mode)
+
+ def forward(self, *args: Any, **kwargs: Any):
+ raise ValueError("Forward is done in PipeEngine.forward_backward()!")
+
+ def __call__(self, *args: Any, **kwargs: Any):
+ return self.forward_backward(*args, **kwargs)
+
+ def backward(self, *args: Any, **kwargs: Any):
+ raise ValueError("Backward is done in PipeEngine.forward_backward()!")
+
+ @property
+ def get_optimizer(self):
+ """
+ Return this stage's optimizer.
+ """
+ return self.optimizer
+
+ @property
+ def get_lr_scheduler(self):
+ return self.lr_scheduler
+
+ def zero_grad_buffer(self, zero_buffer: bool = True):
+ for vpp_module in self.module.stage_modules.values():
+ if isinstance(vpp_module, DDP):
+ vpp_module.zero_grad_buffer(zero_buffer)
+
+ def finish_grad_sync(self):
+ for vpp_module in self.module.stage_modules.values():
+ if isinstance(vpp_module, DDP):
+ vpp_module.finish_grad_sync()
+
+ def train(self, mode: bool = True):
+ for vpp_module in self.module.stage_modules.values():
+ vpp_module.train(mode)
+
+ def eval(self):
+ for vpp_module in self.module.stage_modules.values():
+ vpp_module.eval()
+
+ def parameters(self, including_frozen=False):
+ """
+ Return meta information of the entire model's
+ parameters.
+ """
+ if including_frozen:
+ return self.module.parameters()
+ else:
+ return filter(lambda p: p.requires_grad, self.module.parameters())
+
+ def sync_shared_params(self, group_id: int = 0, share_params=True) -> None:
+ """
+ Synchronize gradients and weights among groups of specified units, dictated by
+ "partition_units" in PipelineConfig. Typically, this function is used for
+ synchronizing gradients and weights of embeddings layers in Transformer-based
+ architecture.
+ Args:
+ group_id (int): specify groups of modules across stages to synchronize. Default by 0.
+ share_params (bool): if True, sync weight parameters; otherwise, share gradients.
+ """
+ local_rank = dist.distributed_c10d.get_rank()
+ tp_coordinate = self.module.device_mesh_management.get_tensor_parallel_rank()
+ if self.module.shared_module_mapping and local_rank in dist.distributed_c10d.get_process_group_ranks(
+ self.module.shared_module_process_groups[group_id][tp_coordinate]
+ ):
+ self.module.sync_shared_params(self.global_mesh, group_id=group_id, share_params=share_params)
+
+ def _align_num_batches(self, first_stage_rank, batches):
+ """
+ Aligns all ranks must have the same number of mini-batches as rank 0.
+ """
+ num_batches = torch.tensor([batches], dtype=torch.int64).cuda(dist.get_rank())
+ dist.broadcast(num_batches, src=first_stage_rank)
+ is_consistent = num_batches.item() == batches
+ if not is_consistent:
+ batches = num_batches.item()
+ return batches
diff --git a/vescale/initialize/__init__.py b/vescale/initialize/__init__.py
index 8d130f9..e410817 100644
--- a/vescale/initialize/__init__.py
+++ b/vescale/initialize/__init__.py
@@ -15,4 +15,4 @@
#
################################################################################
-from .deferred_init import deferred_init, is_deferred, materialize_dtensor, materialize_dparameter
+from .deferred_init import deferred_init, is_deferred, materialize_dtensor, materialize_dparameter, materialize_module
diff --git a/vescale/initialize/deferred_init.py b/vescale/initialize/deferred_init.py
index fdab336..5498ca8 100644
--- a/vescale/initialize/deferred_init.py
+++ b/vescale/initialize/deferred_init.py
@@ -18,6 +18,7 @@
from torchdistx.deferred_init import deferred_init as _deferred_init
from torchdistx.deferred_init import is_deferred as _is_deferred
from torchdistx.deferred_init import _C
+ from torchdistx.deferred_init import materialize_module as _materialize_module
IMPORT_DEFER = True
except: # noqa: E722
@@ -81,6 +82,19 @@ def is_deferred(obj: Union[torch.Tensor, nn.Parameter, nn.Module]) -> bool:
return _is_deferred(obj)
+def materialize_module(obj: nn.Module):
+ """Materializes deferred initialized ``nn.Module`` object.
+
+ Args:
+ obj:
+ An ``nn.Module`` instance.
+ """
+ if not IMPORT_DEFER:
+ return False
+
+ _materialize_module(obj)
+
+
def materialize_dtensor(
tensor: torch.Tensor,
device_mesh: Optional[DeviceMesh] = None,
diff --git a/vescale/model/base_gpt/__init__.py b/vescale/model/base_gpt/__init__.py
new file mode 100644
index 0000000..f3b869e
--- /dev/null
+++ b/vescale/model/base_gpt/__init__.py
@@ -0,0 +1,5 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
diff --git a/vescale/model/base_gpt/attention.py b/vescale/model/base_gpt/attention.py
new file mode 100644
index 0000000..66c615d
--- /dev/null
+++ b/vescale/model/base_gpt/attention.py
@@ -0,0 +1,531 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+import math
+
+import torch
+from torch import nn
+
+from vescale.dtensor.api import from_local
+from vescale.model.base_gpt.checkpoint import checkpoint
+from vescale.model.base_gpt.enums import AttnMaskType, AttnType
+from vescale.model.base_gpt.fuse_softmax import FusedScaleMaskSoftmax
+from vescale.model.base_gpt.rotary import apply_rotary_pos_emb
+from vescale.model.random import get_cuda_rng_tracker
+from vescale.model.utils import attention_mask_func, divide
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func
+except ImportError:
+ try:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
+ except ImportError:
+ flash_attn_unpadded_func = None
+
+try:
+ from einops import rearrange
+except ImportError:
+ rearrange = None
+
+
+class CoreAttention(nn.Module):
+ def __init__(self, layer_number, config, attn_mask_type=AttnMaskType.padding):
+ super().__init__()
+ self.fp16 = config.fp16
+ self.bf16 = config.bf16
+
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+ if self.apply_query_key_layer_scaling:
+ self.attention_softmax_in_fp32 = True
+ self.layer_number = max(1, layer_number)
+ self.attn_mask_type = attn_mask_type
+ self.sequence_parallel = config.sequence_parallel
+
+ self.config = config
+
+ # Per attention head and per partition values.
+
+ coeff = None
+ if self.apply_query_key_layer_scaling:
+ coeff = self.layer_number
+
+ self.scale_mask_softmax = FusedScaleMaskSoftmax(
+ self.fp16,
+ self.bf16,
+ self.attn_mask_type,
+ config.masked_softmax_fusion,
+ attention_mask_func,
+ self.attention_softmax_in_fp32,
+ coeff,
+ )
+
+ # Dropout. Note that for a single iteration, this layer will generate
+ # different outputs on different number of parallel partitions but
+ # on average it should not be partition dependent.
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
+
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
+ # ===================================
+ # Raw attention scores. [b, np, s, s]
+ # ===================================
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
+
+ q_t = query_layer.transpose(0, 1)
+ # preallocting input tensor: [b * np, sq, sk]
+ matmul_input_buffer = torch.empty(
+ (output_size[0] * output_size[1] // query_layer._spec.mesh.size(), output_size[2], output_size[3]),
+ dtype=query_layer.dtype,
+ device=query_layer.device,
+ )
+ matmul_input_buffer = from_local(matmul_input_buffer, query_layer._spec.mesh, q_t._spec.placements)
+
+ # Raw attention scores. [b * np, sq, sk]
+ projection_size = self.config.kv_channels * self.config.num_attention_heads
+ hidden_size_per_attention_head = divide(projection_size, self.config.num_attention_heads)
+ norm_factor = math.sqrt(hidden_size_per_attention_head)
+ norm_factor *= self.layer_number
+ matmul_result = torch.baddbmm(
+ matmul_input_buffer,
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
+ beta=0.0,
+ alpha=(1.0 / norm_factor),
+ )
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+ # ===========================
+ # Attention probs and dropout
+ # ===========================
+
+ # attention scores and attention mask [b, np, sq, sk]
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ if not self.sequence_parallel:
+ with get_cuda_rng_tracker().fork():
+ attention_probs = self.attention_dropout(attention_probs)
+ else:
+ attention_probs = self.attention_dropout(attention_probs)
+
+ attention_probs = from_local(attention_probs, attention_scores._spec.mesh, attention_scores._spec.placements)
+
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sk, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+
+ # change view [sk, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
+
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ context_layer = context_layer.view(*context_layer.size()[:-2], -1)
+
+ return context_layer
+
+
+class FlashSelfAttention(nn.Module):
+ """Implement the scaled dot product attention with softmax.
+ Arguments
+ ---------
+ softmax_scale: The temperature to use for the softmax attention.
+ (default: 1/sqrt(d_keys) where d_keys is computed at
+ runtime)
+ attention_dropout: The dropout rate to apply to the attention
+ (default: 0.0)
+ """
+
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
+ super().__init__()
+ assert flash_attn_unpadded_func is not None, (
+ "Please install FlashAttention first, " "e.g., with pip install flash-attn"
+ )
+ assert rearrange is not None, "Please install einops first, e.g., with pip install einops"
+ self.causal = causal
+ self.softmax_scale = softmax_scale
+ self.dropout_p = attention_dropout
+
+ def forward(self, q, k, v):
+ """Implements the multihead softmax attention.
+ Arguments
+ ---------
+ q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
+ """
+
+ assert all(i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))
+ assert all(i.is_cuda for i in (q, k, v))
+
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
+ seqlen_k = k.shape[1]
+
+ q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
+ cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device)
+
+ if self.training:
+ # during training q,k,v always have same seqlen
+ assert seqlen_k == seqlen_q
+
+ is_causal = self.causal
+ cu_seqlens_k = cu_seqlens_q
+ dropout_p = self.dropout_p
+ else:
+ # turn off FA causal mask after first inference autoregressive iteration
+ # only on first autoregressive step q,k,v have same seqlen
+ is_causal = seqlen_q == seqlen_k
+ cu_seqlens_k = torch.arange(
+ 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=q.device
+ )
+ dropout_p = 0
+
+ output = flash_attn_unpadded_func(
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ seqlen_q,
+ seqlen_k,
+ dropout_p,
+ softmax_scale=self.softmax_scale,
+ causal=is_causal,
+ )
+
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
+ return output
+
+
+class ParallelAttention(nn.Module):
+ """Parallel self-attention layer abstract class.
+
+ Self-attention layer takes input with size [s, b, h]
+ and returns output of the same size.
+ """
+
+ def __init__(
+ self,
+ config,
+ layer_number,
+ attention_type=AttnType.self_attn,
+ attn_mask_type=AttnMaskType.padding,
+ ):
+ super().__init__()
+ self.layer_number = max(1, layer_number)
+ self.attention_type = attention_type
+ self.attn_mask_type = attn_mask_type
+ self.config = config
+
+ self.group_query_attention = config.group_query_attention
+ self.num_query_groups = config.num_query_groups
+
+ query_projection_size = config.kv_channels * config.num_attention_heads
+ if self.group_query_attention:
+ kv_projection_size = config.kv_channels * config.num_query_groups
+ else:
+ kv_projection_size = config.kv_channels * config.num_attention_heads
+
+ self.use_flash_attn = (
+ config.use_flash_attn
+ and attention_type == AttnType.self_attn
+ and self.attn_mask_type == AttnMaskType.causal
+ )
+ if self.use_flash_attn:
+ if flash_attn_unpadded_func is None:
+ raise ImportError("FlashAttention is not installed, please install with " "pip install flash-attn")
+ assert attention_type == AttnType.self_attn, (
+ "FlashAttention code path only supports " "self-attention for now"
+ )
+ assert self.attn_mask_type == AttnMaskType.causal, (
+ "FlashAttention code path only " "supports causal mask for now"
+ )
+ if rearrange is None:
+ raise ImportError("einops is not installed, please install with pip install einops")
+
+ # Strided linear layer.
+ if attention_type == AttnType.self_attn:
+ self.query_key_value = nn.Linear(
+ config.hidden_size,
+ query_projection_size + 2 * kv_projection_size,
+ bias=config.add_bias_linear,
+ )
+ config.init_method(self.query_key_value.weight)
+ else:
+ assert attention_type == AttnType.cross_attn
+
+ if self.group_query_attention:
+ raise NotImplementedError("Grouped query attention not implemented for cross-attention.")
+ assert query_projection_size == kv_projection_size
+
+ self.query = nn.Linear(
+ config.hidden_size,
+ query_projection_size,
+ bias=self.add_bias_linear,
+ )
+ config.init_method(self.query.weight)
+ self.key_value = nn.Linear(config.hidden_size, 2 * kv_projection_size, bias=config.add_bias_linear)
+ config.init_method(self.key_value.weight)
+
+ self.core_attention = CoreAttention(self.layer_number, config, self.attn_mask_type)
+ self.checkpoint_core_attention = config.recompute_granularity == "selective"
+
+ if self.use_flash_attn:
+ self.core_attention_flash = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout)
+
+ # Output.
+ self.dense = nn.Linear(
+ query_projection_size,
+ config.hidden_size,
+ bias=False,
+ )
+ self.dense_bias = torch.empty(config.hidden_size) if config.add_bias_linear else None
+ config.output_layer_init_method(self.dense.weight)
+
+ def _checkpointed_attention_forward(self, query_layer, key_layer, value_layer, attention_mask, rotary_pos_emb=None):
+ """Forward method with activation checkpointing."""
+
+ def custom_forward(*inputs):
+ query_layer = inputs[0]
+ key_layer = inputs[1]
+ value_layer = inputs[2]
+ attention_mask = inputs[3]
+ output_ = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
+ return output_
+
+ q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None else rotary_pos_emb
+
+ hidden_states = checkpoint(
+ custom_forward, False, query_layer, key_layer, value_layer, attention_mask, q_pos_emb, k_pos_emb
+ )
+
+ return hidden_states
+
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads):
+ query_projection_size = self.config.kv_channels * self.config.num_attention_heads
+ hidden_size_per_attention_head = divide(query_projection_size, self.config.num_attention_heads)
+ return torch.empty(
+ inference_max_sequence_len,
+ batch_size,
+ num_attention_heads,
+ hidden_size_per_attention_head,
+ dtype=self.params_dtype,
+ device=torch.cuda.current_device(),
+ )
+
+ def forward(
+ self, hidden_states, attention_mask=None, encoder_output=None, inference_params=None, rotary_pos_emb=None
+ ):
+ # hidden_states: [sq, b, h]
+
+ # Per attention head and per partition values.
+ world_size = self.hidden_states._spec.mesh.size() # TP
+ query_projection_size = self.config.kv_channels * self.config.num_attention_heads
+ self.hidden_size_per_attention_head = divide(query_projection_size, self.config.num_attention_heads)
+ self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
+
+ if self.group_query_attention:
+ self.num_query_groups_per_partition = divide(self.num_query_groups, world_size)
+ else:
+ self.num_query_groups_per_partition = self.num_attention_heads_per_partition
+
+ self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
+
+ # =================================================
+ # Pre-allocate memory for key-values for inference.
+ # =================================================
+ is_first_step = False
+ if inference_params:
+ if self.layer_number not in inference_params.key_value_memory_dict:
+ inf_max_seq_len = inference_params.max_sequence_length
+ inf_max_batch_size = inference_params.max_batch_size
+ inference_key_memory = self._allocate_memory(
+ inf_max_seq_len, inf_max_batch_size, self.num_query_groups_per_partition
+ )
+ inference_value_memory = self._allocate_memory(
+ inf_max_seq_len, inf_max_batch_size, self.num_query_groups_per_partition
+ )
+
+ inference_params.key_value_memory_dict[self.layer_number] = (
+ inference_key_memory,
+ inference_value_memory,
+ )
+ is_first_step = True
+ else:
+ inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[self.layer_number]
+
+ # =====================
+ # Query, Key, and Value
+ # =====================
+ if self.attention_type == AttnType.self_attn:
+ # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
+ mixed_x_layer, _ = self.query_key_value(hidden_states)
+
+ # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
+ self.num_query_groups_per_partition,
+ (
+ (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
+ * self.hidden_size_per_attention_head
+ ),
+ )
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
+ (query_layer, key_layer, value_layer) = torch.split(
+ mixed_x_layer,
+ [
+ (
+ self.num_attention_heads_per_partition
+ // self.num_query_groups_per_partition
+ * self.hidden_size_per_attention_head
+ ),
+ self.hidden_size_per_attention_head,
+ self.hidden_size_per_attention_head,
+ ],
+ dim=3,
+ )
+
+ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
+ query_layer = query_layer.view(
+ query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head
+ )
+ else:
+ # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
+ mixed_kv_layer, _ = self.key_value(encoder_output)
+
+ # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
+ new_tensor_shape = mixed_kv_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 2 * self.hidden_size_per_attention_head,
+ )
+ mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
+
+ # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
+ (key_layer, value_layer) = split_tensor_along_last_dim(mixed_kv_layer, 2)
+
+ # Attention head [sq, b, h] --> [sq, b, hp]
+ query_layer, _ = self.query(hidden_states)
+ # [sq, b, hp] --> [sq, b, np, hn]
+ new_tensor_shape = query_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ query_layer = query_layer.view(*new_tensor_shape)
+
+ # ==================================
+ # Adjust key and value for inference
+ # ==================================
+
+ # duplicate the pos_emb for self attention
+ if rotary_pos_emb is not None:
+ if isinstance(rotary_pos_emb, tuple):
+ rotary_pos_emb = rotary_pos_emb
+ else:
+ rotary_pos_emb = (rotary_pos_emb,) * 2
+
+ if inference_params:
+ batch_start = inference_params.batch_size_offset
+ batch_end = batch_start + key_layer.size(1)
+ assert batch_end <= inference_key_memory.size(1)
+ sequence_start = inference_params.sequence_len_offset
+ sequence_end = sequence_start + key_layer.size(0)
+ assert sequence_end <= inference_key_memory.size(0)
+ # Copy key and values.
+ inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
+ inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
+ key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
+ value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
+
+ # adjust the key rotary positional embedding
+ if rotary_pos_emb is not None:
+ q_pos_emb, k_pos_emb = rotary_pos_emb
+ # need to cross check this condition during inference
+ # if not set_inference_key_value_memory:
+ if not is_first_step:
+ # In inference, we compute one token at a time.
+ # Select the correct positional embedding
+ # (only the last token in the sequence)
+ q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
+ else:
+ # In the first forward pass of inference,
+ # we use the entire provided prefix.
+ # q_pos_emb here has the rope embeddings of the entire
+ # prefix + to-be-generated output so
+ # we slice to just the prefix.
+ q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
+ k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
+ rotary_pos_emb = (q_pos_emb, k_pos_emb)
+
+ # ==================================
+ # core attention computation
+ # ==================================
+
+ # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
+ if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
+ key_layer = key_layer.repeat_interleave(
+ self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
+ )
+ value_layer = value_layer.repeat_interleave(
+ self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
+ )
+
+ # apply relative positional encoding (rotary embedding)
+ if rotary_pos_emb is not None:
+ q_pos_emb, k_pos_emb = rotary_pos_emb
+ query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
+ key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
+ # TODO, can apply positional embedding to value_layer so it has
+ # absolute positional embedding.
+ # otherwise, only relative positional embedding takes effect
+ # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
+
+ if not self.use_flash_attn:
+ if self.checkpoint_core_attention:
+ context_layer = self._checkpointed_attention_forward(
+ query_layer, key_layer, value_layer, attention_mask
+ )
+ else:
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
+ else:
+ q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (query_layer, key_layer, value_layer))
+ if not self.sequence_parallel:
+ with get_cuda_rng_tracker().fork():
+ context_layer = self.core_attention_flash(q, k, v)
+ else:
+ context_layer = self.core_attention_flash(q, k, v)
+ context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output = self.dense(context_layer)
+
+ return output, self.dense_bias
diff --git a/vescale/model/base_gpt/checkpoint.py b/vescale/model/base_gpt/checkpoint.py
new file mode 100644
index 0000000..d6f6cd7
--- /dev/null
+++ b/vescale/model/base_gpt/checkpoint.py
@@ -0,0 +1,133 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+from typing import Any, Tuple
+
+import torch
+
+from vescale.dtensor.dtensor import DTensor
+from vescale.dtensor.placement_types import Replicate
+from vescale.model.random import _set_cuda_rng_state, get_cuda_rng_tracker
+
+
+def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
+ if isinstance(inputs, tuple):
+ out = []
+ for inp in inputs:
+ if not isinstance(inp, torch.Tensor):
+ out.append(inp)
+ continue
+
+ x = inp.detach()
+ x.requires_grad = inp.requires_grad
+ out.append(x)
+ return tuple(out)
+ else:
+ raise RuntimeError(
+ "Only tuple of tensors is supported. Got Unsupported input type: ",
+ type(inputs).__name__,
+ )
+
+
+def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
+ """Break a tensor into equal 1D chunks across tensor parallel ranks.
+
+ Returns a Tensor or View with this rank's portion of the data.
+
+ Arguments:
+ tensor: The tensor to split
+
+ Keyword Arguments:
+ new_buffer (bool): If True, returns a new Tensor.
+ If False, returns a view into the existing Tensor.
+ Default is False
+
+ """
+ device_mesh = tensor.device_mesh
+ partition_size = torch.numel(tensor) // device_mesh.size()
+ start_index = partition_size * device_mesh.get_rank()
+ end_index = start_index + partition_size
+ if new_buffer:
+ data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
+ data.copy_(tensor.view(-1)[start_index:end_index])
+ else:
+ data = tensor.view(-1)[start_index:end_index]
+ return data
+
+
+class CheckpointFunction(torch.autograd.Function):
+ """This function is adapted from torch.utils.checkpoint with
+ two main changes:
+ 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
+ 2) the states in the model parallel tracker are also properly
+ tracked/set/reset.
+ """
+
+ @staticmethod
+ def forward(ctx, run_function, distribute_saved_activations, *args):
+ ctx.run_function = run_function
+ ctx.distribute_saved_activations = distribute_saved_activations
+
+ # Copy the rng states.
+ ctx.fwd_cpu_rng_state = torch.get_rng_state()
+ ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
+ ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ with torch.no_grad():
+ outputs = run_function(*args)
+
+ # Divide hidden states across model parallel group and only keep
+ # the chunk corresponding to the current rank.
+ if distribute_saved_activations:
+ ctx.input_0_shape = args[0].data.shape
+ assert isinstance(args[0].data, DTensor)
+ args[0].data = split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)
+
+ # Store everything.
+ ctx.save_for_backward(*args)
+
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *args):
+ if not torch.autograd._is_checkpoint_valid():
+ raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible")
+ inputs = ctx.saved_tensors
+ if ctx.distribute_saved_activations:
+ assert isinstance(inputs[0].data, DTensor)
+ inputs[0].data.redistribute(inputs[0].data.device_mesh, [Replicate()])
+
+ # Store the current states.
+ bwd_cpu_rng_state = torch.get_rng_state()
+ bwd_cuda_rng_state = torch.cuda.get_rng_state()
+ bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ # Set the states to what it used to be before the forward pass.
+ torch.set_rng_state(ctx.fwd_cpu_rng_state)
+ _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
+
+ # Compute the forward pass.
+ detached_inputs = detach_variable(inputs)
+ with torch.enable_grad():
+ outputs = ctx.run_function(*detached_inputs)
+
+ # Set the states back to what it was at the start of this function.
+ torch.set_rng_state(bwd_cpu_rng_state)
+ _set_cuda_rng_state(bwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
+
+ if isinstance(outputs, DTensor):
+ outputs = (outputs,)
+ torch.autograd.backward(outputs, args)
+ grads = tuple(inp.grad if isinstance(inp, DTensor) else inp for inp in detached_inputs)
+ return (None, None) + grads
+
+
+def checkpoint(function, distribute_saved_activations, *args):
+ """Checkpoint a model or part of the model.
+ This has been directly copied from torch.utils.checkpoint."""
+ return CheckpointFunction.apply(function, distribute_saved_activations, *args)
diff --git a/vescale/model/base_gpt/enums.py b/vescale/model/base_gpt/enums.py
new file mode 100644
index 0000000..841dffd
--- /dev/null
+++ b/vescale/model/base_gpt/enums.py
@@ -0,0 +1,27 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+import enum
+
+
+class ModelType(enum.Enum):
+ encoder_or_decoder = 1
+ encoder_and_decoder = 2
+
+
+class LayerType(enum.Enum):
+ encoder = 1
+ decoder = 2
+
+
+class AttnType(enum.Enum):
+ self_attn = 1
+ cross_attn = 2
+
+
+class AttnMaskType(enum.Enum):
+ padding = 1
+ causal = 2
diff --git a/vescale/model/base_gpt/fuse_layer_norm.py b/vescale/model/base_gpt/fuse_layer_norm.py
new file mode 100644
index 0000000..e1e5801
--- /dev/null
+++ b/vescale/model/base_gpt/fuse_layer_norm.py
@@ -0,0 +1,119 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+"""This code is copied fron NVIDIA apex:
+ https://github.com/NVIDIA/apex
+with some changes."""
+
+import importlib
+import numbers
+
+import torch
+from torch import nn
+from torch.nn import init
+from torch.nn.parameter import Parameter
+
+
+fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
+
+try:
+ from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
+
+ HAVE_PERSIST_LAYER_NORM = True
+except ImportError:
+ HAVE_PERSIST_LAYER_NORM = False
+
+try:
+ from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
+
+ HAVE_FUSED_LAYER_NORM = True
+except ImportError:
+ HAVE_FUSED_LAYER_NORM = False
+
+
+class MixedFusedLayerNorm(nn.Module):
+ def __init__(
+ self,
+ normalized_shape,
+ eps=1e-5,
+ no_persist_layer_norm=True,
+ param_dtype=torch.float32,
+ sequence_parallel=False,
+ apply_layernorm_1p=False,
+ ):
+ super().__init__()
+
+ self.apply_layernorm_1p = apply_layernorm_1p
+
+ # List of hiddens sizes supported in the persistent layer norm kernel
+ # If the hidden size is not supported, fall back to the non-persistent
+ # kernel.
+ persist_ln_hidden_sizes = [
+ 1024,
+ 1536,
+ 2048,
+ 2304,
+ 3072,
+ 3840,
+ 4096,
+ 5120,
+ 6144,
+ 8192,
+ 10240,
+ 12288,
+ 12800,
+ 15360,
+ 16384,
+ 18432,
+ 20480,
+ 24576,
+ 25600,
+ 30720,
+ 32768,
+ 40960,
+ 49152,
+ 65536,
+ ]
+ if normalized_shape not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM:
+ no_persist_layer_norm = True
+
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ self.normalized_shape = torch.Size(normalized_shape)
+ self.eps = eps
+ self.weight = Parameter(torch.Tensor(*normalized_shape).to(param_dtype))
+ self.bias = Parameter(torch.Tensor(*normalized_shape).to(param_dtype))
+ self.reset_parameters()
+ self.no_persist_layer_norm = no_persist_layer_norm
+ self.sequence_parallel = sequence_parallel
+
+ def reset_parameters(self):
+ if self.apply_layernorm_1p:
+ init.zeros_(self.weight)
+ init.zeros_(self.bias)
+ else:
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+ def forward(self, input):
+ weight = self.weight + 1 if self.apply_layernorm_1p else self.weight
+
+ if self.no_persist_layer_norm:
+ assert (
+ FusedLayerNormAffineFunction is not None
+ ), "FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex"
+ out = FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps, False)
+ return out
+ else:
+ output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
+
+ # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
+ # a populated '_base' field). This will result in schedule.py's
+ # deallocate_output_tensor() throwing an error, so a viewless tensor is
+ # created to prevent this.
+ # output = make_viewless_tensor(
+ # inp=output, requires_grad=input.requires_grad, keep_graph=True)
+ return output
diff --git a/vescale/model/base_gpt/fuse_softmax.py b/vescale/model/base_gpt/fuse_softmax.py
new file mode 100644
index 0000000..25f3021
--- /dev/null
+++ b/vescale/model/base_gpt/fuse_softmax.py
@@ -0,0 +1,203 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+import torch
+from torch import nn
+
+from vescale.model.base_gpt.enums import AttnMaskType
+
+
+class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
+ """
+ Fused operation which performs following three operations in sequence
+ 1. Scale the tensor.
+ 2. Apply upper triangular mask (typically used in gpt models).
+ 3. Perform softmax.
+ """
+
+ @staticmethod
+ def forward(ctx, inputs, scale):
+ import scaled_upper_triang_masked_softmax_cuda
+
+ scale_t = torch.tensor([scale])
+ softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
+
+ ctx.save_for_backward(softmax_results, scale_t)
+ return softmax_results
+
+ @staticmethod
+ def backward(ctx, output_grads):
+ import scaled_upper_triang_masked_softmax_cuda
+
+ softmax_results, scale_t = ctx.saved_tensors
+ input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
+
+ return input_grads, None
+
+
+class ScaledMaskedSoftmax(torch.autograd.Function):
+ """
+ Fused operation which performs following three operations in sequence
+ 1. Scale the tensor.
+ 2. Apply the mask.
+ 3. Perform softmax.
+ """
+
+ @staticmethod
+ def forward(ctx, inputs, mask, scale):
+ import scaled_masked_softmax_cuda
+
+ scale_t = torch.tensor([scale])
+ softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
+ ctx.save_for_backward(softmax_results, scale_t)
+ return softmax_results
+
+ @staticmethod
+ def backward(ctx, output_grads):
+ import scaled_masked_softmax_cuda
+
+ softmax_results, scale_t = ctx.saved_tensors
+
+ input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
+ return input_grads, None, None
+
+
+class ScaledSoftmax(torch.autograd.Function):
+ """
+ Fused operation which performs following two operations in sequence
+ 1. Scale the tensor.
+ 2. Perform softmax.
+ """
+
+ @staticmethod
+ def forward(ctx, inputs, scale):
+ import scaled_softmax_cuda
+
+ scale_t = torch.tensor([scale])
+
+ softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0])
+ ctx.save_for_backward(softmax_results, scale_t)
+ return softmax_results
+
+ @staticmethod
+ def backward(ctx, output_grads):
+ import scaled_softmax_cuda
+
+ softmax_results, scale_t = ctx.saved_tensors
+
+ input_grads = scaled_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
+ return input_grads, None, None
+
+
+class FusedScaleMaskSoftmax(nn.Module):
+ """
+ fused operation: scaling + mask + softmax
+
+ Arguments:
+ input_in_fp16: flag to indicate if input in fp16 data format.
+ input_in_bf16: flag to indicate if input in bf16 data format.
+ attn_mask_type: attention mask type (pad or causal)
+ scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
+ mask_func: mask function to be applied.
+ softmax_in_fp32: if true, softmax in performed at fp32 precision.
+ scale: scaling factor used in input tensor scaling.
+ """
+
+ def __init__(
+ self,
+ input_in_fp16,
+ input_in_bf16,
+ attn_mask_type,
+ scaled_masked_softmax_fusion,
+ mask_func,
+ softmax_in_fp32,
+ scale,
+ ):
+ super().__init__()
+ self.input_in_fp16 = input_in_fp16
+ self.input_in_bf16 = input_in_bf16
+ assert not (
+ self.input_in_fp16 and self.input_in_bf16
+ ), "both fp16 and bf16 flags cannot be active at the same time."
+ self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
+ self.attn_mask_type = attn_mask_type
+ self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
+ self.mask_func = mask_func
+ self.softmax_in_fp32 = softmax_in_fp32
+ self.scale = scale
+
+ assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
+
+ def forward(self, input, mask):
+ # [b, np, sq, sk]
+ assert input.dim() == 4
+
+ if self.is_kernel_available(mask, *input.size()):
+ return self.forward_fused_softmax(input, mask)
+ else:
+ return self.forward_torch_softmax(input, mask)
+
+ def is_kernel_available(self, mask, b, np, sq, sk):
+ attn_batches = b * np
+
+ if (
+ self.scaled_masked_softmax_fusion # user want to fuse
+ and self.input_in_float16 # input must be fp16
+ and 16 < sk <= 4096 # sk must be 16 ~ 2048
+ and sq % 4 == 0 # sq must be divisor of 4
+ and sk % 4 == 0 # sk must be divisor of 4
+ and attn_batches % 4 == 0 # np * b must be divisor of 4
+ ):
+ if 0 <= sk <= 4096:
+ batch_per_block = self.get_batch_per_block(sq, sk, b, np)
+
+ if self.attn_mask_type == AttnMaskType.causal:
+ if attn_batches % batch_per_block == 0:
+ return True
+ else:
+ if sq % batch_per_block == 0:
+ return True
+ return False
+
+ def forward_fused_softmax(self, input, mask):
+ b, np, sq, sk = input.size()
+ scale = self.scale if self.scale is not None else 1.0
+
+ if self.attn_mask_type == AttnMaskType.causal:
+ assert sq == sk, "causal mask is only for self attention"
+ # input is 3D tensor (attn_batches, sq, sk)
+ input = input.view(-1, sq, sk)
+ probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
+ return probs.view(b, np, sq, sk)
+ else:
+ # input is 4D tensor (b, np, sq, sk)
+ if mask is not None:
+ return ScaledMaskedSoftmax.apply(input, mask, scale)
+ else:
+ return ScaledSoftmax.apply(input, scale)
+
+ def forward_torch_softmax(self, input, mask):
+ if self.input_in_float16 and self.softmax_in_fp32:
+ input = input.float()
+
+ if self.scale is not None:
+ input = input * self.scale
+ mask_output = self.mask_func(input, mask) if mask is not None else input
+ probs = torch.nn.Softmax(dim=-1)(mask_output)
+
+ if self.input_in_float16 and self.softmax_in_fp32:
+ if self.input_in_fp16:
+ probs = probs.half()
+ else:
+ probs = probs.bfloat16()
+
+ return probs
+
+ @staticmethod
+ def get_batch_per_block(sq, sk, b, np):
+ import scaled_masked_softmax_cuda
+
+ return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
diff --git a/vescale/model/base_gpt/jit_func.py b/vescale/model/base_gpt/jit_func.py
new file mode 100644
index 0000000..c129688
--- /dev/null
+++ b/vescale/model/base_gpt/jit_func.py
@@ -0,0 +1,40 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+from typing import Optional
+
+import torch
+
+
+def bias_dropout_add(x, bias, residual, prob, training):
+ # type: (torch.Tensor, Optional[torch.Tensor], torch.Tensor, float, bool) -> torch.Tensor
+ if bias is not None:
+ x = x + bias
+ out = torch.nn.functional.dropout(x, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+def get_bias_dropout_add(training):
+ def _bias_dropout_add(x, bias, residual, prob):
+ return bias_dropout_add(x, bias, residual, prob, training)
+
+ return _bias_dropout_add
+
+
+@torch.compile
+def bias_dropout_add_fused_train(
+ x: torch.Tensor, bias: Optional[torch.Tensor], residual: torch.Tensor, prob: float
+) -> torch.Tensor:
+ return bias_dropout_add(x, bias, residual, prob, True)
+
+
+# @torch.jit.script
+@torch.compile
+def bias_dropout_add_fused_inference(
+ x: torch.Tensor, bias: Optional[torch.Tensor], residual: torch.Tensor, prob: float
+) -> torch.Tensor:
+ return bias_dropout_add(x, bias, residual, prob, False)
diff --git a/vescale/model/base_gpt/mlp.py b/vescale/model/base_gpt/mlp.py
new file mode 100644
index 0000000..f2c33fc
--- /dev/null
+++ b/vescale/model/base_gpt/mlp.py
@@ -0,0 +1,101 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+import torch
+from torch import nn
+
+from vescale.model.utils import bias_gelu_impl, openai_gelu
+
+
+class SwitchMLP(nn.Module):
+ """
+ Routes input to one of N MLP "experts"
+ """
+
+ def __init__(self, hidden_size, num_experts):
+ super().__init__()
+ self.router = nn.Linear(hidden_size, num_experts)
+ self.experts = torch.nn.ModuleList()
+ for _ in range(num_experts):
+ self.experts.append(ParallelMLP(hidden_size))
+
+ def forward(self, hidden_states):
+ # hidden_states: [s, b, h]
+ s = hidden_states.size(0)
+ b = hidden_states.size(1)
+ h = hidden_states.size(2)
+ route = self.router(hidden_states)
+ route = torch.nn.functional.softmax(route, dim=2)
+ max_prob, max_ind = torch.max(route, dim=2)
+ max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
+
+ # TODO (rprenger) TODO this could be made easier to read
+ # Converting [s, b, h] to [s*b, h].
+ # Each vector could be routed differently
+ # [s*b h]
+ hidden_states = hidden_states.view(-1, hidden_states.size(2))
+ max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
+ max_ind = max_ind.view(-1) # [s*b]
+
+ output_total = torch.empty_like(hidden_states)
+ output_bias_total = torch.empty_like(hidden_states)
+ # TODO (rprenger) This does each expert in serial, but it could be parallelized
+
+ for expert_num, expert in enumerate(self.experts):
+ local_indices = (max_ind == expert_num).nonzero()
+ hidden = hidden_states[local_indices, :]
+ output, output_bias = expert(hidden)
+ output_bias = output_bias.expand_as(output)
+ output_total[local_indices, :] = output
+ output_bias_total[local_indices, :] = output_bias
+
+ output_total = output_total * max_prob
+ output_bias_total = output_bias_total * max_prob
+ output_total = output_total.view(s, b, h)
+ output_bias_total = output_bias_total.view(s, b, h)
+
+ return output_total, output_bias_total
+
+
+class ParallelMLP(nn.Module):
+ """MLP.
+
+ MLP will take the input with h hidden state, project it to 4*h
+ hidden dimension, perform nonlinear transformation, and project the
+ state back into h hidden dimension.
+ """
+
+ def __init__(self, h, param_dtype=torch.float32, bias_gelu_fusion=None):
+ super().__init__()
+
+ # Project to 4h.
+ self.dense_h_to_4h = nn.Linear(h, h * 4, bias=False, dtype=param_dtype)
+ # torch.nn.init.normal_(self.dense_h_to_4h.weight, mean=0.0, std=0.02)
+ torch.nn.init.xavier_normal_(self.dense_h_to_4h.weight)
+ self.dense_h_to_4h_bias = nn.Parameter(torch.zeros(4 * h, dtype=param_dtype))
+
+ self.bias_gelu_fusion = bias_gelu_fusion
+ self.activation_func = openai_gelu
+
+ # Project back to h.
+ self.dense_4h_to_h = nn.Linear(4 * h, h, bias=False, dtype=param_dtype)
+ torch.nn.init.xavier_uniform_(self.dense_4h_to_h.weight)
+ # torch.nn.init.normal_(self.dense_4h_to_h.weight, mean=0.0, std=0.02)
+ self.dense_4h_to_h_bias = nn.Parameter(torch.zeros(h, dtype=param_dtype))
+
+ def forward(self, hidden_states):
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
+ bias_parallel = self.dense_h_to_4h_bias
+
+ if self.bias_gelu_fusion:
+ intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
+ else:
+ intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
+
+ # [s, b, h]
+ output = self.dense_4h_to_h(intermediate_parallel)
+ output_bias = self.dense_4h_to_h_bias
+ return output, output_bias
diff --git a/vescale/model/base_gpt/rotary.py b/vescale/model/base_gpt/rotary.py
new file mode 100644
index 0000000..eaa8d76
--- /dev/null
+++ b/vescale/model/base_gpt/rotary.py
@@ -0,0 +1,52 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+from typing import Union
+
+import torch
+from torch import Tensor
+
+from vescale.dtensor.dtensor import DTensor
+
+
+def _rotate_half(x: Union[Tensor, DTensor]) -> Union[Tensor, DTensor]:
+ """Change sign so the last dimension becomes [-odd, +even]
+
+ Args:
+ x (Tensor): Input tensor
+
+ Returns:
+ Tensor: Tensor rotated half
+ """
+
+ x1, x2 = torch.chunk(x, 2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(t: Union[Tensor, DTensor], freqs: Union[Tensor, DTensor]) -> Union[Tensor, DTensor]:
+ """Apply rotary positional embedding to input tensor T.
+
+ check https://kexue.fm/archives/8265 for detailed formulas
+
+ Args:
+ t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
+ freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
+
+ Returns:
+ Tensor: The input tensor after applying RoPE
+ """
+ rot_dim = freqs.shape[-1]
+
+ # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
+ t, t_pass = t.narrow(-1, 0, rot_dim), t.narrow(-1, rot_dim, max(t.size()[-1] - rot_dim, 0))
+
+ # first part is cosine component
+ # second part is sine component, need to change signs with _rotate_half method
+ cos_ = torch.cos(freqs).to(t.dtype)
+ sin_ = torch.sin(freqs).to(t.dtype)
+
+ t = (t * cos_) + (_rotate_half(t) * sin_)
+ return torch.cat((t, t_pass), dim=-1)
diff --git a/vescale/model/base_gpt/transformer_block.py b/vescale/model/base_gpt/transformer_block.py
new file mode 100644
index 0000000..a2c09be
--- /dev/null
+++ b/vescale/model/base_gpt/transformer_block.py
@@ -0,0 +1,135 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+import torch
+from torch import nn
+from contextlib import nullcontext
+from vescale.dtensor.dtensor import DTensor
+from vescale.initialize.deferred_init import deferred_init
+from vescale.model.base_gpt.transformer_layer import ParallelTransformerLayer
+from vescale.model.random import get_cuda_rng_tracker
+
+
+class TransformerBlock(nn.Module):
+ """Transformer class."""
+
+ def __init__(
+ self,
+ num_layer,
+ args,
+ drop_path_rate=0.0,
+ pre_process=True,
+ deferred_init=False,
+ ):
+ super().__init__()
+
+ self.config = args
+ self.drop_path_rate = drop_path_rate
+ self.pre_process = pre_process
+ self.num_layer = num_layer
+ self.deferred_init = deferred_init
+
+ # required for pipeline parallel schedules
+ self.input_tensor = None
+ self._build_layers()
+
+ def _build_layers(self):
+ # Transformer layers.
+ # @jcasper can we improve how we deal with layer_number?
+ # currently it's only used in CoreAttention?
+ # if self.apply_query_key_layer_scaling:
+ # coeff = self.layer_number
+ # self.norm_factor *= coeff
+ def build_layer(layer_number):
+ if self.deferred_init:
+ layer_config = {
+ "init_method": self.config.init_method,
+ "output_layer_init_method": self.config.output_layer_init_method,
+ "layer_number": layer_number,
+ "args": self.config,
+ "drop_path_rate": self.drop_path_rate,
+ }
+ layer = deferred_init(ParallelTransformerLayer, **layer_config)
+ else:
+ layer = ParallelTransformerLayer(
+ self.config.init_method,
+ self.config.output_layer_init_method,
+ layer_number,
+ self.config,
+ self.drop_path_rate,
+ )
+
+ return layer
+
+ # offset is implicit in TransformerLayer
+ self.transformer_layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layer)])
+ self.layers = torch.nn.Sequential()
+ for i in range(len(self.transformer_layers)):
+ self.layers.append(self.transformer_layers[i])
+
+ def _get_layer(self, layer_number):
+ return self.transformer_layers[layer_number]
+
+ def set_input_tensor(self, input_tensor):
+ """Set input tensor to be used instead of forward()'s input.
+
+ When doing pipeline parallelism the input from the previous
+ stage comes from communication, not from the input, so the
+ model's forward_step_func won't have it. This function is thus
+ used by internal code to bypass the input provided by the
+ forward_step_func"""
+ self.input_tensor = input_tensor
+
+ def forward(
+ self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None
+ ):
+ # hidden_states (float): [s, b, h]
+ # attention_mask (bool): [1, 1, s, s]
+
+ if not self.pre_process:
+ # See set_input_tensor()
+ hidden_states = self.input_tensor
+
+ # Viewless tensor.
+ # - We only need to create a viewless tensor in the case of micro batch
+ # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
+ # above creates a view tensor, and '.contiguous()' is a pass-through.
+ # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
+ # the need to make it viewless.
+ #
+ # However, we don't explicitly check mbs == 1 here because
+ # make_viewless_tensor() has negligible overhead when its input
+ # is already viewless.
+ #
+ # - For the 'else' case above, calling make_viewless_tensor() here is
+ # likely redundant, since p2p_communication.py (likely originator)
+ # already creates viewless tensors. That said, make_viewless_tensor()
+ # is called here to be future-proof and corner-case-proof.
+ # hidden_states = make_viewless_tensor(
+ # inp=hidden_states,
+ # requires_grad=True,
+ # keep_graph=True,
+ # )
+
+ rng_context = nullcontext()
+ if isinstance(hidden_states, DTensor):
+ placements = hidden_states.placements
+ # check sbh, for s
+ is_sp = any(placement.is_shard(dim=0) for placement in placements)
+ if is_sp:
+ rng_context = get_cuda_rng_tracker().fork()
+
+ with rng_context:
+ for layer in self.transformer_layers:
+ hidden_states = layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ encoder_output=encoder_output,
+ enc_dec_attn_mask=enc_dec_attn_mask,
+ inference_params=inference_params,
+ )
+
+ return hidden_states
diff --git a/vescale/model/base_gpt/transformer_layer.py b/vescale/model/base_gpt/transformer_layer.py
new file mode 100644
index 0000000..f9931d1
--- /dev/null
+++ b/vescale/model/base_gpt/transformer_layer.py
@@ -0,0 +1,194 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+from contextlib import nullcontext
+
+import torch
+from torch import nn
+
+from vescale.model.attention.dmodule_parallel_attention import ParallelAttention
+from vescale.model.base_gpt.fuse_layer_norm import MixedFusedLayerNorm as LayerNorm
+from vescale.model.base_gpt.jit_func import (
+ bias_dropout_add_fused_inference,
+ bias_dropout_add_fused_train,
+ get_bias_dropout_add,
+)
+from vescale.model.base_gpt.mlp import ParallelMLP, SwitchMLP
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample
+ (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=0.0):
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_state):
+ if self.drop_prob == 0.0 or not self.training:
+ return hidden_state
+ keep_prob = 1 - self.drop_prob
+ # work with diff dim tensors, not just 2D ConvNets
+ # hidden_state: [s, b, h]
+ random_tensor = keep_prob + torch.rand_like(hidden_state)
+ random_tensor.floor_() # binarize
+ output = hidden_state.div(keep_prob) * random_tensor
+ return output
+
+
+class ParallelTransformerLayer(nn.Module):
+ """A single transformer layer.
+
+ Transformer layer takes input with size [s, b, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(
+ self,
+ init_method,
+ output_layer_init_method,
+ layer_number,
+ args,
+ drop_path_rate=0.0,
+ ):
+ super().__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
+
+ self.bf16 = args.bf16
+
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon,
+ no_persist_layer_norm=not args.persist_layer_norm,
+ param_dtype=args.param_dtype,
+ )
+
+ # Self attention.
+ self.self_attention = ParallelAttention(
+ args.hidden_size,
+ args.kv_channels,
+ args.num_attention_heads,
+ args.world_size,
+ 1, # n_shared_qhead
+ args.param_dtype,
+ )
+ self.hidden_dropout = args.hidden_dropout
+ self.bias_dropout_fusion = args.bias_dropout_fusion
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
+
+ # Layernorm on the attention output
+ self.post_attention_layernorm = LayerNorm(
+ args.hidden_size,
+ eps=args.layernorm_epsilon,
+ no_persist_layer_norm=not args.persist_layer_norm,
+ param_dtype=args.param_dtype,
+ )
+
+ # MLP
+ if args.num_experts is not None:
+ self.mlp = SwitchMLP(init_method, output_layer_init_method, args)
+ else:
+ self.mlp = ParallelMLP(args.hidden_size, param_dtype=args.param_dtype)
+
+ # Set bias+dropout+add fusion grad_enable execution handler.
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
+ use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
+ self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
+
+ def forward(
+ self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None
+ ):
+ # hidden_states: [s, b, h]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Self attention.
+ attention_output, attention_bias = self.self_attention(
+ layernorm_output, attention_mask, inference_params=inference_params
+ )
+
+ # assert not torch.isnan(attention_output.to_local()
+ # ).any(), attention_output
+ # assert not torch.isnan(attention_bias.to_local()).any(), attention_bias
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ if self.drop_path is None:
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ with self.bias_dropout_add_exec_handler():
+ layernorm_input = bias_dropout_add_func(
+ attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
+ )
+ else:
+ out = attention_output + attention_bias
+ out = torch.nn.functional.dropout(out, p=self.hidden_dropout, training=self.training)
+ layernorm_input = residual + self.drop_path(out)
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+ # assert not torch.isnan(layernorm_output).any()
+
+ # MLP.
+ mlp_output, mlp_bias = self.mlp(layernorm_output)
+ # assert not torch.isnan(mlp_output.to_local()).any(), mlp_output
+ # assert not torch.isnan(mlp_bias.to_local()).any(), mlp_bias
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ if self.drop_path is None:
+ with self.bias_dropout_add_exec_handler():
+ output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout)
+
+ # Jit compiled function creates 'view' tensor. This tensor
+ # potentially gets saved in the MPU checkpoint function context,
+ # which rejects view tensors. While making a viewless tensor here
+ # won't result in memory savings (like the data loader, or
+ # p2p_communication), it serves to document the origin of this
+ # 'view' tensor.
+ # output = dtensor.utils.make_viewless_tensor(inp=output, requires_grad=output.requires_grad, keep_graph=True)
+ #
+ else:
+ out = mlp_output + mlp_bias
+ out = torch.nn.functional.dropout(out, p=self.hidden_dropout, training=self.training)
+ output = residual + self.drop_path(out)
+
+ return output
+
+ def forward_util(self, input_tensor, data):
+ ret = {
+ "hidden_states": input_tensor if input_tensor is not None else data["hidden_states"],
+ "attention_mask": data["attention_mask"],
+ }
+ return [ret["hidden_states"], ret["attention_mask"]]
+
+ def output_utils(self, p2p_tensor):
+ p2p_tensor = torch.permute(p2p_tensor, (0, 2, 1))
+ return p2p_tensor
diff --git a/vescale/model/base_gpt/utils.py b/vescale/model/base_gpt/utils.py
new file mode 100644
index 0000000..3a67817
--- /dev/null
+++ b/vescale/model/base_gpt/utils.py
@@ -0,0 +1,27 @@
+################################################################################
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+import functools
+from typing import Callable
+
+from optree import tree_map
+from vescale.dtensor.dtensor import DTensor
+
+
+def switch_dtensor(func: Callable):
+ @functools.wraps(func)
+ def wrap(*args, **kwargs):
+ def to_tensor(x):
+ if isinstance(x, DTensor):
+ return x.to_local()
+ return x
+
+ new_args = tree_map(to_tensor, args)
+ new_kwargs = tree_map(to_tensor, kwargs)
+ out = func(*new_args, **new_kwargs)
+ return out
+
+ return wrap
diff --git a/vescale/ndtimeline/README.md b/vescale/ndtimeline/README.md
new file mode 100644
index 0000000..9f21442
--- /dev/null
+++ b/vescale/ndtimeline/README.md
@@ -0,0 +1,55 @@
+# ndtimeline (N-Dimension Timeline)
+
+## Why ndtimeline?
+
+- When training LLM (Large Language Models) on an extremely large scale, several challenges need to be overcome:
+
+ - Sink machines (i.e. stragglers) tend to initiate CUDA kernels slowly, significantly reducing training efficiency.
+ - Traditional tools such as torch profile/nsys can only offer information within one physical machine, whereas communication occurs among multiple or even thousands of machines.
+ - Although torch profile can provide details about training, it comes at a considerable cost, making it impractical to be constantly enabled. The large size of the tracing file is also a significant issue, making analysis difficult.
+
+- We require a mechanism to jointly collect and visualize training details across numerous machines with low costs and a small tracing log to effectively detect stragglers and confirm the training status.
+
+## What is ndtimeline?
+
+### Insights
+- With `CUDA Events` provided by device, we can record durations of interesting parts.
+- We can utilize a reference `CUDA Event` as a clock with a Unix timestamp, allowing comparisons between events to provide a full span that includes both the starting time and duration.
+- Clocks among different machines are challenging to synchronize precisely, but we can simulate a global clock through communication to offer a consistent view of spans.
+- To minimize overhead, we can record events in multiple stages and flush them in another thread at specific intervals
+- To maximize flexibility, `ndtimeline` exposes handlers for users to inject during pre and post metric processing and perform any desired operations.
+- As metric collectors are located in each training process, they ensure the same semantics as parallelism on each rank, facilitating the easy extension of ndtimeline when new types of parallelism are introduced.
+
+### Architecture
+Assume there are two ranks on one machine.
+
+
+
+### important details
+ - Communication Stream
+ - Torch does not inherently offer an interface to obtain the stream for communication of nccl. `ProcessgroupNCCL` maintains a `CUDA Stream` pool and generates a stream from it when the user does not manually set the `CUDA Stream `before communication.
+ - We modify torch to establish a new interface for this purpose.
+ - An important **caveat**: The first communication operation will be lost in the tracing statistics because the `CUDA Stream` is allocated lazily when the first communication operation is initiated. Therefore, you may encounter some information logs such as `p2p stream is not available, skipped` as expected.
+ - is_internal_vescale
+ - We employ `ndtimeline` to emit metrics to an internal service, which may potentially disclose important secrets. Hence, we have to remove some `NDHandler` implementations and the corresponding meta information.
+
+## How to use ndtimeline?
+- Both **Higher** and **Lower** api are provided
+ - `ndtimeline.api` provides three key higher apis: `init_ndtimers`, `wait` and `flush`
+ - `init_ndtimers` Metrics injected in vescale are predefined in `ndtimeline.predefined`. The method for processing metrics, named as `NDHandler` defined in `ndtimeline.handlers`, can also be initialized using init_ndtimers.
+ - `wait` ensures that metrics are flushed and should be called at the end of main.py.
+ - `flush` flushes the collected metrics and calibrates the simulated global clock.
+ - `ndtimeline.api` provides another two api: `inc_step` and `set_global_step`
+ - They are introduced to align with the traditional approach for maintaining the global step instead of a `global_step_getter` function.
+- **Lower** api are rather complex
+ - Investigate `ndtimeline.timer.NDTimerManager` and `ndtimeline.sock_streamer.NDtimelineStreamer` to understand how to use them.
+
+- An Example
+
+ - Demo for default tracing file
+
+
+
+ - Users can utilize `init_ndtimers` and pass `ndtimeline.handlers.LocalTimelineNDHandler` as the post handler. A tracing file on the machine will then be generated in the current directory.
+ - Using the Chrome built-in tracing file viewer at https://ui.perfetto.dev/, one can visualize the tracing file.
+ - In case you need a tracing file related to ranks on different machines, you can implement an MQHandler by yourself and send all metrics to a central storage. This provides you with a method to filter and generate the tracing file for specified ranks.
\ No newline at end of file
diff --git a/vescale/ndtimeline/__init__.py b/vescale/ndtimeline/__init__.py
new file mode 100644
index 0000000..cf86cbb
--- /dev/null
+++ b/vescale/ndtimeline/__init__.py
@@ -0,0 +1,87 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from . import handlers # noqa: F401
+from . import exceptions # noqa: F401
+from . import logger # noqa: F401
+from . import predefined # noqa: F401
+
+from .binary_protocol import serialize_to_package, encode_package, loads_fn, dumps_fn
+from .pool import DefaultEventPool, CudaEventPool
+from .world_info import WorldInfo, TrainingInfo, TopoInfo
+from .timer import (
+ NDTimerManager,
+ NDTimerManagerSingleton,
+ DeviceTimerMeta,
+ ndtimeit,
+ NDMetricLevel,
+ ndtimer,
+ ndtimeit_p2p,
+)
+from .sock_streamer import NDtimelineStreamer
+from .variables import (
+ NDTIMELINE_INNER_GLOBAL_STEP_KEY,
+ SOCK_TIMEOUT_CLIENT,
+ SOCK_PARENT_DIR,
+ SOCK_PATH,
+ NDTIMELINE_STREAM_KEY,
+)
+from .stream import get_nccl_p2p_stream, get_nccl_coll_stream
+from .api import flush, wait, init_ndtimers, set_global_step, inc_step
+
+__all__ = [
+ "handlers",
+ "logger",
+ "exceptions",
+ "predefined",
+ "serialize_to_package",
+ "encode_package",
+ "loads_fn",
+ "dumps_fn",
+ "DefaultEventPool",
+ "CudaEventPool",
+ "WorldInfo",
+ "TrainingInfo",
+ "TopoInfo",
+ "NDTimerManager",
+ "NDTimerManagerSingleton",
+ "DeviceTimerMeta",
+ "ndtimeit",
+ "NDtimelineStreamer",
+ "NDTIMELINE_INNER_GLOBAL_STEP_KEY",
+ "SOCK_TIMEOUT_CLIENT",
+ "SOCK_PARENT_DIR",
+ "SOCK_PATH",
+ "NDTIMELINE_STREAM_KEY",
+ "NDMetricLevel",
+ "get_nccl_p2p_stream",
+ "get_nccl_coll_stream",
+ "ndtimer",
+ "ndtimeit_p2p",
+ "flush",
+ "wait",
+ "init_ndtimers",
+ "set_global_step",
+ "inc_step",
+]
+
+try:
+ import _internal
+
+ __all__.append("_internal")
+except ImportError:
+ pass
diff --git a/vescale/ndtimeline/api.py b/vescale/ndtimeline/api.py
new file mode 100644
index 0000000..6ca266a
--- /dev/null
+++ b/vescale/ndtimeline/api.py
@@ -0,0 +1,396 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from typing import Optional, List, Callable, Tuple
+import math
+from copy import deepcopy
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+import torch
+import torch.distributed as dist
+
+from .is_internal import is_internal_vescale
+
+if is_internal_vescale():
+ from vescale.fsdp import FullyShardedDataParallel, ShardingStrategy
+ from vescale.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
+ from ._internal import _get_ip_by_env, _get_role_id, _get_run_id, _get_trial_id
+else:
+ # make python happy
+ class FullyShardedDataParallel:
+ pass
+
+ class ShardingStrategy:
+ pass
+
+ HYBRID_SHARDING_STRATEGIES = ""
+
+
+from vescale.dtensor.device_mesh import DeviceMesh
+from vescale.devicemesh_api import VESCALE_DEVICE_MESH
+from .timer import NDTimerManagerSingleton, DeviceTimerMeta, NDMetricLevel
+from .handlers import NDHandler, SockNDHandler, LocalTimelineNDHandler
+from .world_info import WorldInfo
+from .sock_streamer import NDtimelineStreamer
+from .logger import NDTimelineLogger
+from .predefined import (
+ FORWARD_COMPUTE,
+ BACKWARD_COMPUTE,
+ CROSS_MESH_RECV,
+ CROSS_MESH_SEND,
+ RECV_FORWARD,
+ RECV_BACKWARD,
+ SEND_FORWARD,
+ SEND_BACKWARD,
+ SEND_FORWARD_RECV_BACKWARD,
+ SEND_BACKWARD_RECV_FORWARD,
+ UNSHARD_AG,
+ GRAD_RS,
+ GRAD_AR,
+)
+from .fsdp_patch import patch_fsdp
+
+
+def init_ndtimers(
+ rank: Optional[int] = None,
+ mode: Literal["fsdp", "hybrid"] = "hybrid",
+ wrapped_fsdp_module: Optional[FullyShardedDataParallel] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mesh_shape: Optional[Tuple[int, ...]] = None,
+ local_rank: Optional[int] = None,
+ step_getter: Optional[Callable[[], int]] = None,
+ enable_streamer: bool = True,
+ n_rank_per_host: Optional[int] = None,
+ pre_handlers: Optional[List[NDHandler]] = None,
+ post_handlers: Optional[List[NDHandler]] = None,
+ user_spcified_timers: Optional[List[DeviceTimerMeta]] = None,
+ level: NDMetricLevel = NDMetricLevel.DEBUG,
+ ip: str = "0.0.0.0",
+ **kwargs,
+):
+ """
+ High level api to enable timers.
+ It MUST be called after both torch.cuda.set_device and default process group are initialized.
+
+ Args:
+ rank (int): rank id. If rank is None, it will be determined by torch.distributed.get_rank.
+
+ mode (str): `fsdp` or `hybrid` mode, `fsdp` currently is only supported in internal version.
+
+ wrapped_fsdp_module (FullyShardedDataParallel): `FullyShardedDataParallel` wrapped torch.nn.module,
+ only used in fsdp mode and only valid in internal version.
+
+ device_mesh (DeviceMesh): only used in fsdp mode and only valid in internal version.
+
+ mesh_shape (Tuple): only used in fsdp mode and only valid in internal version.
+
+ local_rank (int): local rank id. If local_rank is None, it will be determined by VESCALE_DEVICE_MESH.
+
+ step_getter (Callable[[], int]): func to get current global step. If it is None, steps will be always set as 0.
+ Another choice is to use `set_global_step` and `inc_step` to maintain step.
+
+ enable_streamer (bool): If set, a streamer process will be forked and then post_handlers can be enabled.
+
+ n_rank_per_host (int): number of devices on one machine. If it is None, it will be determined by torch.cuda.device_count.
+
+ pre_handlers (List[NDHandler]): List of NDHandlers triggered immediately after `flush` on each training process.
+ `SockNDHandler` will be automatically injected in pre_handlers when streamer enabled and no pre_handlers are given.
+
+ post_handlers (List[NDHandler]): List of NDHandlers triggered in streamer process.
+ `LocalTimelineNDHandler` will be automatically injected when streamer enabled and no post_handlers are given.
+
+ user_spcified_timers (List[DeviceTimerMeta]): List of DeviceTimerMeta registered by user.
+
+ level (NDMetricLevel): metrics of which the level is lower than this will be ignored.
+
+ ip (str): pod/host ip.
+
+ Returns:
+ Nothing
+ """
+
+ post_handlers = [] if post_handlers is None else post_handlers
+ pre_handlers = [] if pre_handlers is None else pre_handlers
+ user_spcified_timers = [] if user_spcified_timers is None else user_spcified_timers
+
+ if mode not in ["hybrid", "fsdp"]:
+ raise NotImplementedError(f"mode {mode} not implemented")
+
+ if mode == "fsdp" and not is_internal_vescale():
+ raise NotImplementedError("fsdp is not currently supported for opensource version")
+
+ if mode != "fsdp" and wrapped_fsdp_module is not None:
+ raise ValueError("wrapped_fsdp_module and mode should be set accordingly")
+
+ if NDTimerManagerSingleton.is_initialized():
+ NDTimelineLogger().warning("timers initialized, no need for initialization")
+ return
+
+ local_rank = VESCALE_DEVICE_MESH.get_local_rank() if local_rank is None else local_rank
+ rank = torch.distributed.get_rank() if rank is None else rank
+ n_rank_per_host = torch.cuda.device_count() if n_rank_per_host is None else n_rank_per_host
+
+ world_size = dist.get_world_size()
+ ddp_rank, ddp_size = 0, 1
+ if mode == "hybrid":
+ tp_size = VESCALE_DEVICE_MESH.get_strategy_size("TP")
+ dp_size = VESCALE_DEVICE_MESH.get_strategy_size("DP")
+ pp_size = VESCALE_DEVICE_MESH.get_strategy_size("PP")
+
+ tp_rank = VESCALE_DEVICE_MESH.get_tensor_parallel_rank()
+ pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank()
+ dp_rank = VESCALE_DEVICE_MESH.get_data_parallel_rank()
+
+ assert (
+ tp_size * dp_size * pp_size == world_size
+ ), f"tp_size: {tp_size}, dp_size: {dp_size}, pp_size: {pp_size}, world_size: {world_size}"
+ elif mode == "fsdp":
+ tp_size, pp_size = 1, 1
+ tp_rank, pp_rank = 0, 0
+
+ patch_fsdp()
+ if wrapped_fsdp_module is not None:
+ intra_node_group = wrapped_fsdp_module.process_group
+ inter_node_group = getattr(wrapped_fsdp_module, "_inter_node_pg", None)
+ dp_rank, dp_size, ddp_rank, ddp_size = _calculate_topo(
+ intra_node_group,
+ inter_node_group,
+ wrapped_fsdp_module.sharding_strategy,
+ world_size,
+ )
+ elif device_mesh is not None:
+ dp_rank, dp_size, ddp_rank, ddp_size = _calculate_topo_by_shape(tuple(device_mesh.mesh.shape), rank)
+ elif mesh_shape is not None:
+ dp_rank, dp_size, ddp_rank, ddp_size = _calculate_topo_by_shape(mesh_shape, rank)
+ else:
+ raise ValueError("for fsdp, device_mesh or wrapped_fsdp_module or mesh_shape must be given at least 1")
+
+ if enable_streamer:
+ if local_rank == 0:
+ if len(post_handlers) > 0:
+ NDtimelineStreamer.init(local_rank, post_handlers)
+ else:
+ NDtimelineStreamer.init(
+ local_rank,
+ [
+ LocalTimelineNDHandler(n_rank_per_host),
+ ],
+ )
+ if len(pre_handlers) == 0 or all(not isinstance(handler, SockNDHandler) for handler in pre_handlers):
+ pre_handlers.append(SockNDHandler())
+
+ trial_id, run_id, role_id = 0, 0, 0
+
+ if is_internal_vescale():
+ if ip == "0.0.0.0":
+ ip = _get_ip_by_env()
+ trial_id = _get_trial_id()
+ run_id = _get_run_id()
+ role_id = _get_role_id()
+
+ NDTimerManagerSingleton(
+ WorldInfo(
+ rank=rank,
+ local_rank=local_rank,
+ tp_rank=tp_rank,
+ pp_rank=pp_rank,
+ dp_rank=dp_rank,
+ ddp_rank=ddp_rank,
+ tp_size=tp_size,
+ pp_size=pp_size,
+ dp_size=dp_size,
+ ddp_size=ddp_size,
+ world_size=world_size,
+ ip=ip,
+ trial_id=trial_id,
+ run_id=run_id,
+ role_id=role_id,
+ ),
+ init_cuda_dist=True,
+ handlers=pre_handlers,
+ metric_level=level,
+ )
+
+ extra = {}
+ mq_sinks = []
+ if is_internal_vescale():
+ from ._internal import MQNDHandler
+
+ for handler in post_handlers:
+ if isinstance(handler, MQNDHandler):
+ mq_sinks.extend(handler.mq_sinks)
+ if len(mq_sinks) != 0:
+ extra = {"sinks": mq_sinks}
+
+ if mode == "hybrid":
+ predefined_timers = [
+ DeviceTimerMeta(SEND_BACKWARD, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(SEND_FORWARD, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(RECV_FORWARD, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(RECV_BACKWARD, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(SEND_FORWARD_RECV_BACKWARD, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(SEND_BACKWARD_RECV_FORWARD, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(CROSS_MESH_RECV, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(CROSS_MESH_SEND, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(FORWARD_COMPUTE, is_cpu_op=False, step_getter=step_getter),
+ DeviceTimerMeta(BACKWARD_COMPUTE, is_cpu_op=False, step_getter=step_getter),
+ ]
+ else:
+ predefined_timers = [
+ DeviceTimerMeta(
+ UNSHARD_AG,
+ is_cpu_op=False,
+ step_getter=step_getter,
+ common_extra=deepcopy(extra),
+ ),
+ DeviceTimerMeta(
+ GRAD_RS,
+ is_cpu_op=False,
+ step_getter=step_getter,
+ common_extra=deepcopy(extra),
+ ),
+ DeviceTimerMeta(
+ GRAD_AR,
+ is_cpu_op=False,
+ step_getter=step_getter,
+ common_extra=deepcopy(extra),
+ ),
+ DeviceTimerMeta(
+ FORWARD_COMPUTE,
+ is_cpu_op=False,
+ step_getter=step_getter,
+ common_extra=deepcopy(extra),
+ ),
+ DeviceTimerMeta(
+ BACKWARD_COMPUTE,
+ is_cpu_op=False,
+ step_getter=step_getter,
+ common_extra=deepcopy(extra),
+ ),
+ ]
+ predefined_timers.extend(user_spcified_timers)
+ NDTimerManagerSingleton().register_timers(predefined_timers)
+
+
+def wait():
+ """
+ High level api for timers to exit gracefully
+ """
+ if NDTimerManagerSingleton.is_initialized():
+ NDTimerManagerSingleton().wait()
+
+
+def set_global_step(global_step: int = 0):
+ """
+ Another choice to set global step when `global_step_getter` is None
+ """
+ if NDTimerManagerSingleton.is_initialized():
+ NDTimerManagerSingleton().global_step = global_step
+
+
+def inc_step(step: int = 1):
+ """
+ Another choice beside `global_step_getter` to increase global step when `global_step_getter` is None
+ """
+ if NDTimerManagerSingleton.is_initialized():
+ step_increased = NDTimerManagerSingleton().global_step + step
+ NDTimerManagerSingleton().global_step = step_increased
+
+
+def flush(
+ step_range: Optional[range] = None,
+ next_iter_enabled: bool = True,
+ submit2handler: bool = True,
+ dynamic_calibrate: bool = False,
+ keep_timer_state: bool = False,
+ sequential_calibrate: bool = True,
+):
+ """
+ High level api for timers to flush metrics to handlers.
+
+ Args:
+ step_range (range): global step range. Theorically, NO step_getter is acceptable if user use lower level api.
+ Therefore, step_range is used to allocating steps to metrics. If step_getter is given, step_range can be ignored.
+
+ next_iter_enabled (bool): whether timers continue to be enabled after flushed
+
+ submit2handler (bool): whether metrics should be dropped. False means dropping metrics.
+
+ dynamic_calibrate (bool): whether calibrate clocks at least every 20 minutes.
+
+ keep_timer_state (bool): keep timers being enable or disabled state after flushed, if True; next_iter_enabled ignored if True
+
+ sequential_calibrate (bool): calibrate clocks in main thread or other threads
+
+ Returns:
+ Nothing
+
+ """
+ if NDTimerManagerSingleton.is_initialized():
+ step_range = range(0, 1) if step_range is None else step_range
+ NDTimerManagerSingleton().async_flush(
+ step_range,
+ next_iter_enabled=next_iter_enabled,
+ submit2handler=submit2handler,
+ dynamic_calibrate=dynamic_calibrate,
+ keep_timer_state=keep_timer_state,
+ sequential_calibrate=sequential_calibrate,
+ )
+
+
+def _calculate_topo(
+ intra_node_group: dist.ProcessGroup,
+ inter_node_group: dist.ProcessGroup,
+ sharding_strategy: ShardingStrategy,
+ world_size: int,
+) -> Tuple[int, int, int, int]:
+ if sharding_strategy in HYBRID_SHARDING_STRATEGIES:
+ ddp_size = inter_node_group.size()
+ ddp_rank = inter_node_group.rank()
+ dp_size = intra_node_group.size()
+ dp_rank = intra_node_group.rank()
+ assert (
+ world_size == intra_node_group.size() * inter_node_group.size()
+ ), f"world_size: {world_size} intra_node_group: {dp_size} inter_node_group: {ddp_size}"
+ return dp_rank, dp_size, ddp_rank, ddp_size
+ elif sharding_strategy == ShardingStrategy.FULL_SHARD:
+ dp_size = intra_node_group.size()
+ dp_rank = intra_node_group.rank()
+ assert world_size == intra_node_group.size(), f"world_size: {world_size}"
+ return dp_rank, dp_size, 0, 1
+ else:
+ raise NotImplementedError("not implemented for ddp")
+
+
+def _calculate_topo_by_shape(mesh_shape: Tuple[int, ...], rank: int) -> Tuple[int, int, int, int]:
+ for m in mesh_shape:
+ assert m > 0 and isinstance(m, int)
+ if len(mesh_shape) == 2:
+ dim0, dim1 = mesh_shape[0], mesh_shape[1]
+ ddp_size, dp_size = dim0, dim1
+ mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
+ ddp_rank, dp_rank = torch.where(mesh == rank)
+ ddp_rank, dp_rank = int(ddp_rank), int(dp_rank)
+ return dp_rank, dp_size, ddp_rank, ddp_size
+ elif len(mesh_shape) == 1:
+ return rank, math.prod(mesh_shape), 0, 1
+ else:
+ raise ValueError(f"invalid mesh_shape {mesh_shape}")
diff --git a/vescale/ndtimeline/binary_protocol.py b/vescale/ndtimeline/binary_protocol.py
new file mode 100644
index 0000000..75a69bd
--- /dev/null
+++ b/vescale/ndtimeline/binary_protocol.py
@@ -0,0 +1,139 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import pickle
+import io
+import gc
+from typing import Any, Callable
+
+from .exceptions import ProtocolValidationError
+
+
+def dumps(v):
+ return pickle.dumps(v, protocol=4)
+
+
+def loads(binary):
+ gc.disable()
+ res = pickle.loads(binary)
+ gc.enable()
+ return res
+
+
+dumps_fn = dumps
+loads_fn = loads
+
+
+# +---------------------------------------------------------------+
+
+# | Magic Number 1Byte | Protocol Version 1Byte | Reserved 2Byte |
+
+# +---------------------------------------------------------------+
+
+# | Payload Length 4Byte |
+
+# +---------------------------------------------------------------+
+
+# | Payload |
+
+# +---------------------------------------------------------------+
+
+# | EOF Symbol 1Byte |
+
+# +------------------+
+
+# Both Payload Length and Maigc Number are Little Endian
+
+
+MAGIC_NUMBER = (0x9C).to_bytes(length=1, byteorder="little")
+MAGIC_BYTES_LEN = len(MAGIC_NUMBER)
+PROTOCOL_VERSION_0 = (0x0).to_bytes(length=1, byteorder="little")
+PROTOCOL_VERSION_BYTES_LEN = len(PROTOCOL_VERSION_0)
+RESERVED = b"00"
+RESERVED_BYTES_LEN = len(RESERVED)
+EOF_SYMBOL = b"\n"
+EOF_SYMBOL_BYTES_LEN = len(EOF_SYMBOL)
+MAX_PAYLOAD_LEN = 1024 * 1024 * 128 # 128MiB
+PAYLOAD_LEN_BYTES_LEN = 4
+
+
+# encode_package encode payload to package
+def encode_package(payload: bytes) -> bytes:
+ payload_len = len(payload)
+ if payload_len > MAX_PAYLOAD_LEN:
+ raise ValueError(f"payload size {payload_len}, larger than max size {MAX_PAYLOAD_LEN}")
+ payload_len_bytes = payload_len.to_bytes(length=PAYLOAD_LEN_BYTES_LEN, byteorder="little")
+ # memory efficient
+ return b"".join([MAGIC_NUMBER, PROTOCOL_VERSION_0, RESERVED, payload_len_bytes, payload, EOF_SYMBOL])
+
+
+# v: any pickable object
+def serialize_to_package(v: Any):
+ # payload = pickle.dumps(v, protocol=4)
+ payload = dumps_fn(v)
+ return encode_package(payload)
+
+
+def recv_and_validate(recv_func: Callable, preload_data: bytearray) -> bytes:
+ magic_bytes = read_or_recv(MAGIC_BYTES_LEN, recv_func, preload_data)
+ if magic_bytes != MAGIC_NUMBER:
+ raise ProtocolValidationError("MAGIC_NUMBER field is broken")
+ pt_version_bytes = read_or_recv(PROTOCOL_VERSION_BYTES_LEN, recv_func, preload_data)
+ if pt_version_bytes != PROTOCOL_VERSION_0:
+ raise ProtocolValidationError("PROTOCOL_VERSION_0 field is broken")
+ reserved_bytes = read_or_recv(RESERVED_BYTES_LEN, recv_func, preload_data)
+ if reserved_bytes != RESERVED:
+ raise ProtocolValidationError(f"RESERVED field is {reserved_bytes}, should be {RESERVED}")
+ payload_len_bytes = read_or_recv(PAYLOAD_LEN_BYTES_LEN, recv_func, preload_data)
+ payload_len = int.from_bytes(payload_len_bytes, byteorder="little")
+ if payload_len > MAX_PAYLOAD_LEN:
+ raise ProtocolValidationError(f"payload_len {payload_len} loger than {MAX_PAYLOAD_LEN}")
+ payload = read_or_recv(payload_len, recv_func, preload_data)
+ eof = read_or_recv(EOF_SYMBOL_BYTES_LEN, recv_func, preload_data)
+ if eof != EOF_SYMBOL:
+ raise ProtocolValidationError("EOF field is broken")
+ return payload
+
+
+def recv_to_buf(size: int, recv: Callable, preload_data: bytearray):
+ assert len(preload_data) <= size
+ buf = io.BytesIO()
+ buf.write(preload_data)
+ remaining = size - len(preload_data)
+ del preload_data[: len(preload_data)]
+ while remaining > 0:
+ chunk = recv(8192)
+ n = len(chunk)
+ if n == 0:
+ raise BrokenPipeError("recv 0 byte from socket")
+ if n <= remaining:
+ buf.write(chunk)
+ remaining -= n
+ else:
+ buf.write(chunk[:remaining])
+ preload_data.extend(chunk[remaining:])
+ return buf.getvalue()
+ return buf.getvalue()
+
+
+def read_or_recv(size: int, recv: Callable, preload_data: bytearray):
+ if len(preload_data) >= size:
+ res = bytes(preload_data[:size])
+ del preload_data[:size]
+ return res
+ else:
+ return recv_to_buf(size, recv, preload_data)
diff --git a/vescale/ndtimeline/exceptions.py b/vescale/ndtimeline/exceptions.py
new file mode 100644
index 0000000..35b995b
--- /dev/null
+++ b/vescale/ndtimeline/exceptions.py
@@ -0,0 +1,28 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+
+class ProtocolValidationError(ValueError):
+ def __init__(self, msg):
+ super().__init__(msg)
+ self.msg = msg
+
+
+class NDHandlerError(RuntimeError):
+ def __init__(self, msg):
+ super().__init__(msg)
+ self.msg = msg
diff --git a/vescale/ndtimeline/fsdp_patch.py b/vescale/ndtimeline/fsdp_patch.py
new file mode 100644
index 0000000..c3301a8
--- /dev/null
+++ b/vescale/ndtimeline/fsdp_patch.py
@@ -0,0 +1,28 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from .is_internal import is_internal_vescale
+
+if is_internal_vescale():
+ from ._internal import patch_fsdp, is_fsdp_patched
+else:
+
+ def patch_fsdp():
+ pass
+
+ def is_fsdp_patched():
+ return False
diff --git a/vescale/ndtimeline/handlers/__init__.py b/vescale/ndtimeline/handlers/__init__.py
new file mode 100644
index 0000000..171eac8
--- /dev/null
+++ b/vescale/ndtimeline/handlers/__init__.py
@@ -0,0 +1,34 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from .sock_handler import SockNDHandler
+from .parser_handler import ParserNDHandler
+from .logging_handler import LoggingNDHandler
+from .local_raw_handler import LocalRawNDHandler
+from .local_timeline_handler import LocalTimelineNDHandler
+from .handler_base import NDHandler
+from .do_nothing_handler import DoNothingNDHandler
+
+__all__ = [
+ "SockNDHandler",
+ "ParserNDHandler",
+ "NDHandler",
+ "LoggingNDHandler",
+ "LocalRawNDHandler",
+ "DoNothingNDHandler",
+ "LocalTimelineNDHandler",
+]
diff --git a/vescale/ndtimeline/handlers/chrome_trace_event.py b/vescale/ndtimeline/handlers/chrome_trace_event.py
new file mode 100644
index 0000000..809ea6d
--- /dev/null
+++ b/vescale/ndtimeline/handlers/chrome_trace_event.py
@@ -0,0 +1,291 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import random
+from dataclasses import dataclass
+from typing import Union, Optional, List, Dict, Tuple
+from abc import ABC, abstractmethod
+
+
+class TracingEvent(ABC):
+ """
+ chrome trace event format see doc:
+ https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#
+ """
+
+ @abstractmethod
+ def to_objects(self) -> List[dict]:
+ pass
+
+
+@dataclass
+class CompleteEvent(TracingEvent):
+ name: str
+ cat: str
+ pid: Union[str, int]
+ tid: Union[str, int]
+
+ # 起始和持续时间长度(单位都是us)
+ ts: float
+ dur: float
+
+ args: Optional[dict] = None
+
+ def to_objects(self) -> List[dict]:
+ return [
+ {
+ "name": self.name,
+ "cat": self.cat,
+ "pid": self.pid,
+ "tid": self.tid,
+ "args": self.args or {},
+ "ts": self.ts,
+ "dur": self.dur,
+ "ph": "X",
+ }
+ ]
+
+
+@dataclass
+class BeginEvent(TracingEvent):
+ name: str
+ cat: str
+ pid: Union[str, int]
+ tid: Union[str, int]
+
+ # 起始和持续时间长度(单位都是us)
+ ts: float
+ stack: Optional[List[int]] = None
+
+ args: Optional[dict] = None
+
+ def to_objects(self) -> List[dict]:
+ return [
+ {
+ "name": self.name,
+ "cat": self.cat,
+ "pid": self.pid,
+ "tid": self.tid,
+ "args": self.args or {},
+ "ts": self.ts,
+ "ph": "B",
+ }
+ ]
+
+
+@dataclass
+class EndEvent(TracingEvent):
+ name: str
+ cat: str
+ pid: Union[str, int]
+ tid: Union[str, int]
+
+ # 起始和持续时间长度(单位都是us)
+ ts: float
+ stack: Optional[List[int]] = None
+
+ args: Optional[dict] = None
+
+ def to_objects(self) -> List[dict]:
+ return [
+ {
+ "name": self.name,
+ "cat": self.cat,
+ "pid": self.pid,
+ "tid": self.tid,
+ "args": self.args or {},
+ "ts": self.ts,
+ "ph": "E",
+ }
+ ]
+
+
+flow_event_id_counter = 0
+
+
+@dataclass
+class FlowEvent(TracingEvent):
+ # {"ph": "f", "id": 246, "pid": "172.20.133.93", "tid": 13, "ts": 1669171992173028, \
+ # "cat": "async_gpu", "name": "cudaLaunchKernel", "bp": "e"}
+ name: str
+ cat: str
+
+ # list of (pid, tid, ts)
+ flows: List[Tuple[Union[str, int], Union[str, int], float]]
+
+ def to_objects(self) -> List[dict]:
+ global flow_event_id_counter
+ flow_event_id_counter += 1
+ gen_id = flow_event_id_counter # use stable predictable id
+ ret = []
+ # 起始时间比结束时间更晚的话没意义,不会被渲染,所以修正一下
+ for i in range(1, len(self.flows)):
+ _, _, ts0 = self.flows[i - 1]
+ pid, tid, ts1 = self.flows[i]
+ if ts1 <= ts0:
+ self.flows[i] = (pid, tid, ts0 + 1)
+ for f in self.flows:
+ pid, tid, ts = f
+ ret.append(
+ {
+ "name": self.name,
+ "cat": self.cat,
+ "pid": pid,
+ "tid": tid,
+ "ts": ts,
+ "ph": "t",
+ "bp": "e",
+ "id": gen_id,
+ }
+ )
+ ret[0]["ph"] = "s"
+ ret[-1]["ph"] = "f"
+ ret[-1]["ts"] += 1
+ return ret
+
+
+@dataclass
+class CounterEvent(TracingEvent):
+ name: str
+ pid: Union[str, int]
+
+ # 起始和持续时间长度(单位都是us)
+ ts: float
+
+ # 计数的数据序列
+ data: Dict[str, Union[int, float]]
+
+ def to_objects(self) -> List[dict]:
+ return [
+ {
+ "name": self.name,
+ "pid": self.pid,
+ "args": self.data,
+ "ts": self.ts,
+ "ph": "C",
+ }
+ ]
+
+
+class CombinedEvents(TracingEvent):
+ """
+ 将多个tracing event合并一起,表示成1个event,最后按顺序展开每个object
+ """
+
+ def __init__(self, events: List[TracingEvent]):
+ self.events = events
+
+ def to_objects(self) -> List[dict]:
+ obj = []
+ for e in self.events:
+ obj.extend(e.to_objects())
+ return obj
+
+
+@dataclass
+class ProcessMetadataEvent(TracingEvent):
+ pid: Union[str, int]
+ sort_index: int
+ process_name: Optional[str] = None
+ process_labels: List[str] = None
+
+ def to_objects(self) -> List[dict]:
+ ret = [
+ {
+ "name": "process_sort_index",
+ "pid": self.pid,
+ "ph": "M",
+ "args": {
+ "sort_index": self.sort_index,
+ },
+ }
+ ]
+ if self.process_labels is not None:
+ ret.append(
+ {
+ "name": "process_labels",
+ "pid": self.pid,
+ "ph": "M",
+ "args": {
+ "labels": ",".join(self.process_labels),
+ },
+ }
+ )
+ if self.process_name is not None:
+ ret.append(
+ {
+ "name": "process_name",
+ "pid": self.pid,
+ "ph": "M",
+ "args": {
+ "name": self.process_name,
+ },
+ }
+ )
+ return ret
+
+
+@dataclass
+class ThreadMetadataEvent(TracingEvent):
+ pid: Union[str, int]
+ tid: Union[str, int]
+ sort_index: int
+ thread_name: Optional[str] = None
+
+ def to_objects(self) -> List[dict]:
+ ret = [
+ {
+ "name": "thread_sort_index",
+ "pid": self.pid,
+ "tid": self.tid,
+ "ph": "M",
+ "args": {
+ "sort_index": self.sort_index,
+ },
+ }
+ ]
+ if self.thread_name is not None:
+ ret.append(
+ {
+ "name": "thread_name",
+ "pid": self.pid,
+ "tid": self.tid,
+ "ph": "M",
+ "args": {
+ "name": self.thread_name,
+ },
+ }
+ )
+ return ret
+
+
+class DummyEvent(TracingEvent):
+ def to_objects(self) -> List[dict]:
+ return [
+ {
+ "name": "dummy",
+ "cat": "dummy",
+ "pid": random.randint(1, 100),
+ "tid": random.randint(1, 100),
+ "args": {
+ "content": "*" * random.randint(100, 1000),
+ },
+ "ts": random.randint(1, 9999),
+ "dur": random.randint(1, 100),
+ "ph": "i",
+ }
+ ]
diff --git a/vescale/ndtimeline/handlers/do_nothing_handler.py b/vescale/ndtimeline/handlers/do_nothing_handler.py
new file mode 100644
index 0000000..ad4b9eb
--- /dev/null
+++ b/vescale/ndtimeline/handlers/do_nothing_handler.py
@@ -0,0 +1,36 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from typing import Dict, List, Any
+
+from .handler_base import NDHandler
+from ..world_info import WorldInfo
+
+
+class DoNothingNDHandler(NDHandler):
+ def call_impl(
+ self,
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+ ) -> Any:
+ pass
diff --git a/vescale/ndtimeline/handlers/handler_base.py b/vescale/ndtimeline/handlers/handler_base.py
new file mode 100644
index 0000000..055240e
--- /dev/null
+++ b/vescale/ndtimeline/handlers/handler_base.py
@@ -0,0 +1,79 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from typing import List, Any, Dict
+from abc import ABC, abstractmethod
+from ..variables import NDTIMELINE_FLUSH_SEPCIAL
+from ..world_info import WorldInfo
+
+
+class NDHandler(ABC):
+ def __init__(self, designated_key="", ignore_metrics=None) -> None:
+ super().__init__()
+ self._dispatch_key = self.__class__.__name__
+ self._ignore_metrics = ignore_metrics if ignore_metrics is not None else [NDTIMELINE_FLUSH_SEPCIAL]
+ if designated_key != "":
+ self._dispatch_key = designated_key
+
+ @property
+ def dispatch_key(self):
+ return self._dispatch_key
+
+ @property
+ def ignore_metrics(self):
+ return self._ignore_metrics
+
+ def __repr__(self) -> str:
+ return f"NDHandler instance with dispatch key: {self._dispatch_key}"
+
+ def __call__(
+ self,
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+ ) -> Any:
+ if metric_name in self.ignore_metrics:
+ return
+ return self.call_impl(
+ metric_name,
+ elapsed,
+ recent_elapsed_raw_parts,
+ recent_since_start_raw_parts,
+ tags,
+ step_range,
+ world_info,
+ extra,
+ )
+
+ @abstractmethod
+ def call_impl(
+ self,
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+ ) -> Any:
+ pass
diff --git a/vescale/ndtimeline/handlers/local_raw_handler.py b/vescale/ndtimeline/handlers/local_raw_handler.py
new file mode 100644
index 0000000..6db4cc7
--- /dev/null
+++ b/vescale/ndtimeline/handlers/local_raw_handler.py
@@ -0,0 +1,67 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import logging
+import os
+from typing import List, Dict, Any
+from logging import Formatter
+from logging.handlers import RotatingFileHandler
+
+from .handler_base import NDHandler
+from ..world_info import WorldInfo
+from ..variables import LOCAL_LOGGING_PATH
+
+
+CHUNK_SZ = 1024 * 1024 * 128 # 128 MiB
+BACKUP_CNT = 8
+
+
+class LocalRawNDHandler(NDHandler):
+ def __init__(
+ self, run_id: int, log_path: str = LOCAL_LOGGING_PATH, chunk_sz: int = CHUNK_SZ, backup_cnt: int = BACKUP_CNT
+ ) -> None:
+ """if a trial of log exceeds `chunk_sz`, it will be dropped"""
+ super().__init__()
+ if not os.path.exists(log_path):
+ os.makedirs(log_path, exist_ok=True)
+ file_name = f"timeline_run{run_id}_raw.log"
+ formatter = Formatter("%(asctime)s - %(message)s")
+ handler = RotatingFileHandler(
+ filename=os.path.join(log_path, file_name), maxBytes=chunk_sz, backupCount=backup_cnt
+ )
+ handler.setFormatter(formatter)
+ self.logger = logging.getLogger("LocalRawNDHandler")
+ self.logger.propagate = False
+ self.logger.addHandler(handler)
+ self.logger.setLevel(logging.DEBUG)
+
+ def call_impl(
+ self,
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+ ) -> Any:
+ msg = (
+ f"metric_name: {metric_name}, elapsed: {elapsed}, recent_elapsed_raw_parts: {recent_elapsed_raw_parts}, recent_since_start_raw_parts: {recent_since_start_raw_parts},"
+ f" tags: {tags}, step_range: {step_range}, world_info: {world_info}"
+ )
+ self.logger.info(msg)
diff --git a/vescale/ndtimeline/handlers/local_timeline_handler.py b/vescale/ndtimeline/handlers/local_timeline_handler.py
new file mode 100644
index 0000000..907c5ef
--- /dev/null
+++ b/vescale/ndtimeline/handlers/local_timeline_handler.py
@@ -0,0 +1,201 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import json
+from typing import List, Dict, Any, Set, Deque, Optional
+from collections import deque
+
+import torch
+
+from .chrome_trace_event import CompleteEvent, ThreadMetadataEvent, ProcessMetadataEvent
+from ..world_info import WorldInfo
+from ..variables import NDTIMELINE_FLUSH_SEPCIAL
+from .handler_base import NDHandler
+from .parser_handler import parse_record, DeviceTimerStreamRecord
+
+
+# thread_index_table
+
+
+def build_thread_index_table(tab, metrics, index, index_name):
+ for m in metrics:
+ tab[m] = (index, index_name)
+
+
+major_metrics = {
+ "forward-compute",
+ "backward-compute",
+ "embedding-grads-all-reduce",
+ "optimizer",
+ "optimizer-clip-main-grad",
+ "optimizer-inner-step",
+ "optimizer-copy-to-main-grad",
+ "optimizer-copy-main-to-model-params",
+}
+
+tp_stream_metrics = {
+ "tp-allreduce",
+ "tp-allgather",
+ "tp-reducescatter",
+ "layernorm-grads-all-reduce",
+}
+
+dp_stream_metrics = {
+ "grads-reduce-scatter",
+ "params-all-gather",
+ "separate-grads-all-reduce",
+ "grads-reduce-scatter-nonoverlapping",
+ "params-all-gather-nonoverlapping",
+}
+
+pp_batch_stream_metrics = {
+ "backward-send-backward-recv",
+ "backward-send-forward-recv",
+ "forward-send-backward-recv",
+ "forward-send-forward-recv",
+ "forward-backward-send-forward-backward-recv",
+ "cross-mesh-recv",
+ "cross-mesh-send",
+}
+
+pp_forward_stream_metrics = {
+ "forward-recv",
+ "backward-send",
+}
+
+pp_backward_stream_metrics = {
+ "forward-send",
+ "backward-recv",
+}
+
+
+thread_sort_index = {}
+build_thread_index_table(thread_sort_index, major_metrics, 0, "main")
+build_thread_index_table(thread_sort_index, pp_forward_stream_metrics, 1, "pp ->")
+build_thread_index_table(thread_sort_index, pp_backward_stream_metrics, 2, "pp <-")
+build_thread_index_table(thread_sort_index, pp_batch_stream_metrics, 3, "pp send/recv")
+build_thread_index_table(thread_sort_index, tp_stream_metrics, 4, "tp collective")
+build_thread_index_table(thread_sort_index, dp_stream_metrics, 5, "dp collective")
+sort_index_other = 6
+index_name_other = "other"
+
+
+events = []
+tid_names = {} # tid -> (pid, name)
+
+MAX_UINT64 = 18446744073709551615
+NEGTIVE_ONE = -1
+
+
+class LocalTimelineNDHandler(NDHandler):
+ def __init__(self, n_rank_per_host: Optional[int] = None):
+ super().__init__(ignore_metrics=[])
+ if n_rank_per_host is None:
+ n_rank_per_host = torch.cuda.device_count()
+ self.n_rank_per_host = n_rank_per_host
+ self.rank2buffer: List[List[DeviceTimerStreamRecord]] = [[] for _ in range(n_rank_per_host)]
+ # rank -> deque(set(steps), set(steps), empty set)
+ self.rank2steps: List[Deque[Set[int]]] = [deque(set() for _ in range(1)) for _ in range(n_rank_per_host)]
+
+ def dump_records(self):
+ output_ranks = set()
+ events = []
+ min_step, max_step = MAX_UINT64, NEGTIVE_ONE
+ buffer = [record for rank in range(self.n_rank_per_host) for record in self.rank2buffer[rank]]
+ for record in buffer:
+ metric, step, rank, dp_rank = record.metric, record.step, record.rank, record.dp_rank
+ if step < 0:
+ continue
+ min_step = min(min_step, step)
+ max_step = max(max_step, step)
+ output_ranks.add((dp_rank, rank))
+ sort_index, index_name = thread_sort_index.get(metric, (sort_index_other, index_name_other))
+ tid = rank * 10 + sort_index # 乘以10表示让出个位数给thread_sort_index编码
+ tid_names[tid] = (dp_rank, f"rank[{rank}] {index_name}")
+ for ts, dur in zip(record.start_ts, record.duration):
+ args = {
+ "rank": rank,
+ "step": step,
+ "tp": record.tp_rank,
+ "pp": record.pp_rank,
+ }
+ ev = CompleteEvent(name=metric, cat=metric, pid=dp_rank, tid=tid, ts=ts * 1e6, dur=dur * 1e6, args=args)
+ events.append(ev)
+ for tid, (dp_rank, name) in tid_names.items():
+ ev = ThreadMetadataEvent(
+ pid=dp_rank,
+ tid=tid,
+ sort_index=tid,
+ thread_name=name,
+ )
+ events.append(ev)
+ for dp_rank in {dp_rank for dp_rank, _ in output_ranks}:
+ ev = ProcessMetadataEvent(pid=dp_rank, sort_index=dp_rank, process_name=f"dp rank[{dp_rank}]")
+ events.append(ev)
+ spans = []
+ for ev in events:
+ spans.extend(ev.to_objects())
+ with open(f"trace_step{min_step}_{max_step}", "w") as f:
+ json.dump(spans, f)
+
+ def call_impl(
+ self,
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+ ) -> Any:
+ local_rank = world_info["local_rank"]
+ if metric_name == NDTIMELINE_FLUSH_SEPCIAL:
+ self.rank2steps[local_rank].append(set())
+ if all(len(self.rank2steps[i]) >= 2 for i in range(self.n_rank_per_host)):
+ # split
+ new_rank2buffer: List[List[DeviceTimerStreamRecord]] = [[] for _ in range(self.n_rank_per_host)]
+ for rank in range(self.n_rank_per_host):
+ # use record.copy to avoid gc failure and memory leaking
+ new_rank2buffer[rank] = [
+ record.copy()
+ for record in self.rank2buffer[rank]
+ if record.step not in self.rank2steps[rank][0]
+ ]
+ self.rank2buffer[rank] = [
+ record for record in self.rank2buffer[rank] if record.step in self.rank2steps[rank][0]
+ ]
+ self.dump_records()
+ # update
+ self.rank2buffer = new_rank2buffer
+ for rank in range(self.n_rank_per_host):
+ self.rank2steps[rank].popleft()
+ else:
+ # assume local_rank is in [0...n_rank_per_device-1]
+ records = parse_record(
+ metric_name,
+ elapsed,
+ recent_elapsed_raw_parts,
+ recent_since_start_raw_parts,
+ tags,
+ step_range,
+ world_info,
+ extra,
+ )
+ self.rank2buffer[local_rank].extend(records)
+ for record in records:
+ self.rank2steps[local_rank][-1].add(record.step)
diff --git a/vescale/ndtimeline/handlers/logging_handler.py b/vescale/ndtimeline/handlers/logging_handler.py
new file mode 100644
index 0000000..ca21e4f
--- /dev/null
+++ b/vescale/ndtimeline/handlers/logging_handler.py
@@ -0,0 +1,47 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from typing import List, Dict, Any
+
+from .handler_base import NDHandler
+from ..world_info import WorldInfo
+from ..logger import NDTimelineLogger
+
+
+class LoggingNDHandler(NDHandler):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def call_impl(
+ self,
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+ ) -> Any:
+ NDTimelineLogger().debug(
+ f"#recent_elapsed_raw_parts: {len(recent_elapsed_raw_parts)}, #recent_since_start_raw_parts {len(recent_since_start_raw_parts)}"
+ )
+ if len(step_range) < 1:
+ raise ValueError(f"step_range length is {len(step_range)}")
+ NDTimelineLogger().info(
+ f"[rank{world_info.topo_info.rank}, step{step_range[0]}-{step_range[-1]}]: {len(recent_since_start_raw_parts)} times {metric_name} total cost: {elapsed*1000:.2f}ms"
+ )
diff --git a/vescale/ndtimeline/handlers/parser_handler.py b/vescale/ndtimeline/handlers/parser_handler.py
new file mode 100644
index 0000000..d56cccf
--- /dev/null
+++ b/vescale/ndtimeline/handlers/parser_handler.py
@@ -0,0 +1,206 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import time
+import itertools
+from dataclasses import dataclass
+from typing import List, Dict, Any
+
+from ..logger import NDTimelineLogger
+from .handler_base import NDHandler
+from ..exceptions import NDHandlerError
+from ..world_info import WorldInfo
+from ..variables import NDTIMELINE_INNER_GLOBAL_STEP_KEY
+
+
+@dataclass
+class DeviceTimerStreamRecord:
+ ts: int # record time for partition purpose
+ rank: int
+ metric: str
+ iteration: int # legacy field, no meaning
+ step: int
+ avg_dur: float # time elapsed, legacy name
+ start_ts: List[float]
+ duration: List[float]
+ model_chunk: int # vpp model chunk id, start from 0
+ pp_rank: int # pp_rank legacy problem
+ dp_rank: int # the rank of existing dp group
+ tp_rank: int # the rank of existing tp group
+ ip: str
+ role_id: int # multi-role in RL
+ trial_id: int # trial id
+ run_id: int # run_id
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "ts": self.ts,
+ "rank": self.rank,
+ "metric": self.metric,
+ "iteration": self.iteration,
+ "step": self.step,
+ "value": self.avg_dur,
+ "start_ts": self.start_ts,
+ "duration": self.duration,
+ "model_chunk": self.model_chunk,
+ "stage": self.pp_rank,
+ "dp_rank": self.dp_rank,
+ "tp_rank": self.tp_rank,
+ "ip": self.ip,
+ "role": self.role_id,
+ "trial": str(self.trial_id),
+ "run_id": self.run_id,
+ }
+
+ def copy(self):
+ return DeviceTimerStreamRecord(
+ self.ts,
+ self.rank,
+ self.metric,
+ self.iteration,
+ self.step,
+ self.avg_dur,
+ self.start_ts,
+ self.duration,
+ self.model_chunk,
+ self.pp_rank,
+ self.dp_rank,
+ self.tp_rank,
+ self.ip,
+ self.role_id,
+ self.trial_id,
+ self.run_id,
+ )
+
+
+def parse_record(
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+) -> List[DeviceTimerStreamRecord]:
+ if len(recent_elapsed_raw_parts) != len(recent_since_start_raw_parts):
+ raise NDHandlerError(
+ f"recent_elapsed_raw_parts {len(recent_elapsed_raw_parts)} not"
+ f"equal to recent_since_start_raw_parts {len(recent_since_start_raw_parts)}"
+ )
+ if len(recent_elapsed_raw_parts) != len(tags):
+ raise NDHandlerError(f"recent_elapsed_raw_parts {len(recent_elapsed_raw_parts)} not equal to tags {len(tags)}")
+
+ if len(recent_elapsed_raw_parts) == 0:
+ return []
+
+ specified_steps = [tag[NDTIMELINE_INNER_GLOBAL_STEP_KEY] for tag in tags if NDTIMELINE_INNER_GLOBAL_STEP_KEY in tag]
+
+ now = int(time.time())
+ records = []
+ if len(specified_steps) != 0:
+ # metric with `INNER_GLOBAL_STEP_KEY` does not respect `step_range`
+ # but it should always be set with `INNER_GLOBAL_STEP_KEY` and monotonically increasing
+ if len(specified_steps) != len(tags):
+ raise NDHandlerError("timer with INNER_GLOBAL_STEP_KEY's step is not always set")
+
+ # to understand the following codes,
+ # you can `print(list(itertools.groupby([21,22,23,23,23,46,46,49,50])))`
+ i = 0
+ # NDTimelineLogger().debug("{}: {}".format(metric_name, len(tags)))
+ for step, group_v in itertools.groupby(specified_steps):
+ op_counts = sum(1 for _ in group_v) # memory efficient version of `len(list(group_v))`
+ avg_dur = sum(recent_elapsed_raw_parts[i : i + op_counts]) / op_counts if op_counts != 0 else 0.0
+ record = DeviceTimerStreamRecord(
+ ts=now,
+ rank=world_info.topo_info.rank,
+ metric=metric_name,
+ iteration=0,
+ step=step,
+ avg_dur=avg_dur,
+ start_ts=recent_since_start_raw_parts[i : i + op_counts],
+ duration=recent_elapsed_raw_parts[i : i + op_counts],
+ model_chunk=0,
+ pp_rank=world_info.topo_info.pp_rank,
+ dp_rank=world_info.topo_info.dp_rank,
+ tp_rank=world_info.topo_info.tp_rank,
+ ip=world_info.topo_info.ip,
+ role_id=world_info["role_id"],
+ trial_id=world_info["trial_id"],
+ run_id=world_info["run_id"],
+ )
+ records.append(record)
+ i += op_counts
+ else:
+ if len(step_range) == 0:
+ raise NDHandlerError(f"step_range {step_range} length is zero")
+ if len(recent_elapsed_raw_parts) % len(step_range) != 0:
+ fmt_str = (
+ "len(recent_elapsed_raw_parts) {} of {} "
+ + "is not multiply of len(step_range) {}; "
+ + "if you can't ensure op counts in every step are equal,"
+ + "please explicitly use `step_getter`"
+ )
+ raise NDHandlerError(fmt_str.format(metric_name, len(recent_elapsed_raw_parts), len(step_range)))
+ NDTimelineLogger().debug(f"{metric_name}: {len(recent_elapsed_raw_parts)} in {step_range}")
+ num_step_ops = len(recent_elapsed_raw_parts) // len(step_range)
+ for i, step in enumerate(step_range):
+ avg_dur = sum(recent_elapsed_raw_parts[i * num_step_ops : (i + 1) * num_step_ops]) / num_step_ops
+ record = DeviceTimerStreamRecord(
+ ts=now,
+ rank=world_info.topo_info.rank,
+ metric=metric_name,
+ iteration=0,
+ step=step,
+ avg_dur=avg_dur,
+ start_ts=recent_since_start_raw_parts[i * num_step_ops : (i + 1) * num_step_ops],
+ duration=recent_elapsed_raw_parts[i * num_step_ops : (i + 1) * num_step_ops],
+ model_chunk=0,
+ pp_rank=world_info.topo_info.pp_rank,
+ dp_rank=world_info.topo_info.dp_rank,
+ tp_rank=world_info.topo_info.tp_rank,
+ ip=world_info.topo_info.ip,
+ role_id=world_info["role_id"],
+ trial_id=world_info["trial_id"],
+ run_id=world_info["run_id"],
+ )
+ records.append(record)
+ return records
+
+
+class ParserNDHandler(NDHandler):
+ def call_impl(
+ self,
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+ ) -> Any:
+ return parse_record(
+ metric_name,
+ elapsed,
+ recent_elapsed_raw_parts,
+ recent_since_start_raw_parts,
+ tags,
+ step_range,
+ world_info,
+ extra,
+ )
diff --git a/vescale/ndtimeline/handlers/sock_handler.py b/vescale/ndtimeline/handlers/sock_handler.py
new file mode 100644
index 0000000..6d6cc0a
--- /dev/null
+++ b/vescale/ndtimeline/handlers/sock_handler.py
@@ -0,0 +1,107 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import time
+import traceback
+import socket
+from typing import List, Dict, Any
+
+from ..logger import NDTimelineLogger
+from ..binary_protocol import serialize_to_package
+from .handler_base import NDHandler
+from ..world_info import WorldInfo
+from ..variables import SOCK_PATH, SOCK_TIMEOUT_CLIENT
+
+
+class SockNDHandler(NDHandler):
+ def __init__(self, timeout: float = SOCK_TIMEOUT_CLIENT, sock_path: str = SOCK_PATH):
+ super().__init__(ignore_metrics=[])
+ self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self.sock.settimeout(timeout)
+ self.sock_path = sock_path
+ self.timeout = timeout
+ self.initialized = False
+ self.server_exited = False
+ self.try_to_connect()
+
+ def try_to_connect(self, must=False):
+ if self.initialized:
+ return
+ if must:
+ retry = 50
+ else:
+ retry = 1
+ backoff = 0.8 # seconds
+ for _ in range(retry + 1):
+ err_msg = ""
+ try:
+ self.sock.connect(self.sock_path)
+ self.initialized = True
+ break
+ except OSError as e:
+ if e.errno == 106 and e.strerror == "Transport endpoint is already connected":
+ # might be called in multiple threads
+ # but for one process, only one connection is required
+ self.initialized = True
+ break
+ else:
+ err_msg = traceback.format_exc()
+ time.sleep(backoff)
+ except Exception:
+ err_msg = traceback.format_exc()
+ time.sleep(backoff)
+
+ if must and not self.initialized:
+ NDTimelineLogger().error(f"initialize sock handler failed: {err_msg}")
+
+ def call_impl(
+ self,
+ metric_name: str,
+ elapsed: float,
+ recent_elapsed_raw_parts: List[float],
+ recent_since_start_raw_parts: List[float],
+ tags: List[Dict[str, Any]],
+ step_range: range,
+ world_info: WorldInfo,
+ extra: Dict[str, Any],
+ ) -> Any:
+ self.try_to_connect(True)
+ if self.server_exited:
+ return
+ try:
+ st = time.perf_counter()
+ pkg = serialize_to_package(
+ {
+ "metric_name": metric_name,
+ "elapsed": elapsed,
+ "recent_elapsed_raw_parts": recent_elapsed_raw_parts,
+ "recent_since_start_raw_parts": recent_since_start_raw_parts,
+ "tags": tags,
+ "step_range": step_range,
+ "world_info": world_info,
+ "extra": extra,
+ }
+ )
+ self.sock.sendall(pkg)
+ NDTimelineLogger().debug(f"serialize and send data: {(time.perf_counter() - st) * 1000:3.3f}ms")
+ except BrokenPipeError as e:
+ NDTimelineLogger().error(f"{e}, server exit")
+ self.server_exited = True
+ except socket.timeout:
+ NDTimelineLogger().warning(f"socket timeout {traceback.format_exc()}")
+ except Exception:
+ NDTimelineLogger().error(traceback.format_exc())
diff --git a/vescale/ndtimeline/is_internal.py b/vescale/ndtimeline/is_internal.py
new file mode 100644
index 0000000..0f140e3
--- /dev/null
+++ b/vescale/ndtimeline/is_internal.py
@@ -0,0 +1,23 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+try:
+ from ._internal import is_internal_vescale
+except ImportError:
+
+ def is_internal_vescale():
+ return False
diff --git a/vescale/ndtimeline/logger.py b/vescale/ndtimeline/logger.py
new file mode 100644
index 0000000..929c2a1
--- /dev/null
+++ b/vescale/ndtimeline/logger.py
@@ -0,0 +1,41 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import logging
+import sys
+import os
+
+
+class NDTimelineLogger:
+ def __new__(cls):
+ if not hasattr(cls, "instance"):
+ level = logging.getLevelName(os.getenv("VESCALE_NDTIMELINE_LOG_LEVEL", "INFO"))
+ if isinstance(level, str):
+ # if NDTIMELINE_LOG_LEVEL has an illegal value
+ # logging.getLevelName returns a str `Level xxx`
+ level = logging.WARNING
+ formatter = logging.Formatter(
+ "[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d][pid:%(process)d] - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+ handler = logging.StreamHandler(stream=sys.stderr)
+ handler.setFormatter(formatter)
+ cls.instance = logging.getLogger("ndtimeline")
+ cls.instance.addHandler(handler)
+ cls.instance.setLevel(level)
+ cls.instance.propagate = False
+ return cls.instance
diff --git a/vescale/ndtimeline/pool.py b/vescale/ndtimeline/pool.py
new file mode 100644
index 0000000..4cc5c34
--- /dev/null
+++ b/vescale/ndtimeline/pool.py
@@ -0,0 +1,78 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import threading
+from collections import deque
+from typing import Dict, Any, Optional
+
+import torch
+from torch.cuda import Event
+
+from .variables import DEFAULT_CUDA_EVENT_POOL_SIZE
+
+
+class CudaEventPool:
+ def __init__(
+ self, device: Optional[int] = None, init_sz: int = DEFAULT_CUDA_EVENT_POOL_SIZE, blocking: bool = False
+ ) -> None:
+ self._pool = deque()
+ self._device = device
+ self._event_attr = {"enable_timing": True, "blocking": blocking, "interprocess": False}
+
+ self._mtx = threading.Lock()
+
+ for _ in range(init_sz):
+ event = Event(**self._event_attr)
+ event.tag = {}
+ self._pool.append(event)
+ event.record() # warmup
+
+ def get(self, tag: Dict[str, Any]):
+ device = torch.cuda.current_device()
+ if self._device is not None:
+ device = self._device
+ with torch.cuda.device(device):
+ try:
+ with self._mtx:
+ event = self._pool.popleft()
+ except IndexError:
+ event = Event(**self._event_attr)
+ event.tag = tag.copy()
+ return event
+
+ def release(self, event: Event):
+ with self._mtx:
+ self._pool.append(event)
+
+
+class DefaultEventPool:
+ initialized = False
+
+ @classmethod
+ def init(cls, device: Optional[int] = None):
+ assert not cls.initialized
+ cls._default_cuda_event_pool = CudaEventPool(device=device, blocking=True)
+ cls.initialized = True
+
+ @classmethod
+ def get(cls, tag: Optional[Dict[str, Any]] = None):
+ tag = tag if tag is not None else {}
+ return cls._default_cuda_event_pool.get(tag)
+
+ @classmethod
+ def release(cls, event: Event):
+ cls._default_cuda_event_pool.release(event)
diff --git a/vescale/ndtimeline/predefined.py b/vescale/ndtimeline/predefined.py
new file mode 100644
index 0000000..1cccffb
--- /dev/null
+++ b/vescale/ndtimeline/predefined.py
@@ -0,0 +1,30 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+RECV_FORWARD = "forward-recv"
+RECV_BACKWARD = "backward-recv"
+SEND_FORWARD = "forward-send"
+SEND_BACKWARD = "backward-send"
+SEND_FORWARD_RECV_BACKWARD = "forward-send-backward-recv"
+SEND_BACKWARD_RECV_FORWARD = "backward-send-forward-recv"
+CROSS_MESH_RECV = "cross-mesh-recv"
+CROSS_MESH_SEND = "cross-mesh-send"
+FORWARD_COMPUTE = "forward-compute"
+BACKWARD_COMPUTE = "backward-compute"
+UNSHARD_AG = "unshard-all-gather"
+GRAD_RS = "grad-reduce-scatter"
+GRAD_AR = "grad-all-reduce"
diff --git a/vescale/ndtimeline/sock_streamer.py b/vescale/ndtimeline/sock_streamer.py
new file mode 100644
index 0000000..34b1983
--- /dev/null
+++ b/vescale/ndtimeline/sock_streamer.py
@@ -0,0 +1,132 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import socket
+import socketserver
+import os
+import traceback
+import queue
+import threading
+from typing import Callable, List, Optional
+
+import torch.multiprocessing as mp
+from torch.multiprocessing import ProcessContext
+
+from .logger import NDTimelineLogger
+from .binary_protocol import recv_and_validate, loads_fn
+from .exceptions import ProtocolValidationError, NDHandlerError
+from .variables import SOCK_PATH, SOCK_PARENT_DIR
+
+q = None
+
+
+def internal_queue_consume(handlers: Optional[List[Callable]] = None):
+ if handlers is None:
+ handlers = []
+ global q
+ while True:
+ try:
+ args = q.get(block=True)
+ for handler in handlers:
+ handler(
+ args["metric_name"],
+ args["elapsed"],
+ args["recent_elapsed_raw_parts"],
+ args["recent_since_start_raw_parts"],
+ args["tags"],
+ args["step_range"],
+ args["world_info"],
+ args["extra"],
+ )
+ except NDHandlerError as e:
+ NDTimelineLogger().error(e)
+ NDTimelineLogger().warning(traceback.format_exc())
+ continue
+ except queue.Empty:
+ continue
+ except Exception as e:
+ NDTimelineLogger().error(e)
+ NDTimelineLogger().error(traceback.format_exc())
+ continue
+
+
+class MsgHandler(socketserver.BaseRequestHandler):
+ def handle(self):
+ global q
+ # self.request is a socket, automatically closed after `handle`
+ assert q is not None
+ preload_data = bytearray()
+ while True:
+ try:
+ payload = recv_and_validate(self.request.recv, preload_data)
+ args = loads_fn(payload)
+ q.put(args)
+ except ProtocolValidationError:
+ pass
+ except ValueError as e:
+ NDTimelineLogger().error(e)
+ NDTimelineLogger().error(traceback.format_exc())
+ except socket.timeout:
+ NDTimelineLogger().error("socket.timeout")
+ NDTimelineLogger().error(traceback.format_exc())
+ except BrokenPipeError:
+ NDTimelineLogger().info("client exit")
+ break
+ except Exception:
+ NDTimelineLogger().error(traceback.format_exc())
+ break
+
+
+class NDtimelineStreamer:
+ p: ProcessContext
+ initialized: bool = False
+
+ @classmethod
+ def init(cls, local_rank: int, handlers: Optional[List[Callable]] = None):
+ if local_rank != 0:
+ return
+ if cls.initialized:
+ NDTimelineLogger().warning("NDtimelineStreamer has already been initialized, skipped")
+ return
+ handlers = handlers if handlers is not None else []
+ try:
+ if os.path.exists(SOCK_PATH):
+ os.remove(SOCK_PATH)
+ if not os.path.exists(SOCK_PARENT_DIR):
+ os.makedirs(SOCK_PARENT_DIR, exist_ok=True)
+ cls.p = mp.spawn(
+ fn=NDtimelineStreamer.run, args=(handlers,), nprocs=1, join=False, daemon=True, start_method="spawn"
+ )
+ NDTimelineLogger().info("ndtimeline streamer started")
+ cls.initialized = True
+ except Exception:
+ NDTimelineLogger().error("NDtimelineStreamer init failed")
+ NDTimelineLogger().error(traceback.format_exc())
+
+ @staticmethod
+ def run(process_index, handlers: List[Callable]):
+ global q
+ # in order to save memory of main process, `q` is initialized here
+ q = queue.Queue(500000)
+ mq_thread = threading.Thread(
+ target=internal_queue_consume, args=(handlers,), daemon=True, name="internal_queue_consume"
+ )
+ mq_thread.start()
+
+ with socketserver.ThreadingUnixStreamServer(SOCK_PATH, MsgHandler) as server:
+ server.daemon_threads = True
+ server.serve_forever()
diff --git a/vescale/ndtimeline/stream.py b/vescale/ndtimeline/stream.py
new file mode 100644
index 0000000..a06b072
--- /dev/null
+++ b/vescale/ndtimeline/stream.py
@@ -0,0 +1,79 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import torch
+from .logger import NDTimelineLogger
+
+NCCL_STREAMS = {}
+DEVICE = None
+
+
+def get_nccl_p2p_stream(name: str, nccl_pg: "torch.distributed.ProcessGroup", peer, is_batched):
+ global NCCL_STREAMS, DEVICE
+ if DEVICE is None:
+ DEVICE = torch.device("cuda", index=torch.cuda.current_device())
+ if name in NCCL_STREAMS and NCCL_STREAMS[name] is not None:
+ return NCCL_STREAMS[name]
+ if hasattr(nccl_pg, "_get_backend"):
+ nccl_backend = nccl_pg._get_backend(DEVICE)
+ else:
+ # before torch 2.x torch._C._distributed_c10d.ProcessGroupNCCL is a subclass of
+ # torch.distributed.ProcessGroup
+ nccl_backend = nccl_pg
+ if hasattr(nccl_backend, "get_p2p_cuda_stream_id"):
+ stream_id = nccl_backend.get_p2p_cuda_stream_id(DEVICE.index, peer, is_batched)
+ NDTimelineLogger().debug(f"[{DEVICE.index}]{name} [{peer}] stream_id={stream_id}")
+ if stream_id < 0:
+ rank = nccl_pg.rank()
+ NDTimelineLogger().info(f"[{rank}]{name} is_batched={is_batched} p2p stream is not available, skipped")
+ return None
+ _CUDA_DEVICE = 1
+ nccl_stream = torch.cuda.Stream(stream_id=stream_id, device_index=DEVICE.index, device_type=_CUDA_DEVICE)
+ rank = nccl_pg.rank()
+ msg = f"[{rank}]{name} nccl p2p stream id={stream_id} device={DEVICE} stream={nccl_stream}"
+ NDTimelineLogger().debug(msg)
+ NCCL_STREAMS[name] = nccl_stream
+ return nccl_stream
+ return None
+
+
+def get_nccl_coll_stream(name: str, nccl_pg: "torch.distributed.ProcessGroup", nccl_tensor: torch.Tensor):
+ global NCCL_STREAMS
+ if name in NCCL_STREAMS and NCCL_STREAMS[name] is not None:
+ return NCCL_STREAMS[name]
+ device = nccl_tensor.device
+ if hasattr(nccl_pg, "_get_backend"):
+ nccl_backend = nccl_pg._get_backend(device)
+ else:
+ # before torch 2.x torch._C._distributed_c10d.ProcessGroupNCCL is a subclass of
+ # torch.distributed.ProcessGroup
+ nccl_backend = nccl_pg
+ if hasattr(nccl_backend, "get_coll_cuda_stream_id"):
+ NDTimelineLogger().info(nccl_backend)
+ stream_id = nccl_backend.get_coll_cuda_stream_id([nccl_tensor])
+ if stream_id < 0:
+ rank = nccl_pg.rank()
+ NDTimelineLogger().info(f"[{rank}]{name} coll stream is not available, skipped")
+ return None
+ _CUDA_DEVICE = 1
+ nccl_stream = torch.cuda.Stream(stream_id=stream_id, device_index=device.index, device_type=_CUDA_DEVICE)
+ rank = nccl_pg.rank()
+ msg = f"[{rank}]{name} nccl coll stream id={stream_id} device={device} stream={nccl_stream}"
+ NDTimelineLogger().debug(msg)
+ NCCL_STREAMS[name] = nccl_stream
+ return nccl_stream
+ return None
diff --git a/vescale/ndtimeline/timer.py b/vescale/ndtimeline/timer.py
new file mode 100644
index 0000000..f913dd9
--- /dev/null
+++ b/vescale/ndtimeline/timer.py
@@ -0,0 +1,756 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import dataclasses
+import time
+import traceback
+import gc
+import contextlib
+from decimal import Decimal
+from enum import Enum, unique
+from dataclasses import dataclass
+from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED
+from typing import List, Dict, Any, Callable, Optional, Tuple, Union
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+from functools import wraps
+
+import torch
+
+from .pool import DefaultEventPool, Event
+from .world_info import WorldInfo
+from .stream import get_nccl_p2p_stream, get_nccl_coll_stream
+from .logger import NDTimelineLogger
+from .variables import (
+ NDTIMELINE_INNER_GLOBAL_STEP_KEY,
+ NDTIMELINE_STREAM_KEY,
+ NDTIMELINE_FLUSH_SEPCIAL,
+)
+
+
+class GlobalReferenceTime:
+ local_rank: int = 0
+ world_size: int = 0
+ device: torch.device = None
+ # global ref events
+ ref_events: List[torch.cuda.Event] = []
+ ref_pointer: int = 0
+ clock_diff: float = 0.0 # ms
+ initial_min_clock: int = 0 # ns
+ last_calibrated_at: float = 0.0 # ms
+ gpu_clock_residual_coef: float = 1.0
+ initialized: bool = False
+
+ @classmethod
+ def init(cls, world_sz: int, device: Optional[Union[int, torch.device]] = None):
+ if isinstance(device, int):
+ cls.device = torch.device(f"cuda:{device}")
+ cls.local_rank = device
+ elif isinstance(device, torch.device):
+ cls.device = device
+ cls.local_rank = device.index
+ elif device is None:
+ cls.device = torch.device(f"cuda:{torch.cuda.current_device()}")
+ cls.local_rank = torch.cuda.current_device()
+ else:
+ raise RuntimeError(f"device must be int or torch.device or None, but got {type(device)}")
+ cls.world_size = world_sz
+ assert isinstance(cls.device, torch.device)
+ with torch.cuda.device(cls.device.index):
+ cls.ref_events = [
+ torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False) for _ in range(2)
+ ]
+ # warmup
+ for e in cls.ref_events:
+ e.record(stream=torch.cuda.default_stream())
+ for e in cls.ref_events:
+ e.synchronize()
+ cls.calibrate()
+ if cls.local_rank == 0:
+ NDTimelineLogger().debug(f"cls.initial_min_clock: {cls.initial_min_clock}ns")
+ cls.initialized = True
+
+ @classmethod
+ def sync_events(cls):
+ for i in range(len(cls.ref_events)):
+ cls.ref_events[i].synchronize()
+
+ @classmethod
+ def calibrate(cls):
+ # round-robin
+ calibrate_st = time.perf_counter()
+ next_pointer = (cls.ref_pointer + 1) % len(cls.ref_events)
+ cls.ref_pointer = next_pointer
+ ref = cls.ref_events[next_pointer]
+ with torch.cuda.device(cls.device.index):
+ if not cls.initialized:
+ torch.distributed.barrier()
+ torch.cuda.synchronize()
+ # torch.cuda.default_stream().synchronize()
+ ref.record(stream=torch.cuda.default_stream())
+ ref.synchronize()
+ ts_ns = int(time.time_ns())
+ ts = ts_ns / 1e6
+
+ if not cls.initialized:
+ my_clock = torch.tensor([ts_ns], dtype=torch.long, device=cls.device)
+ world_clocks = [torch.zeros([1], dtype=torch.long, device=cls.device) for _ in range(cls.world_size)]
+ torch.distributed.all_gather(world_clocks, my_clock)
+ all_clocks = [r.cpu().tolist()[0] for r in world_clocks]
+ min_clock = min(all_clocks)
+ cls.initial_min_clock = min_clock
+
+ cls.clock_diff = (ts_ns - cls.initial_min_clock) / 1e6 # to unit ms
+
+ # cpu-gpu calibrate
+ cpu_time = ts - cls.last_calibrated_at # ms
+ gpu_time = 0.0
+ cls.last_calibrated_at = ts # ms
+ if cls.initialized and 2 * 1e3 < cpu_time < 200000 * 1e3:
+ gpu_time = abs(cls.ref_events[0].elapsed_time(cls.ref_events[1])) # ms
+ gpu_cpu_diff = Decimal((gpu_time) - (cpu_time)) / Decimal(gpu_time)
+ cls.gpu_clock_residual_coef = float(1 - gpu_cpu_diff)
+ if cls.local_rank == 0:
+ NDTimelineLogger().info(
+ f"local rank 0, calibrate sync cpu moment: {ts_ns} ns, clock diff: {cls.clock_diff} ms, "
+ f"initial min: {cls.initial_min_clock} ns, "
+ f"gpu clock redidual coef: {cls.gpu_clock_residual_coef}, "
+ f"calibrate cpu: {cpu_time}ms, calibrate gpu: {gpu_time}ms"
+ )
+ NDTimelineLogger().info(
+ f"rank {cls.local_rank} calibrate cost {1000 * (time.perf_counter() - calibrate_st):4.2f}ms"
+ )
+
+ @classmethod
+ def elapsed_time(cls, end_event):
+ # cuda event elapsed_time return in unit ms
+ gpu_time = cls.ref_events[cls.ref_pointer].elapsed_time(end_event)
+ return gpu_time * cls.gpu_clock_residual_coef + cls.last_calibrated_at
+
+ @classmethod
+ def since_global_start_ts(cls, unix_ts):
+ # unix_ts in unit s
+ return unix_ts - cls.initial_min_clock / 1e9
+
+
+@unique
+class NDMetricLevel(Enum):
+ """
+ NDMetricLevel is used to define the level of metric.
+ """
+
+ FRAMEWORK_INFO = 2
+ USER_INFO = 3
+ INFO = 4
+
+ FRAMEWORK_DEBUG = 12
+ USER_DEBUG = 13
+ DEBUG = 14
+
+ FRAMEWORK_TRACE = 102
+ USER_TRACE = 103
+ TRACE = 104
+
+ def __lt__(self, other) -> bool:
+ return self.value < other.value
+
+ def __le__(self, other) -> bool:
+ return self.value <= other.value
+
+ def __gt__(self, other) -> bool:
+ return self.value > other.value
+
+ def __ge__(self, other) -> bool:
+ return self.value >= other.value
+
+ def __eq__(self, other) -> bool:
+ return self.value == other.value
+
+ def __neq__(self, other) -> bool:
+ return self.value != other.value
+
+
+@dataclass(frozen=False)
+class DeviceTimerMeta:
+ name: str = ""
+ is_cpu_op: bool = False
+ legal_tags: List[str] = dataclasses.field(default_factory=list)
+ step_getter: Optional[Callable] = None
+ enabled: bool = True
+ level: NDMetricLevel = dataclasses.field(default_factory=lambda: NDMetricLevel.FRAMEWORK_DEBUG)
+ device_id: int = -1
+ dispatch_mode: Literal["selected", "all"] = "all"
+ dst_names: List[str] = dataclasses.field(default_factory=list)
+ specified_extra: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ common_extra: Dict[str, Any] = dataclasses.field(default_factory=dict)
+
+ def __post_init__(self):
+ if self.dispatch_mode not in ["selected", "all"]:
+ raise ValueError(f"invalid dispatch_mode {self.dispatch_mode}")
+ if not isinstance(self.level, NDMetricLevel):
+ raise ValueError(f"invalid type of level {type(self.level)}")
+
+ def copy(self):
+ return DeviceTimerMeta(
+ self.name,
+ self.is_cpu_op,
+ self.legal_tags.copy(),
+ self.step_getter,
+ self.enabled,
+ self.level,
+ self.device_id,
+ self.dispatch_mode,
+ self.dst_names.copy(),
+ self.specified_extra.copy(),
+ self.common_extra.copy(),
+ )
+
+
+class DeviceTimer:
+ def __init__(
+ self,
+ name: str,
+ is_cpu_op: bool = False,
+ legal_tags: Optional[List[str]] = None,
+ step_getter: Optional[Callable] = None,
+ enabled: bool = True,
+ level: NDMetricLevel = NDMetricLevel.FRAMEWORK_DEBUG,
+ device_id: int = 0,
+ dispatch_mode: Literal["selected", "all"] = "all",
+ dst_names: Optional[List[str]] = None,
+ specified_extra: Optional[Dict[str, Any]] = None,
+ common_extra: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__()
+
+ legal_tags = legal_tags if legal_tags is not None else []
+ dst_names = dst_names if dst_names is not None else []
+ specified_extra = specified_extra if specified_extra is not None else {}
+ common_extra = common_extra if common_extra is not None else {}
+
+ if dispatch_mode not in ["all", "selected"]:
+ raise ValueError(f"invaid dispatch_mode {dispatch_mode} {type(dispatch_mode)}")
+ self.meta = DeviceTimerMeta(
+ name,
+ is_cpu_op,
+ legal_tags,
+ step_getter,
+ enabled,
+ level,
+ device_id,
+ dispatch_mode,
+ dst_names,
+ specified_extra,
+ common_extra,
+ )
+ for field_name in self.meta.__dict__:
+ setattr(self, field_name, getattr(self.meta, field_name))
+ if NDTIMELINE_INNER_GLOBAL_STEP_KEY not in self.legal_tags and step_getter is not None:
+ legal_tags.append(NDTIMELINE_INNER_GLOBAL_STEP_KEY)
+ if NDTIMELINE_STREAM_KEY not in self.legal_tags:
+ legal_tags.append(NDTIMELINE_STREAM_KEY)
+ self._started: bool = False
+ self._stream: torch.cuda.Stream = None
+ # list of [start_event, stop_event]
+ self._event_pairs: List[List[Event, Event]] = []
+ self._pool = DefaultEventPool
+ # list of [start_ts, duration, tag]
+ self._extra_records: List[List[float, float, Dict[str, Any]]] = []
+
+ def __repr__(self) -> str:
+ return f"DeviceTimer with {self.meta.__repr__()}"
+
+ def is_enabled(self) -> bool:
+ return self.enabled
+
+ def enable(self):
+ self.meta.enabled = True
+ self.enabled = True
+
+ def disable(self):
+ self.meta.enabled = False
+ self.enabled = False
+
+ def insert_record(
+ self,
+ start_ts: float,
+ duration: float,
+ tag: Optional[Dict[str, Any]] = None,
+ level: NDMetricLevel = NDMetricLevel.FRAMEWORK_DEBUG,
+ ):
+ if not self.enabled or self.meta.level > level:
+ return
+ tag = tag if tag is not None else {}
+ if self.step_getter is not None:
+ tag[NDTIMELINE_INNER_GLOBAL_STEP_KEY] = self.step_getter()
+ self._extra_records.append([start_ts, duration, tag])
+
+ def start(
+ self,
+ stream: torch.cuda.Stream = None,
+ tag: Optional[Dict[str, Any]] = None,
+ level: NDMetricLevel = NDMetricLevel.FRAMEWORK_DEBUG,
+ ) -> None:
+ """Start the timer"""
+ if not self.enabled or self.meta.level > level:
+ return
+ assert not self._started, "timer has already been started"
+ tag = tag if tag is not None else {}
+ if self.step_getter is not None:
+ tag[NDTIMELINE_INNER_GLOBAL_STEP_KEY] = self.step_getter()
+ if self.is_cpu_op:
+ self._extra_records.append([time.time(), None, tag])
+ self._started = True
+ return
+ start_event = self._pool.get(tag=tag)
+ stream_args = {}
+ if stream is not None:
+ self._stream = stream
+ self._stream.wait_stream(torch.cuda.default_stream())
+ stream_args = {"stream": self._stream}
+ start_event.record(**stream_args)
+ self._event_pairs.append([start_event, None])
+ self._started = True
+
+ def stop(self, tag: Optional[Dict[str, Any]] = None, level: NDMetricLevel = NDMetricLevel.FRAMEWORK_DEBUG) -> None:
+ """Stop the timer. May be called in another thread."""
+ if not self.enabled or self.meta.level > level:
+ return
+ assert self._started, "timer is not started"
+ tag = tag if tag is not None else {}
+ if self.is_cpu_op:
+ now = time.time()
+ assert self._extra_records[-1][1] is None, "duration is already set"
+ self._extra_records[-1][1] = now - self._extra_records[-1][0]
+ self._extra_records[-1][2] = {**tag, **self._extra_records[-1][2]}
+ self._started = False
+ return
+ stop_event = self._pool.get(tag=tag)
+ stream_args = {}
+ if self._stream is not None:
+ stream_args = {"stream": self._stream}
+ stop_event.record(**stream_args)
+ assert self._event_pairs[-1][-1] is None, "stop_event is already set"
+ self._event_pairs[-1][-1] = stop_event
+ self._started = False
+
+ def reset(self) -> None:
+ self._started = False
+ self._stream = None
+ self._event_pairs = []
+ self._extra_records = []
+
+ def elapsed(self, reset=True) -> Tuple[float, List[float], List[float], List[Dict[str, Any]]]:
+ """Calculate the elapsed time."""
+ if not self.enabled:
+ return 0.0, [], [], []
+ recent_elapsed_raw_parts = [0.0] * len(self._event_pairs)
+ recent_since_start_raw_parts = [0.0] * len(self._event_pairs)
+ tags = [{}] * len(self._event_pairs)
+ elapsed = 0.0
+ with torch.cuda.device(self.device_id):
+ for i, (start_event, stop_event) in enumerate(self._event_pairs):
+ stop_event.synchronize()
+ start_event.synchronize()
+ single_elapsed = start_event.elapsed_time(stop_event) / 1e3
+ single_since = GlobalReferenceTime.elapsed_time(start_event) / 1e3
+ elapsed += single_elapsed
+ recent_elapsed_raw_parts[i] = single_elapsed
+ recent_since_start_raw_parts[i] = single_since
+ tags[i] = {**start_event.tag, **stop_event.tag}
+ tags[i] = {k: tags[i][k] for k in tags[i] if k in self.legal_tags}
+ self._pool.release(start_event)
+ self._pool.release(stop_event)
+
+ if len(self._extra_records) > 0:
+ try:
+ elapsed += sum([record[1] for record in self._extra_records])
+ except TypeError as e:
+ NDTimelineLogger().error(
+ f"exception {e} detected in `elapsed` of {self.name}, possible unmatched start stop"
+ )
+ return 0.0, [], [], []
+ self._extra_records.sort(key=lambda x: x[0])
+ if len(recent_since_start_raw_parts) == 0:
+ recent_since_start_raw_parts = [record[0] for record in self._extra_records]
+ recent_elapsed_raw_parts = [record[1] for record in self._extra_records]
+ tags = [record[2] for record in self._extra_records]
+ else:
+ i = 0
+ for record in self._extra_records:
+ while i < len(recent_since_start_raw_parts) and recent_since_start_raw_parts[i] < record[0]:
+ i += 1
+ # a.insert(len(a), x) is equivalent to a.append(x).
+ recent_since_start_raw_parts.insert(i, record[0])
+ recent_elapsed_raw_parts.insert(i, record[1])
+ tags.insert(i, record[2])
+ i += 1
+ if reset:
+ self.reset()
+ return elapsed, recent_elapsed_raw_parts, recent_since_start_raw_parts, tags
+
+
+class NDTimerManager:
+ def __init__(
+ self,
+ world_info: WorldInfo,
+ handlers: Optional[List[Callable]] = None,
+ max_workers: int = 3,
+ device_id: Optional[int] = None,
+ init_cuda_dist: bool = True,
+ metric_level: NDMetricLevel = NDMetricLevel.TRACE,
+ is_nature_step: bool = True,
+ ) -> None:
+ self._name2timer = {}
+ self._name2active_tmp: Dict[str, bool] = {}
+ self._executor = ThreadPoolExecutor(max_workers=max_workers)
+ self._futures = []
+ self._is_initailized = False
+ self.world_info = world_info
+ self.handlers = handlers if handlers is not None else []
+ self._device_id = device_id
+ self.metric_level = metric_level
+ self.is_nature_step = is_nature_step
+ self._unregistered_timer_start = []
+ self._unregistered_timer_stop = []
+ self._unregistered_timer_records_insert = []
+ self._cur_global_step = 0
+
+ if init_cuda_dist:
+ self.init_cuda_dist_associated(device_id=device_id)
+
+ @property
+ def global_step(self):
+ return self._cur_global_step
+
+ @global_step.setter
+ def global_step(self, step: int):
+ if not isinstance(step, int):
+ raise ValueError(f"step {step} is not int")
+ self._cur_global_step = step
+
+ def init_cuda_dist_associated(self, device_id: Optional[int] = None):
+ self._device_id = device_id
+ if self._device_id is not None:
+ DefaultEventPool.init(device=self._device_id)
+ GlobalReferenceTime.init(device=self._device_id, world_sz=self.world_info["world_size"])
+ else:
+ DefaultEventPool.init()
+ GlobalReferenceTime.init(world_sz=self.world_info["world_size"])
+
+ def register_timers(self, timer_metas: List[DeviceTimerMeta]) -> None:
+ for meta in timer_metas:
+ if meta.device_id == -1:
+ if not meta.is_cpu_op:
+ meta.device_id = torch.cuda.current_device()
+ else:
+ meta.device_id = 0
+ if meta.step_getter is None and self.is_nature_step:
+
+ def getter():
+ return self._cur_global_step
+
+ meta.step_getter = getter
+ assert not self._is_initailized, "DeviceTimerManager should only be initialized once"
+ NDTimerManager._register_timers(timer_metas, self._name2timer)
+ self._is_initailized = True
+
+ @staticmethod
+ def _register_timers(timer_metas: List[DeviceTimerMeta], d: Dict[str, DeviceTimer]):
+ for meta in timer_metas:
+ d[meta.name] = DeviceTimer(**meta.__dict__)
+
+ @staticmethod
+ def _flush_timers(
+ handlers: List[Callable],
+ name2timer: Dict[str, DeviceTimer],
+ step_range: range,
+ world_info: WorldInfo,
+ require_calibrate: bool = False,
+ ) -> None:
+ if require_calibrate:
+ GlobalReferenceTime.calibrate()
+ for name in name2timer:
+ timer = name2timer[name]
+ elapsed_result = timer.elapsed()
+ for handler in handlers:
+ if timer.dispatch_mode == "selected" and handler.dispatch_key not in timer.dst_names:
+ continue
+ extra = timer.common_extra
+ if handler.dispatch_key in timer.specified_extra:
+ specified_extra = timer.specified_extra[handler.dispatch_key]
+ extra = {**extra, **specified_extra}
+ try:
+ handler(name, *elapsed_result, step_range, world_info, extra)
+ except Exception as e:
+ NDTimelineLogger().error(f"handler {handler} failed: {e}")
+ NDTimelineLogger().error(traceback.format_exc())
+ timer.meta = None # in case of CudaTimer obj gc failure due to meta obj
+
+ for handler in handlers:
+ handler(NDTIMELINE_FLUSH_SEPCIAL, 0.0, [], [], [], range(0, 1), world_info, extra)
+
+ def start_timer(self, name: str, tag: Optional[Dict[str, Any]] = None) -> None:
+ assert isinstance(self, NDTimerManager) or issubclass(type(self), NDTimerManager)
+ tag = tag if tag is not None else {}
+ try:
+ if name not in self._unregistered_timer_start:
+ stream = None
+ if NDTIMELINE_STREAM_KEY in tag:
+ stream = tag[NDTIMELINE_STREAM_KEY]
+ del tag[NDTIMELINE_STREAM_KEY]
+ self._name2timer[name].start(stream=stream, tag=tag, level=self.metric_level)
+ except KeyError:
+ self._unregistered_timer_start.append(name)
+ NDTimelineLogger().warning(f"metric {name} is not registered when `start_timer`, skipped")
+ except Exception:
+ NDTimelineLogger().error(f"trigger exception when `start_timer` metric {name}")
+ NDTimelineLogger().error(traceback.format_exc())
+
+ def stop_timer(self, name, tag: Optional[Dict[str, Any]] = None) -> None:
+ assert isinstance(self, NDTimerManager) or issubclass(type(self), NDTimerManager)
+ tag = tag if tag is not None else {}
+ try:
+ if name not in self._unregistered_timer_stop:
+ if NDTIMELINE_STREAM_KEY in tag:
+ del tag[NDTIMELINE_STREAM_KEY]
+ self._name2timer[name].stop(tag=tag, level=self.metric_level)
+ except KeyError:
+ self._unregistered_timer_stop.append(name)
+ NDTimelineLogger().warning(f"metric {name} is not registered when `stop_timer`, skipped")
+ except Exception:
+ NDTimelineLogger().error(f"trigger exception when `start_timer` metric {name}")
+ NDTimelineLogger().error(traceback.format_exc())
+
+ def insert_record(self, name, start_ts: float, duration: float, tag: Optional[Dict[str, Any]] = None):
+ assert isinstance(self, NDTimerManager) or issubclass(type(self), NDTimerManager)
+ tag = tag if tag is not None else {}
+ try:
+ if name not in self._unregistered_timer_records_insert:
+ self._name2timer[name].insert_record(start_ts, duration, tag, self.metric_level)
+ except KeyError:
+ self._unregistered_timer_records_insert.append(name)
+ NDTimelineLogger().warning(f"metric {name} is not registered when `insert_record`, skipped")
+ except Exception:
+ NDTimelineLogger().error(f"trigger exception when `insert_record` metric {name}")
+ NDTimelineLogger().error(traceback.format_exc())
+
+ def clear(self):
+ self.async_flush(
+ step_range=range(0, 10),
+ next_iter_enabled=False,
+ collect_future=False,
+ submit2handler=False,
+ keep_timer_state=True,
+ )
+
+ def disable_and_save(self):
+ is_autogc = gc.isenabled()
+ if is_autogc:
+ gc.disable()
+ for k in self._name2timer:
+ self._name2active_tmp[k] = self._name2timer[k].is_enabled()
+ self._name2timer[k].disable()
+ if is_autogc:
+ gc.enable()
+
+ def recover_from_history(self):
+ is_autogc = gc.isenabled()
+ if is_autogc:
+ gc.disable()
+ for k in self._name2timer:
+ if k in self._name2active_tmp:
+ if self._name2active_tmp[k]:
+ self._name2timer[k].enable()
+ else:
+ self._name2timer[k].disable()
+ del self._name2active_tmp[k]
+ if is_autogc:
+ gc.enable()
+
+ def async_flush(
+ self,
+ step_range: range,
+ next_iter_enabled: bool = True,
+ world_info: Optional[WorldInfo] = None,
+ handlers: Optional[List[Callable[..., None]]] = None,
+ collect_future: bool = True,
+ submit2handler: bool = True,
+ force_calibrate: bool = False,
+ dynamic_calibrate: bool = False,
+ keep_timer_state: bool = False,
+ sequential_calibrate: bool = True,
+ ):
+ st = time.perf_counter()
+ handlers = handlers if handlers is not None else []
+ enabled_timer_names = [name for name in self._name2timer if self._name2timer[name].meta.enabled]
+ NDTimelineLogger().debug(f"async flush triggered, {enabled_timer_names}")
+
+ unregistered = self._unregistered_timer_start.copy()
+ unregistered.extend(self._unregistered_timer_stop)
+ unregistered.extend(self._unregistered_timer_records_insert)
+ unregistered = list(set(unregistered))
+ if len(unregistered) > 0:
+ NDTimelineLogger().warning(f"unregistered timers: {unregistered}")
+
+ past_name2timer = self._name2timer
+ fresh_name2timer = {}
+ timer_metas = [past_name2timer[name].meta.copy() for name in past_name2timer]
+
+ if not keep_timer_state:
+ for meta in timer_metas:
+ meta.enabled = next_iter_enabled
+
+ # filter enabled timer
+ past_name2timer = {
+ name: past_name2timer[name]
+ for name in past_name2timer
+ if past_name2timer[name].meta.enabled and past_name2timer[name].meta.level <= self.metric_level
+ }
+
+ NDTimerManager._register_timers(timer_metas, fresh_name2timer)
+
+ is_autogc = gc.isenabled()
+ if is_autogc:
+ gc.disable()
+ self._name2timer = fresh_name2timer
+ if is_autogc:
+ gc.enable()
+
+ if collect_future:
+ i = 0
+ while i < len(self._futures):
+ if self._futures[i].done():
+ e = self._futures[i].exception()
+ if e is not None:
+ NDTimelineLogger().error("".join(traceback.format_exception(type(e), e, e.__traceback__)))
+ self._futures.pop(i)
+ else:
+ i += 1
+
+ if len(handlers) == 0:
+ handlers = self.handlers
+
+ require_calibrate = force_calibrate or (
+ dynamic_calibrate and GlobalReferenceTime.last_calibrated_at < (time.time() - 30 * 60) * 1e3
+ )
+ if require_calibrate and sequential_calibrate:
+ GlobalReferenceTime.calibrate()
+ require_calibrate = False
+
+ if submit2handler and len(past_name2timer) > 0:
+ world_info = self.world_info if world_info is None else self.world_info
+ future = self._executor.submit(
+ NDTimerManager._flush_timers, handlers, past_name2timer, step_range, world_info, require_calibrate
+ )
+ self._futures.append(future)
+
+ NDTimelineLogger().debug(f"async flush cost {1000 * (time.perf_counter() - st):4.2f}ms")
+
+ def wait(self) -> None:
+ if len(self._futures) == 0:
+ return
+ torch.distributed.barrier()
+ # wait at most 10 seconds
+ wait(self._futures, timeout=10, return_when=ALL_COMPLETED)
+ for f in self._futures:
+ e = f.exception(timeout=0.001)
+ if e is not None:
+ NDTimelineLogger().error("".join(traceback.format_exception(type(e), e, e.__traceback__)))
+ self._futures = []
+ # streamer can not respond to training process now
+ # assume msg will be handled in 3 seconds
+ time.sleep(3)
+
+
+class Singleton(type):
+ _instances = {}
+
+ def __call__(self, *args, **kwargs):
+ if self not in self._instances:
+ self._instances[self] = super().__call__(*args, **kwargs)
+ self._singleton_inited = True
+ return self._instances[self]
+
+
+class NDTimerManagerSingleton(NDTimerManager, metaclass=Singleton):
+ @classmethod
+ def is_initialized(cls) -> bool:
+ return hasattr(cls, "_singleton_inited") and cls._singleton_inited
+
+
+@contextlib.contextmanager
+def ndtimeit(name: str, tag: Optional[Dict[str, Any]] = None):
+ """reentrant timeit context manager"""
+ if not NDTimerManagerSingleton.is_initialized():
+ yield
+ return
+ tag = tag if tag is not None else {}
+ NDTimerManagerSingleton().start_timer(name, tag)
+ try:
+ yield
+ finally:
+ NDTimerManagerSingleton().stop_timer(name)
+
+
+@contextlib.contextmanager
+def ndtimeit_p2p(name: str, nccl_pg, peer: int, is_batched: bool = True, tag: Optional[Dict[str, Any]] = None):
+ if not NDTimerManagerSingleton.is_initialized():
+ yield
+ return
+ p2p_stream = get_nccl_p2p_stream(name=name, nccl_pg=nccl_pg, peer=peer, is_batched=is_batched)
+ if tag is not None:
+ tag[NDTIMELINE_STREAM_KEY] = p2p_stream
+ else:
+ tag = {NDTIMELINE_STREAM_KEY: p2p_stream}
+ NDTimerManagerSingleton().start_timer(name, tag)
+ try:
+ yield
+ finally:
+ NDTimerManagerSingleton().stop_timer(name)
+
+
+@contextlib.contextmanager
+def ndtimeit_coll(name: str, pg, tensor: torch.Tensor, tag: Optional[Dict[str, Any]] = None):
+ if not NDTimerManagerSingleton.is_initialized():
+ yield
+ return
+ coll_stream = get_nccl_coll_stream(name, pg, tensor)
+ if tag is not None:
+ tag[NDTIMELINE_STREAM_KEY] = coll_stream
+ else:
+ tag = {NDTIMELINE_STREAM_KEY: coll_stream}
+ NDTimerManagerSingleton().start_timer(name, tag)
+ try:
+ yield
+ finally:
+ NDTimerManagerSingleton().stop_timer(name)
+
+
+def ndtimer(metric: str, tags: Optional[Dict[str, Any]] = None):
+ def _ndtimeit_decorator(func):
+ @wraps(func)
+ def with_ndtimeit(*args, **kwargs):
+ with ndtimeit(metric, tags):
+ return func(*args, **kwargs)
+
+ return with_ndtimeit
+
+ return _ndtimeit_decorator
diff --git a/vescale/ndtimeline/variables.py b/vescale/ndtimeline/variables.py
new file mode 100644
index 0000000..7ed9d65
--- /dev/null
+++ b/vescale/ndtimeline/variables.py
@@ -0,0 +1,27 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import os
+
+SOCK_TIMEOUT_CLIENT: float = 2.0 # seconds
+SOCK_PARENT_DIR: str = "/opt/tiger/tmp/ndtimeline"
+SOCK_PATH: str = os.path.join(SOCK_PARENT_DIR, "ndtimeline.sock") # /opt/tiger/tmp/ndtimeline/ndtimeline.sock
+LOCAL_LOGGING_PATH: str = SOCK_PARENT_DIR
+DEFAULT_CUDA_EVENT_POOL_SIZE: int = 20
+NDTIMELINE_INNER_GLOBAL_STEP_KEY: str = "_inner_global_step"
+NDTIMELINE_STREAM_KEY: str = "stream_key"
+NDTIMELINE_FLUSH_SEPCIAL: str = "special"
diff --git a/vescale/ndtimeline/world_info.py b/vescale/ndtimeline/world_info.py
new file mode 100644
index 0000000..197e1a1
--- /dev/null
+++ b/vescale/ndtimeline/world_info.py
@@ -0,0 +1,123 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from dataclasses import dataclass
+from typing import Any, Dict
+
+
+@dataclass(frozen=False)
+class TopoInfo:
+ rank: int = 0
+ dp_rank: int = 0
+ ddp_rank: int = 0
+ tp_rank: int = 0
+ pp_rank: int = 0
+ local_rank: int = 0
+ ip: str = "0.0.0.0"
+ dp_size: int = 1
+ ddp_size: int = 1
+ tp_size: int = 1
+ pp_size: int = 1
+ world_size: int = 1
+
+ def __post_init__(self):
+ # validation
+ for field_name in self.__dict__:
+ field_content = self.__dict__[field_name]
+ if field_name.endswith("rank") and field_content < 0:
+ raise ValueError(f"TopoInfo instance's {field_name}={field_content}, expected nonnegative number")
+ if field_name.endswith("size") and field_content <= 0:
+ raise ValueError(f"WorldInfo instance's {field_name}={field_content}, expected positive number")
+
+
+@dataclass(frozen=False)
+class TrainingInfo:
+ role_id: int = 0
+ trial_id: int = 0
+ run_id: int = 0
+
+ def __post_init__(self):
+ # validation
+ for field_name in self.__dict__:
+ field_content = self.__dict__[field_name]
+ if field_content < 0:
+ raise ValueError(f"TrainingInfo instance's {field_name}={field_content}, expected nonnegative number")
+
+
+class WorldInfo:
+ def __init__(
+ self,
+ rank: int,
+ local_rank: int,
+ dp_rank: int = 0,
+ ddp_rank: int = 0,
+ tp_rank: int = 0,
+ pp_rank: int = 0,
+ dp_size: int = 1,
+ ddp_size: int = 1,
+ tp_size: int = 1,
+ pp_size: int = 1,
+ world_size: int = 1,
+ ip: str = "0.0.0.0",
+ role_id: int = 0,
+ run_id: int = 0,
+ trial_id: int = 0,
+ **extra_meta: Dict[str, Any],
+ ):
+ self.topo_info = TopoInfo(
+ rank=rank,
+ local_rank=local_rank,
+ dp_rank=dp_rank,
+ ddp_rank=ddp_rank,
+ tp_rank=tp_rank,
+ pp_rank=pp_rank,
+ dp_size=dp_size,
+ ddp_size=ddp_size,
+ tp_size=tp_size,
+ pp_size=pp_size,
+ world_size=world_size,
+ ip=ip,
+ )
+ self.training_info = TrainingInfo(
+ role_id=role_id,
+ trial_id=trial_id,
+ run_id=run_id,
+ )
+ self.extra_info = {}
+ for k in extra_meta:
+ self.extra_info[k] = extra_meta[k]
+
+ def __repr__(self) -> str:
+ return f"WorldInfo: {self.topo_info.__repr__()} {self.training_info.__repr__()} {self.extra_info.__repr__()}"
+
+ def __getitem__(self, key: str):
+ if key in self.topo_info.__dict__:
+ return self.topo_info.__dict__[key]
+ if key in self.training_info.__dict__:
+ return self.training_info.__dict__[key]
+ if key in self.extra_info:
+ return self.extra_info[key]
+ raise KeyError(f"{key} is not found")
+
+ def __setitem__(self, key: str, value: Any):
+ if key in self.topo_info.__dict__:
+ self.topo_info.__dict__[key] = value
+ if key in self.training_info.__dict__:
+ self.training_info.__dict__[key] = value
+ if key in self.extra_info:
+ self.extra_info[key] = value
+ raise KeyError(f"{key} is not found")
diff --git a/vescale/pipe/README.md b/vescale/pipe/README.md
new file mode 100644
index 0000000..028bb44
--- /dev/null
+++ b/vescale/pipe/README.md
@@ -0,0 +1,125 @@
+# veScale Pipeline Parallel (PP)
+
+## TLDR
+
+
+
+## What is PP?
+
+`Pipeline Parallel` (`PP`) partitions layers of a model across multiple devices to form a pipelined execution of the training.
+`PP` takes as input a list of microbatches of data per iteration and performs pipelined training execution (forward, backward, and optimizer update) on each microbatch, while overlaps communication with computation on each device.
+
+## Why veScale PP?
+
+Existing `PP` systems suffer multiple drawbacks as below, which prevent productization within a company:
+
+- _Complex API_: assuming that model developers are also systems experts in `PP`
+
+- _Hacking model code_: requiring manually rewrite the model code to run `PP`
+
+- _Lacking single device abstraction_: requiring manually rewrite the training script to be `PP` device-specific
+
+- _Lacking options of pipeline construction_: relying on a single option of graph tracing, or perfect graph tracing, or solely manual construction of the pipeline.
+
+- _Lacking customizability of pipeline schedule_: deeply coupling the entire runtime (e.g., compute, communication) with a specific `PP` schedule (e.g., `1F1B`)
+
+- _Lacking diverse model support_: supporting only sequential model architecture without branching, or supporting only pipeline stages having single input or single output without multiple input/output.
+
+## What is veScale PP?
+
+`veScale PP` offers a new `PP` framework that is both _**Easy-to-Use**_ and _**Easy-to-Customize**_, thus it is used internally in our production.
+Especially, `veScale PP` provides:
+
+- _Easy API_: hiding the complexity of `PP` systems and runtimes from model developers
+
+- _Zero model code change_: keeping the original torch model code as it is for transparent pipelined models
+
+- _Single device abstraction_: keeping the single device training script as it is for transparent pipelined training on multiple devices
+
+- _Multiple options of pipeline construction_: user can flexibly choose modes:
+
+ - `GRAPH_EAGER` mode automatically traces and parses the model into a graph, splits the graph into pipeline stages, and constructs each stage for pipeline execution
+
+ - graph tracer can also be choices or users
+
+ - `MANUAL_EAGER` mode manually constructs each pipeline stage for pipeline execution, without graph tracing, parsing, and splitting.
+
+- _Customizable pipeline schedule_: empowering users to define their custom pipeline schedules, beyond our built-in schedule as below:
+
+ - `1F1B`
+
+ - `Interleaved 1F1B`
+
+ - `Zero Bubble`
+
+- _Support diverse models_: support comprehensive model archictures for non-sequential models, multiple-input-multiple-output stages, and etc.
+
+## Why is veScale PP a better option than its counterparts?
+
+- Compared with Megatron-LM's PP, `veScale PP` offers not only a better __Ease-of-Use__ experience in all aspects (easy API, zero model code, single device abstraction, options of pipeline construction) but also a plus of __Customizability__ allowing users to conveniently customize new pipeline schedules.
+
+- Compared with DeepSpeed, `veScale PP` requires no modification of model code. It further supports multi-stage scheduling for non-sequential multimodal architecture and multi-input settings instead of being constrained by `nn.Sequential`'s syntax.
+
+- Compared with the pre-release torchtitan, `veScale PP` provides: i) single device abstraction of training script, ii) wider options of graph tracer support, iii) wider model architecture support, and iv) guarantees bitwise accuracy alignment between `PP` and single device code.
+
+## How does veScale PP work?
+
+Spinning up a `PP` job typically requires three steps: i) trace and parse model graph, ii) construct pipeline stage, and iii) execute pipeline schedule. Each step is handled by `PipeParser`, `PipeModule`, and `PipeEngine`. Upon receiving the model definition, `PipeParser` (`GRAPH_EAGER` mode) breaks down the model code to the intermediate representation of low-level modules and operators up to the granularity of your choice. Under `MANUAL_EAGER` mode, users only need to assign stage modules and their communication relationships. `PipeModule` collects parameters and operators, and optimizer states belonging to the same stage, and resolves communication topology among devices. `PipeEngine` will schedule steps to execute training according to pipeline schedules.
+
+## How to use veScale PP?
+
+- Example of using `GRAPH_EAGER` mode:
+
+ ```python
+ # zero model code change
+ class EightMLP(nn.Module):
+ def __init__(self, ...):
+ self.mlp1 = MLP(...)
+ ...
+ self.mlp8 = MLP(...)
+ def forward(...):
+ ...
+
+ # An EightMLP is composed of 8 submodules called MLP
+ model = EightMLP()
+ # or model = deferred_init(EightMLP)
+
+ from vescale.plan import PipelineParallelPlan, PipelineScheduleType, PipelineSplitMethodType, ModeType
+ from vescale.pipe import construct_pipeline_stage
+ from vescale.engine import PipeEngine
+ from vescale.dtensor.device_mesh import DeviceMesh
+
+ # create 3-dim DeviceMesh
+ device_mesh = DeviceMesh("cuda", [[[0]], [[1]], [[2]], [[3]]], mesh_dim_names=("PP", "DP", "TP"))
+
+ # prepare plan for pipeline parallelism
+ pipe_plan = PipelineParallelPlan(
+ mode=ModeType.GRAPH_EAGER,
+ split_method=PipelineSplitMethodType.MANUAL,
+ num_stages=4,
+ virtual_chunks=2,
+ smallest_unsplittable_units=[f"mlp{i + 1}" for i in range(8)], # maintain hierarchy of each MLP module
+ split_points=["mlp2", "mlp4", "mlp6", "mlp8"], # managed pipeline split points by fully qualified names
+ overlap_p2p_comm=True, # speedup option
+ schedule_type=PipelineScheduleType.INTERLEAVED_1F1B,
+ )
+
+ # parse model graph, split graph, and construct pipeline stage
+ pipe_stage = construct_pipeline_stage(model, pipe_plan, device_mesh)
+
+ # prepare pipeline schedule and execution engine
+ engine = PipeEngine(pipe_stage, device_mesh, pipe_plan)
+
+ # train PP model as if on single device
+ for minibatch_data in dataloader:
+ minibatch_loss, microbatch_outputs = engine(minibatch_data)
+ minibatch_loss.backward()
+ ...
+
+ ```
+
+- Example of using `MANUAL_EAGER` mode: Coming Soon.
+
+- APIs can be found in `/vescale/pipe/pipe_stage.py` and `/vescale/pipe/pipe.py`
+
+- More examples can be found in `/test/parallel/pipeline/api/test_simple_api.py`
\ No newline at end of file
diff --git a/vescale/pipe/__init__.py b/vescale/pipe/__init__.py
new file mode 100644
index 0000000..2403b36
--- /dev/null
+++ b/vescale/pipe/__init__.py
@@ -0,0 +1,26 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from .pipe_stage import (
+ PipeModule,
+ build_shared_module_group,
+ build_stage_module_and_dependency,
+ construct_stage_modules,
+ construct_pipeline_stage,
+)
+from .pipe_parser import PipeParser, parse_model_graph, split_pipeline_point, construct_pipeline_split_graph
+from .pipe_emmiter import ScheduleEngine, validate_pipeline_schedule
diff --git a/vescale/pipe/_schedules/__init__.py b/vescale/pipe/_schedules/__init__.py
new file mode 100644
index 0000000..66cb2c6
--- /dev/null
+++ b/vescale/pipe/_schedules/__init__.py
@@ -0,0 +1,21 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from .instruction_base import StageDeps, Shape, register_instruction
+from .pipedream_flush import OneFOneBInstrcutionGenerator
+from .looping_bfs import InterleavedOneFOneBInstructionGenerator
+from .zero_bubble_v import ZeroBubbleVInstrcutionGenerator
diff --git a/vescale/pipe/_schedules/instruction_base.py b/vescale/pipe/_schedules/instruction_base.py
new file mode 100644
index 0000000..d43474e
--- /dev/null
+++ b/vescale/pipe/_schedules/instruction_base.py
@@ -0,0 +1,552 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+import enum
+from dataclasses import dataclass
+from collections import defaultdict
+from abc import ABCMeta, abstractmethod
+from typing import Sequence, Callable
+import torch
+from torch.distributed.distributed_c10d import get_rank
+from vescale.dtensor.device_mesh import DeviceMesh
+from vescale.dtensor.placement_types import Placement
+from vescale.pipe.pipe_stage import PipeModule
+from typing import List, Tuple, Union, Optional, Dict, Any
+import logging
+import numpy as np
+from vescale.plan.spec import PipelineP2PSpec
+
+Shape = Union[List[int], torch.Size]
+
+logger = logging.getLogger(__name__)
+registed_functions = {}
+
+
+def register_instruction(name):
+ assert name is not None, "The Instruction must have name"
+ if name in registed_functions:
+ msg = f"{name} allready in registed instruction"
+ logger.warning(msg)
+
+ def _register_instruction(func):
+ def wrap(*args, **kwargs):
+ return func(*args, **kwargs)
+
+ registed_functions.update({name: func})
+ return wrap
+
+ return _register_instruction
+
+
+@dataclass
+class CommPacket:
+ cur_mesh: DeviceMesh
+ peer_mesh: DeviceMesh
+ input_id: int
+ peer_stage: int
+ peer_sharding: List[Placement] = None
+ cur_sharding: List[Placement] = None
+ is_kwargs: bool = False
+
+
+class StageDeps:
+ def __init__(
+ self,
+ dep: np.ndarray,
+ meshes: List[DeviceMesh],
+ vpp_module_list: Union[List, PipeModule],
+ p2p_index_mapping: Optional[Dict[int, List[PipelineP2PSpec]]] = None,
+ ):
+ self.D = dep
+ self.M = vpp_module_list
+ self.meshes = meshes
+ self.is_vpp = self.get_num_chunks() > 1
+ self.mapping: Dict = {}
+ if p2p_index_mapping is None:
+ self.mapping = defaultdict(list)
+ self.generate_one_forward_mapping()
+ else:
+ self.mapping = p2p_index_mapping
+ self.parsing_forward_mapping()
+
+ self.recv_tables: Dict[int, List[CommPacket]] = defaultdict(list)
+ self.send_tables: Dict[int, List[CommPacket]] = defaultdict(list)
+ self.local_dataloader_list: Dict[Any, List[CommPacket]] = defaultdict(list)
+ self.construct_communication_graph()
+
+ def construct_communication_graph(self):
+ for i in range(self.num_stage):
+ cur_mesh = self.get_current_mesh(i)
+ cur_mapping = self.mapping[i] # get the index mapping i
+ prior_list = []
+ local_data_list = []
+ # stage_id: [input_idx, ...]
+ for p2p_spec in cur_mapping:
+ prev_stage_id = p2p_spec.peer_stage_idx
+ input_id = p2p_spec.peer_output_idx
+ if prev_stage_id != i: # not from self
+ prior_list.append((self.get_current_mesh(prev_stage_id), prev_stage_id, input_id))
+ else: # from self stage
+ local_data_list.append(input_id)
+
+ prior_list = sorted(prior_list, key=lambda item: (item[1], item[2]))
+ for device, pre, input_id in prior_list:
+ sr = CommPacket(
+ cur_mesh=cur_mesh, peer_mesh=device, input_id=input_id, peer_stage=pre
+ ) # input is single
+ self.recv_tables[i].append(sr)
+ for input_id in local_data_list:
+ sr = CommPacket(
+ cur_mesh=cur_mesh,
+ peer_mesh=None,
+ input_id=input_id,
+ peer_stage=None,
+ )
+ self.local_dataloader_list[i].append(sr)
+
+ # construct out degree
+ for i in range(self.num_stage):
+ prior_list = []
+ for j in range(self.num_stage):
+ if i == j: # don't check self , no cycle
+ continue
+ j_recvs = self.recv_tables[j]
+ for recv in j_recvs:
+ if recv.peer_stage == i: # is i send to j
+ send = CommPacket(
+ cur_mesh=recv.peer_mesh,
+ peer_mesh=recv.cur_mesh,
+ input_id=recv.input_id,
+ peer_stage=j,
+ )
+ prior_list.append(send)
+ # sort by input_id stage id is unneeded
+ sorted(prior_list, key=lambda item: item.input_id)
+ self.send_tables[i] = prior_list
+
+ def generate_one_forward_mapping(self):
+ for i in range(self.num_stage):
+ cur_mapping = self.mapping[i]
+ pre_stages = self.get_pre_stage(i, ignore_virtual=False)
+ assert len(pre_stages) <= 1, "multi branch stage need parse p2p_index_mapping"
+ for pre in pre_stages:
+ cur_mapping.append(PipelineP2PSpec(pre, 0))
+
+ if self.is_pipeline_first_stage(i):
+ cur_mapping.append(PipelineP2PSpec(i, 0))
+
+ def parsing_forward_mapping(self):
+ # 1: [(0,0), (1,0), (0,2)]
+ for i in range(self.num_stage):
+ if i not in self.mapping:
+ cur_indexing = []
+ pre_stages = self.get_pre_stage(i, ignore_virtual=False)
+ assert len(pre_stages) <= 1, "multi branch stage need parse p2p_index_mapping"
+ for pre in pre_stages:
+ cur_indexing.append(PipelineP2PSpec(pre, 0))
+ if self.is_pipeline_first_stage(i):
+ cur_indexing.append(PipelineP2PSpec(i, 0))
+ self.mapping.update({i: cur_indexing})
+
+ def get_send_comms(self, i):
+ return self.send_tables[i]
+
+ def get_recv_comms(self, i):
+ return self.recv_tables[i]
+
+ def get_local_comms(self, i):
+ return self.local_dataloader_list[i]
+
+ @property
+ def num_stage(self):
+ return len(self.D)
+
+ def is_first(self, s_id):
+ pre = self.D[:, s_id]
+ non_zero = np.count_nonzero(pre)
+ if non_zero == 0:
+ return True
+ return False
+
+ def is_last(self, s_id):
+ post = self.D[s_id]
+ non_zero = np.count_nonzero(post)
+ if non_zero == 0:
+ return True
+ return False
+
+ def get_pre_stage(self, i, ignore_virtual=True):
+ pre = self.D[:, i]
+ stage_ids = np.where(pre == 1)[0].tolist()
+ if self.is_first(i) and self.is_vpp and not ignore_virtual:
+ last_stages = list(filter(self.is_last, range(self.num_stage)))
+ return last_stages
+ else:
+ return stage_ids
+
+ def get_post_stage(self, i, ignore_virtual=True):
+ post = self.D[i]
+ stage_ids = np.where(post == 1)[0].tolist()
+
+ if self.is_last(i) and self.is_vpp and not ignore_virtual:
+ first_stages = list(filter(self.is_first, range(self.num_stage)))
+ return first_stages
+ else:
+ return stage_ids
+
+ def get_first_stage(self):
+ stages = []
+ for i in range(self.num_stage):
+ pre_stages = self.get_pre_stage(i)
+ if len(pre_stages) == 0: # in-degree is 0
+ stages.append(i)
+ return stages
+
+ def get_last_stage(self):
+ stages = []
+ for i in range(self.num_stage):
+ post_stages = self.get_post_stage(i)
+ if len(post_stages) == 0: # out-degree is 0
+ stages.append(i)
+ return stages
+
+ def get_current_model(self, i):
+ return self.M
+
+ def is_pipeline_first_stage(self, i):
+ pre = self.get_pre_stage(i)
+ return len(pre) == 0 # first stage has no input
+
+ def is_pipeline_last_stage(self, i):
+ post = self.get_post_stage(i)
+ return len(post) == 0 # last stage has no output
+
+ def is_vpp_first_stage(self, i, chunk_id):
+ return self.is_pipeline_first_stage(i) and chunk_id == 0
+
+ def is_vpp_last_stage(self, i, chunk_id):
+ return self.is_pipeline_last_stage(i) and (chunk_id == (self.get_num_chunks() - 1))
+
+ def get_num_chunks(self):
+ if isinstance(self.M, list):
+ return len(self.M)
+ else:
+ return self.M.virtual_chunks
+
+ def get_current_mesh(self, i):
+ return self.meshes[i]
+
+ def __str__(self):
+ tmp = "\n\n"
+ tmp += f"stages: {self.num_stage}, deps:{self.D}\n"
+ for i in range(self.num_stage):
+ tmp += f"\n===================stage:{i} start=======================\n"
+ tmp += "recv : \n"
+ for comm in self.recv_tables[i]:
+ tmp += f"\t\t recv from {comm.peer_stage} with input:{comm.input_id} comm:{comm}\n"
+ tmp += "send : \n"
+ for comm in self.send_tables[i]:
+ tmp += f"\t\t send to {comm.peer_stage} with input:{comm.input_id} comm:{comm}\n"
+ tmp += "local_dataloader_list : \n"
+ for comm in self.local_dataloader_list[i]:
+ tmp += f"\t\t local_dataloader with input:{comm.input_id} comm:{comm}\n"
+
+ tmp += f"===================stage:{i} end=======================\n\n"
+ return tmp
+
+
+def get_linear_pp_module_dep2(module_list: List, device_mesh_list: List[DeviceMesh]):
+ stage_len = len(device_mesh_list) # for forward
+ dep = np.zeros((stage_len, stage_len), dtype=np.int64)
+ for i in range(stage_len - 1):
+ dep[i][i + 1] = 1 # direct graph
+ return StageDeps(dep, device_mesh_list, module_list)
+
+
+@dataclass
+class Status:
+ batch_idx: int = 0
+ stage_id: int = 0
+ chunk_id: int = 0
+ f_b: "str" = "" # forward or backward
+ stg: "str" = "" # stage for 1f1b
+ k: int = 0
+
+ def __str__(self):
+ return f"b:{self.batch_idx}, c:{self.chunk_id}, {self.stg + '-' + self.f_b}"
+
+
+class PipelineSchema(metaclass=ABCMeta):
+ """
+ we define this class to abstract the pipeline execute
+ Args:
+ dep: the dependency for adjacency martrix
+ meshes: the list for stage of
+
+ """
+
+ def __init__(self, num_stage: int, meshes: Union[List[DeviceMesh], int], batches: int = 1):
+ self.num_stage = num_stage
+ self.meshes = meshes
+ self.batches = batches
+ self._schedules: List[List[Tuple]] = self._gen_schedule()
+
+ @property
+ @abstractmethod
+ def name(self):
+ """print schedule name"""
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _gen_schedule(self):
+ """generator the pipelinne schedule for engine"""
+ raise NotImplementedError("not impl")
+
+ def __str__(self):
+ """print the pipeline clock work"""
+ stream = "\n"
+ d = " ".join([f"d{d:<24}" for d in range(self.num_mesh)])
+ stream += f"T k :{d:<24} \n"
+ for time, scheds in enumerate(self.schedules):
+ sched_str = " ".join([f"{str(sched):<24}" for sched in scheds])
+ stream += f"T {time:<2}: {sched_str} \n"
+ return stream
+
+ @property
+ def schedules(self):
+ """return schedules"""
+ return self._schedules
+
+ @property
+ def num_mesh(self):
+ """return the num mesh of tp group"""
+ if isinstance(self.meshes, Sequence):
+ return len(self.meshes)
+ elif isinstance(self.meshes, int):
+ return self.meshes
+ else:
+ raise NotImplementedError("unsupport device mesh list")
+
+ @property
+ def num_clock(self):
+ """return num schedule for the num clock"""
+
+ return len(self._schedules)
+
+
+@dataclass
+class BaseInstruction(metaclass=ABCMeta):
+ @abstractmethod
+ def run(self, *args, **kwargs):
+ raise NotImplementedError("unsupport run command")
+
+ @property
+ def name(self):
+ return "base_instruction"
+
+ def dump(self):
+ return f"{get_rank()}: {self}"
+
+
+class InstructionGenerator(metaclass=ABCMeta):
+ def __init__(
+ self,
+ deps: StageDeps,
+ meshes: int,
+ batches: int,
+ default_shape: Optional[Shape] = None,
+ default_dtype: Optional[torch.dtype] = None,
+ batch_shape_lists: Optional[List[Any]] = None,
+ batch_dtype_lists: Optional[List[Any]] = None,
+ forward_only=False,
+ num_chunk=1,
+ ):
+ self.deps = deps
+ self.meshes = meshes
+ self.num_chunk = num_chunk
+ self.batches = batches
+ self.default_shape = default_shape
+ self.default_dtype = default_dtype
+ self.batch_shape_lists = batch_shape_lists
+ self.batch_dtype_lists = batch_dtype_lists
+ self.forward_only = forward_only
+ self.instruction_list: List = []
+
+ """
+ generate instruction
+ """
+
+ @abstractmethod
+ def gen_instruction(self):
+ raise NotImplementedError("not implement")
+
+ """
+ get current stage instruction
+ """
+
+ def get_instruction_list(self, stage: int):
+ return self.instruction_list[stage]
+
+ """
+ update with batch idx, stage idx
+ """
+
+ def _set_inst(self, inst: BaseInstruction, s: int):
+ self.instruction_list[s].append(inst)
+
+ """
+ set instruction type
+ """
+
+ def execute(self, *args, **kwargs):
+ raise NotImplementedError("not implement")
+
+
+class InstructionBuilder:
+ global_instructions_funcs = defaultdict(list)
+ global_instructions_str = defaultdict(list)
+
+ constant_data = defaultdict()
+ user_data = defaultdict()
+ loss_fn: Callable = torch.sum
+ dataloader: Any
+ topo: StageDeps
+ model: Callable
+ stage_id: int
+ _pos = 0
+ _stack = None
+
+ def build_from_dict(self, instructions: Dict):
+ assert isinstance(instructions, dict), "instructions should be dict"
+ for stage_id, instruction_list in instructions.items():
+ cur_stage_ins_list = instruction_list
+ if isinstance(cur_stage_ins_list, str):
+ instructions_funcs = cur_stage_ins_list.split(",")
+ else:
+ instructions_funcs = cur_stage_ins_list
+
+ mapped_functions = [registed_functions[x] for x in instructions_funcs]
+
+ self.global_instructions_funcs[stage_id] = mapped_functions
+ self.global_instructions_str[stage_id] = instructions_funcs
+
+ def draw_instructions(self):
+ from matplotlib import pyplot as plt
+
+ fig, ax = plt.subplots()
+ # draw rectangle
+ stage_nums = len(self.global_instructions_str.keys())
+ for stage_id, instuctions_strs in self.global_instructions_str.items():
+ for id, stage_str in enumerate(instuctions_strs):
+ ax.add_patch(plt.Rectangle((id, -1 * stage_id), 1, 1, fill=False, edgecolor="black", lw=2))
+ ax.text(id + 0.5, -1 * stage_id + 0.5, stage_str, ha="center", va="center")
+
+ for stage_id in range(stage_nums):
+ ax.text(-0.5, -1 * stage_id + 0.5, stage_id, ha="center", va="center")
+ # set max xlim and ylim
+ max_stages = max(len(x) for x in self.global_instructions_str.values())
+ ax.set_xlim(0, max_stages)
+ ax.set_ylim(-1 * stage_nums + 1, 1)
+ ax.axis("off")
+ plt.savefig("instructions.png")
+
+ @property
+ def pos(self):
+ return self._pos
+
+ @property
+ def last(self):
+ return self._stack
+
+ def run(self, stage_id: int):
+ output = []
+ for pos, fn in enumerate(self.global_instructions_funcs[stage_id]):
+ self._pos = pos
+ out = fn()
+ self._stack = out
+ output.append(out)
+ return output
+
+ def export(self, stage_id, *args, **kwargs):
+ func_lists = self.global_instructions_funcs[stage_id]
+
+ class Model(torch.nn.Module):
+ def __init__(self, func_lists, model):
+ super().__init__()
+ self.func_lists = func_lists
+ self.model = model
+
+ def forward(self, *args, **kwargs):
+ for f in self.func_lists:
+ # TODO: handle this to make forward inst work.
+ if f.__name__ == "forward":
+ activation = self.model(*args, **kwargs)
+ args = (activation,)
+ else:
+ args, kwargs = f(*args, **kwargs)
+ return args, kwargs
+
+ model = Model(func_lists, self.model)
+ graph = torch.export.export(model, args)
+ return graph
+
+
+class CompilePPCollectiveKind(enum.Enum):
+ SEND = 1
+ RECV = 2
+ BORADCAST = 3 # for cross mesh collective
+ UNKNOWN = 4
+
+
+class CompilePPCollectiveOperator:
+ def __init__(
+ self,
+ kind: CompilePPCollectiveKind,
+ src: int = None,
+ dst: List[int] = None,
+ is_backward: bool = False,
+ ) -> None:
+ assert kind in (
+ CompilePPCollectiveKind.BORADCAST,
+ CompilePPCollectiveKind.SEND,
+ CompilePPCollectiveKind.RECV,
+ )
+ self.kind = kind
+ self.is_backward = is_backward
+
+ if self.kind is CompilePPCollectiveKind.SEND:
+ assert dst is not None and isinstance(dst, int)
+ elif self.kind is CompilePPCollectiveKind.RECV:
+ assert src is not None and isinstance(src, int)
+ else:
+ assert src is not None and isinstance(src, int)
+ assert dst is not None and isinstance(dst, List[int])
+ assert src in dst
+
+ self.src = src
+ self.dst = dst
+ pass
+
+ def __hash__(self) -> int:
+ if isinstance(self.dst, List[int]):
+ dst = tuple(self.dst)
+ else:
+ dst = self.dst
+ return hash((self.kind, self.src, dst, self.is_backward))
+
+
+VESCALE_INTRUCTION_BUILDER = InstructionBuilder()
diff --git a/vescale/pipe/_schedules/looping_bfs.py b/vescale/pipe/_schedules/looping_bfs.py
new file mode 100644
index 0000000..4d0b6e6
--- /dev/null
+++ b/vescale/pipe/_schedules/looping_bfs.py
@@ -0,0 +1,1789 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from vescale.pipe._schedules.instruction_base import (
+ PipelineSchema,
+ Status,
+ Shape,
+ InstructionGenerator,
+ StageDeps,
+ BaseInstruction,
+ CommPacket,
+ VESCALE_INTRUCTION_BUILDER as builder,
+ register_instruction,
+ registed_functions,
+)
+import contextlib
+from dataclasses import dataclass, field
+from vescale.dtensor.dtensor import DTensor
+import torch
+from collections import defaultdict
+from inspect import signature
+import numpy as np
+from vescale.dtensor.device_mesh import DeviceMesh
+from typing import List, Sequence, Optional, Dict, Union, Callable
+from functools import partial
+from vescale.dtensor._diff import dummy_p2p, manage_dump_file
+from vescale.pipe.p2p_communication import (
+ recv_forward,
+ drain_send_reqs,
+ drain_recv_reqs,
+ send_forward_backward_recv_forward_backward,
+ send_forward_recv_forward,
+ send_backward_recv_backward,
+)
+from vescale.model.base_gpt.utils import switch_dtensor
+
+
+@dataclass
+class RECV_FORWARD(BaseInstruction):
+ comm_packages: List[CommPacket] = field(default_factory=list)
+ tensor_shapes: Union[List[Shape], Shape] = field(default_factory=list)
+ tensor_dtypes: Union[List[torch.dtype], torch.dtype] = field(default_factory=list)
+ batch_p2p_comm: bool = True
+ batch_id: Optional[int] = None
+ is_pp_first_stage: bool = False
+ debug: str = ""
+
+ @property
+ def name(self):
+ return "recv_forward"
+
+ @dummy_p2p
+ def run(self) -> List:
+ if self.is_pp_first_stage:
+ return None
+
+ def f(info):
+ comm, shape, dtype = info
+ return recv_forward(
+ tensor_shape=shape,
+ recv_dtype=dtype,
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ batch_p2p_comm=self.batch_p2p_comm,
+ )
+
+ infos = zip(self.comm_packages, self.tensor_shapes, self.tensor_dtypes)
+ out = list(map(f, infos))
+ return out if len(out) > 0 else None
+
+
+@dataclass
+class WAIT_FWD(BaseInstruction):
+ @property
+ def name(self):
+ return "wait_forward"
+
+ @dummy_p2p
+ def run(self, fwd_wait_handles: Optional[Sequence]):
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+
+@dataclass
+class DRAIN_SEND_REQS(BaseInstruction):
+ @property
+ def name(self):
+ return "drain_send_reqs"
+
+ @dummy_p2p
+ def run(self):
+ drain_send_reqs()
+
+
+@dataclass
+class DRAIN_RECV_REQS(BaseInstruction):
+ drain_type: str = "all"
+ check_bwd_wait: bool = False
+
+ @property
+ def name(self):
+ return "drain_recv_reqs"
+
+ @dummy_p2p
+ def run(self, bwd_wait_handles: Optional[Sequence]):
+ if self.check_bwd_wait:
+ if bwd_wait_handles is not None:
+ drain_recv_reqs(self.drain_type)
+ else:
+ drain_recv_reqs(self.drain_type)
+
+
+@dataclass
+class DEALLOCATE_OUTPUT_TENSOR(BaseInstruction):
+ @property
+ def name(self):
+ return "deallocate tensor"
+
+ @dummy_p2p
+ def run(self, output_tensor, deallocate_pipeline_outputs):
+ def deallocate(output_tensor):
+ if (output_tensor is None) or (not deallocate_pipeline_outputs):
+ return
+ assert isinstance(
+ output_tensor, [torch.Tensor, DTensor]
+ ), f"expected Tensor, found {type(output_tensor).__name__}."
+ assert output_tensor._base is None, "counter-productive to free a view of another tensor."
+ if isinstance(output_tensor, [torch.Tensor, DTensor]):
+ output_tensor._local_tensor.data = torch.empty(
+ (1,),
+ device=output_tensor.device,
+ dtype=output_tensor.dtype,
+ )
+ else:
+ output_tensor.data = torch.empty(
+ (1,),
+ device=output_tensor.device,
+ dtype=output_tensor.dtype,
+ )
+ return
+
+ if not isinstance(output_tensor, Sequence):
+ output_tensor = [output_tensor]
+ map(deallocate, output_tensor)
+
+
+@dataclass
+class APPEND_INPUTS(BaseInstruction):
+ chunk: int = 0
+
+ @property
+ def name(self):
+ return "append inputs"
+
+ @dummy_p2p
+ def run(self, input_tensor, input_tensors):
+ input_tensors[self.chunk].append(input_tensor)
+
+
+@dataclass
+class APPEND_GRADS(BaseInstruction):
+ chunk: int = 0
+
+ @property
+ def name(self):
+ return "append grads"
+
+ @dummy_p2p
+ def run(self, output_tensor_grad, output_tensor_grads):
+ output_tensor_grads[self.chunk].append(output_tensor_grad)
+
+
+@dataclass
+class SEND_FORWARD_BACKWARD_RECV_FORWARD_BACKWARD(BaseInstruction):
+ recv_prev: bool = False
+ recv_next: bool = False
+ send_comms: List[CommPacket] = field(default_factory=list)
+ recv_comms: List[CommPacket] = field(default_factory=list)
+ recv_shapes: List[Shape] = field(default_factory=list)
+ recv_dtypes: List[torch.dtype] = field(default_factory=list)
+ batch_p2p_comm: bool = True
+ debug: str = ""
+
+ @property
+ def name(self):
+ return "send forward backward recv forward backward"
+
+ @dummy_p2p
+ def run(self, output_tensor, input_tensor_grad):
+ if not isinstance(output_tensor, Sequence):
+ output_tensor = [output_tensor]
+ if not isinstance(input_tensor_grad, Sequence):
+ input_tensor_grad = [input_tensor_grad]
+
+ def f(info):
+ output_tensor, input_tensor_grad, recv_comm, send_comm, tensor_shape, dtype = info
+ if isinstance(output_tensor, DTensor):
+ output_tensor = output_tensor._local_tensor
+ if isinstance(input_tensor_grad, DTensor):
+ input_tensor_grad = input_tensor_grad._local_tensor
+
+ input_tensor, output_tensor_grad = send_forward_backward_recv_forward_backward(
+ output_tensor=output_tensor,
+ input_tensor_grad=input_tensor_grad,
+ recv_prev=self.recv_prev,
+ recv_next=self.recv_next,
+ current_device_mesh=send_comm.cur_mesh,
+ prev_device_mesh=recv_comm.peer_mesh,
+ next_device_mesh=send_comm.peer_mesh,
+ tensor_shape=tensor_shape,
+ recv_dtype=dtype,
+ batch_p2p_comm=self.batch_p2p_comm,
+ )
+ return input_tensor, output_tensor_grad
+
+ zipped_data = list(
+ zip(
+ output_tensor,
+ input_tensor_grad,
+ self.recv_comms,
+ self.send_comms,
+ self.recv_shapes,
+ self.recv_dtypes,
+ )
+ )
+
+ outputs = list(map(f, zipped_data))
+
+ if len(outputs) > 1:
+ if self.overlap_p2p_comm:
+ out = [x[0] for x in outputs]
+ handle = [x[1] for x in outputs]
+ return out, handle
+ else:
+ return outputs
+ else:
+ return outputs[0]
+
+
+@dataclass
+class SEND_FORWARD_RECV_FORWARD(BaseInstruction):
+ recv_prev: bool = False
+ send_shapes: List[Shape] = field(default_factory=list)
+ send_tensor_shapes_unpad: List[Shape] = field(default_factory=list)
+ send_dtypes: List[torch.dtype] = field(default_factory=list)
+ batch_p2p_comm: bool = True
+ overlap_p2p_comm: bool = False
+ send_comms: List[CommPacket] = field(default_factory=list)
+ recv_comms: List[CommPacket] = field(default_factory=list)
+ microbatch_id: int = 0
+ debug: str = ""
+
+ @property
+ def name(self):
+ return "send forward recv forward"
+
+ @dummy_p2p
+ def run(self, output_tensor):
+ if not isinstance(output_tensor, Sequence):
+ output_tensor = [output_tensor]
+
+ def f(info):
+ output_tensor, recv_comm, send_comm, tensor_shape, tensor_shape_unpad, dtype = info
+ if isinstance(output_tensor, DTensor):
+ output_tensor = output_tensor._local_tensor
+ output = send_forward_recv_forward(
+ output_tensor,
+ recv_prev=self.recv_prev,
+ tensor_shape=tensor_shape,
+ send_tensor_shape_unpad=tensor_shape_unpad,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ batch_p2p_comm=self.batch_p2p_comm,
+ recv_dtype=dtype,
+ current_device_mesh=send_comm.cur_mesh,
+ prev_device_mesh=recv_comm.peer_mesh,
+ next_device_mesh=send_comm.peer_mesh,
+ )
+ return output
+
+ zipped_data = list(
+ zip(
+ output_tensor,
+ self.recv_comms,
+ self.send_comms,
+ self.send_shapes,
+ self.send_tensor_shapes_unpad,
+ self.send_dtypes,
+ )
+ )
+
+ outputs = list(map(f, zipped_data))
+
+ if len(outputs) > 1:
+ if self.overlap_p2p_comm:
+ out = [x[0] for x in outputs]
+ handle = [x[1] for x in outputs]
+ return out, handle
+ else:
+ return outputs
+ else:
+ return outputs[0]
+
+
+@dataclass
+class SEND_BACKWARD_RECV_BACKWARD(BaseInstruction):
+ recv_next: bool = False
+ send_shapes: List[Shape] = field(default_factory=list)
+ send_tensor_shapes_unpad: List[Shape] = field(default_factory=list)
+ send_dtypes: List[torch.dtype] = field(default_factory=list)
+ batch_p2p_comm: bool = True
+ overlap_p2p_comm: bool = False
+ send_comms: List[CommPacket] = field(default_factory=list)
+ recv_comms: List[CommPacket] = field(default_factory=list)
+ debug: str = ""
+
+ @property
+ def name(self):
+ return "send backward recv backward"
+
+ @dummy_p2p
+ def run(self, input_tensor_grad):
+ if not isinstance(input_tensor_grad, Sequence):
+ input_tensor_grad = [input_tensor_grad]
+
+ def f(info):
+ input_tensor_grad, recv_comm, send_comm, tensor_shape, tensor_shape_unpad, dtype = info
+ if isinstance(input_tensor_grad, DTensor):
+ input_tensor_grad = input_tensor_grad._local_tensor
+ output = send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=self.recv_next,
+ tensor_shape=tensor_shape,
+ send_tensor_shape_unpad=tensor_shape_unpad,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ batch_p2p_comm=self.batch_p2p_comm,
+ recv_dtype=dtype,
+ current_device_mesh=send_comm.cur_mesh,
+ prev_device_mesh=recv_comm.peer_mesh,
+ next_device_mesh=send_comm.peer_mesh,
+ )
+ return output
+
+ zipped_data = list(
+ zip(
+ input_tensor_grad,
+ self.recv_comms,
+ self.send_comms,
+ self.send_shapes,
+ self.send_tensor_shapes_unpad,
+ self.send_dtypes,
+ )
+ )
+
+ output = list(map(f, zipped_data))
+
+ if len(output) > 1:
+ if self.overlap_p2p_comm:
+ return [x[0] for x in output], [x[1] for x in output]
+ else:
+ return output
+ else:
+ return output[0]
+
+
+@dataclass
+class SET_INPUTGRAD_TO_NONE(BaseInstruction):
+ @property
+ def name(self):
+ return "set inputgrad to none"
+
+ @dummy_p2p
+ def run(self):
+ return None
+
+
+@dataclass
+class SET_OUTPUT_TO_NONE(BaseInstruction):
+ @property
+ def name(self):
+ return "set output to none"
+
+ @dummy_p2p
+ def run(self):
+ return None
+
+
+@dataclass
+class BWD(BaseInstruction):
+ is_vpp_last_stage: bool = False
+ last_microbatch_for_model_chunk: bool = False
+ grad_sync_chunk_id: int = 0
+ grad_sync_microbatch_id: int = 0
+ model_chunk_id: int = 0
+ microbatch_id: int = 0
+ debug: str = ""
+
+ @property
+ def name(self):
+ return "backward"
+
+ def backward_step(
+ self,
+ input_tensor,
+ output_tensor,
+ output_tensor_grad,
+ grad_scaler=None,
+ deallocate_pipeline_outputs=False,
+ ):
+ """Backward step through passed-in output tensor.
+
+ If last stage, output_tensor_grad is None, otherwise gradient of loss
+ with respect to stage's output tensor.
+
+ Returns gradient of loss with respect to input tensor (None if first
+ stage)."""
+
+ # NOTE: This code currently can handle at most one skip connection. It
+ # needs to be modified slightly to support arbitrary numbers of skip
+ # connections.
+
+ # Retain the grad on the input_tensor.
+ unwrap_input_tensor_grad = False
+ if not isinstance(input_tensor, list):
+ input_tensor = [input_tensor]
+ unwrap_input_tensor_grad = True
+ for x in input_tensor:
+ if x is not None:
+ x.retain_grad()
+
+ if not isinstance(output_tensor, list):
+ output_tensor = [output_tensor]
+ if not isinstance(output_tensor_grad, list):
+ output_tensor_grad = [output_tensor_grad]
+
+ # extract loss value from output tensors
+ if isinstance(output_tensor[0], Sequence):
+ for j in range(len(output_tensor[0])):
+ if output_tensor[0][j].ndim == 0 and output_tensor[0][j].numel() == 1:
+ loss_value = output_tensor[0][j]
+ break
+ else:
+ loss_value = output_tensor[0][-1]
+ else:
+ loss_value = output_tensor[0]
+
+ # Backward pass.
+ if output_tensor_grad[0] is None and grad_scaler is not None:
+ loss_value = grad_scaler(loss_value)
+ # FIXME: For virtual pipeline, there may exist frozen layer without grad;
+ # Need to verify if this solution is correct
+ if not loss_value.requires_grad:
+ return None
+
+ model_chunk_id = builder.user_data["model_chunk_id"]
+ model = builder.model[model_chunk_id]
+ if deallocate_pipeline_outputs:
+ assert 0
+ else:
+ switch_dtensor(torch.autograd.backward)(loss_value, grad_tensors=output_tensor_grad[0])
+
+ model_chunk_id = builder.user_data["model_chunk_id"]
+ model = builder.model[model_chunk_id]
+
+ # Collect the grad of the input_tensor.
+ input_tensor_grad = [None]
+ if input_tensor is not None:
+ input_tensor_grad = []
+ for x in input_tensor:
+ if x is None:
+ input_tensor_grad.append(None)
+ else:
+ input_tensor_grad.append(x.grad)
+
+ if unwrap_input_tensor_grad:
+ input_tensor_grad = input_tensor_grad[0]
+
+ return input_tensor_grad
+
+ @dummy_p2p
+ def run(
+ self,
+ input_tensors,
+ output_tensors,
+ output_tensor_grads,
+ grad_sync_func,
+ synchronized_model_chunks,
+ kwargs: dict,
+ ):
+ grad_scaler, model, deallocate_pipeline_outputs = (
+ kwargs["grad_scaler"],
+ kwargs["model"],
+ kwargs["deallocate_pipeline_outputs"],
+ )
+ if self.is_vpp_last_stage:
+ if len(output_tensor_grads[self.model_chunk_id]) == 0:
+ output_tensor_grads[self.model_chunk_id].append(None)
+ input_tensor = input_tensors[self.model_chunk_id].pop(0)
+ output_tensor = output_tensors[self.model_chunk_id].pop(0)
+ output_tensor_grad = output_tensor_grads[self.model_chunk_id].pop(0)
+ input_tensor_grad = self.backward_step(
+ input_tensor, output_tensor, output_tensor_grad, grad_scaler, deallocate_pipeline_outputs
+ )
+
+ def f(input_tensor):
+ if input_tensor is not None:
+ assert isinstance(input_tensor, (torch.Tensor, DTensor)), input_tensor
+ input_tensor.grad = None
+ DEALLOCATE_OUTPUT_TENSOR().run(input_tensor, deallocate_pipeline_outputs)
+
+ if not isinstance(input_tensor, Sequence):
+ map(f, [input_tensor])
+ else:
+ map(f, input_tensor)
+
+ # launch grad synchronization (custom grad sync)
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if grad_sync_func is not None:
+ if self.grad_sync_microbatch_id >= 0 and self.last_microbatch_for_model_chunk:
+ grad_sync_func(model[self.grad_sync_chunk_id])
+ synchronized_model_chunks.add(self.grad_sync_chunk_id)
+ return input_tensor_grad
+
+
+@dataclass
+class FWD(BaseInstruction):
+ microbatch_id: int = 0
+ model_chunk_id: int = 0
+ param_sync_chunk_id: int = 0
+ is_vpp_first_stage: bool = False
+ is_vpp_last_stage: bool = False
+ forward_only: bool = False
+ num_model_chunks: int = 1
+ num_microbatches: int = 1
+ param_sync_microbatch_id: int = 0
+ first_microbatch_for_model_chunk: bool = True
+ optimizer_step_successful: bool = True
+ overlap_p2p_comm: bool = False
+ param_sync_overlap: bool = False
+ debug: str = ""
+
+ @property
+ def name(self):
+ return "forward"
+
+ def forward_step(
+ self,
+ data_iterator,
+ input_tensor,
+ model,
+ forward_data_store,
+ is_pp_first_stage: bool,
+ is_pp_last_stage: bool,
+ autocast_dtype=torch.float,
+ enable_autocast=False,
+ model_chunk_id=0,
+ ):
+ """Forward step for passed-in model.
+
+ If first stage, input tensor is obtained from data_iterator, otherwise
+ passed-in input_tensor is used.
+
+ Returns output tensor."""
+ if enable_autocast:
+ context_manager = torch.autocast("cuda", dtype=autocast_dtype)
+ else:
+ context_manager = contextlib.nullcontext()
+ with context_manager:
+
+ def prepare_data():
+ model_chunk_id = builder.user_data["model_chunk_id"]
+ ground_truth = []
+ if builder.user_data["is_pp_first_stage"]:
+ local_tensors = next(builder.dataloader[model_chunk_id])
+ true_input_tensor = None
+ else:
+ local_tensors = next(builder.dataloader[model_chunk_id])
+ if isinstance(local_tensors, Sequence) and len(local_tensors) > 1:
+ ground_truth.append(local_tensors[-1])
+ elif isinstance(local_tensors, Dict) and "labels" in local_tensors:
+ ground_truth.append(local_tensors["labels"])
+ true_input_tensor = builder.user_data["p2p_tensors"]
+ if isinstance(true_input_tensor, Sequence) and len(true_input_tensor) == 1:
+ true_input_tensor = true_input_tensor[0]
+
+ return true_input_tensor, local_tensors, ground_truth
+
+ builder.user_data["model_chunk_id"] = model_chunk_id
+ builder.user_data["p2p_tensors"] = input_tensor
+ builder.user_data["is_pp_first_stage"] = is_pp_first_stage
+ builder.user_data["is_pp_last_stage"] = is_pp_last_stage
+ builder.user_data["prepare_data_fn"] = prepare_data
+ p2p_input, local_input, ground_truth = registed_functions["vescale_interleaved_1f1b_pre_forward_data"]()
+ builder.user_data["ground_truth"] = ground_truth
+ output_tensor = registed_functions["vescale_interleaved_1f1b_forward"](p2p_input, local_input)
+ builder.user_data["output_tensor"] = output_tensor
+
+ if is_pp_last_stage:
+ output_tensor, loss_tensor = registed_functions["vescale_interleaved_1f1b_loss_fn"]()
+ forward_data_store.append((output_tensor, loss_tensor))
+ if builder.loss_fn is None:
+ return output_tensor
+ else:
+ return loss_tensor
+
+ return output_tensor
+
+ @dummy_p2p
+ def run(self, input_tensors, output_tensors, param_sync_func, kwargs):
+ # dump arguments for underlying fwd/bwd helpers
+ data_iterator, model, forward_data_store, dtype, enable_autocast = (
+ kwargs["data_iterator"],
+ kwargs["model"],
+ kwargs["forward_data_store"],
+ kwargs["dtype"],
+ kwargs["enable_autocast"],
+ )
+
+ assert param_sync_func is None
+ # TODO: implment logic for param_sync_func with PipeModule's utils
+ if param_sync_func is not None:
+ if self.param_sync_microbatch_id < self.num_microbatches and self.first_microbatch_for_model_chunk:
+ if 1 < self.param_sync_chunk_id < self.num_model_chunks:
+ param_sync_func(model[self.param_sync_chunk_id].parameters())
+
+ if self.overlap_p2p_comm and self.param_sync_overlap:
+ drain_recv_reqs("forward")
+
+ # forward step
+ if self.is_vpp_first_stage:
+ if len(input_tensors[self.model_chunk_id]) == len(output_tensors[self.model_chunk_id]):
+ input_tensors[self.model_chunk_id].append(None)
+
+ input_tensor = input_tensors[self.model_chunk_id][-1]
+ output_tensor = self.forward_step(
+ data_iterator=data_iterator,
+ input_tensor=input_tensor,
+ model=model,
+ forward_data_store=forward_data_store,
+ is_pp_first_stage=self.is_vpp_first_stage,
+ is_pp_last_stage=self.is_vpp_last_stage,
+ autocast_dtype=dtype,
+ enable_autocast=enable_autocast,
+ model_chunk_id=self.model_chunk_id,
+ )
+ output_tensors[self.model_chunk_id].append(output_tensor)
+
+ # if forward-only, no need to save tensors for a backward pass
+ if self.forward_only:
+ input_tensors[self.model_chunk_id].pop()
+ output_tensors[self.model_chunk_id].pop()
+
+ return output_tensor
+
+
+@dataclass
+class BUBBLE(BaseInstruction):
+ @property
+ def name(self):
+ return "bubble"
+
+ def run(self):
+ return
+
+
+@dataclass
+class LAUNCH_SHARED_UNITS_SYNC(BaseInstruction):
+ num_chunks: int = 1
+
+ @property
+ def name(self):
+ return "launch remain grad sync"
+
+ @dummy_p2p
+ def run(self, model):
+ for model_chunk_id in range(self.num_chunks):
+ # if isinstance(model, PipeModule):
+ # model.sync_shared_params(share_params=False, model_chunk_id=model_chunk_id)
+ ...
+
+
+class InterleavedPipeDreramFlush(PipelineSchema):
+ def __init__(
+ self,
+ num_chunks: int,
+ meshes: Sequence[DeviceMesh],
+ default_shape: Shape,
+ default_dtype: torch.dtype = torch.float32,
+ batches: int = 1,
+ input_shapes: Optional[List] = None,
+ input_shapes_unpad: Optional[List] = None,
+ **kwargs,
+ ):
+ assert batches % len(meshes) == 0, "Interleaved 1f1b only support mircobatch size mode device size"
+ assert batches // len(meshes) > 1, "Interleaved 1f1b only support mircobatch size = Interger * device size"
+ self.num_chunks = num_chunks
+ self.total_num_microbatches = num_chunks * batches
+ self.input_shapes = input_shapes
+ self.input_shapes_unpad = input_shapes_unpad
+ self.default_tensor_shape = default_shape
+ self.default_dtype = default_dtype
+ super().__init__(len(meshes), meshes, batches)
+
+ @property
+ def name(self):
+ return "Interleaved 1f1b"
+
+ def get_variable_tensor_shape(self, microbatch_id: int):
+ if self.input_shapes is None or len(self.input_shapes) == 0 or microbatch_id >= self.total_num_microbatches:
+ return self.default_tensor_shape
+
+ microbatch_group_size = self.num_mesh * self.num_chunks
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ microbatch_id_curr_model_chunk = microbatch_group_id * self.num_mesh + microbatch_id_in_group % self.num_mesh
+ tensor_shape = self.input_shapes[microbatch_id_curr_model_chunk]
+
+ return tensor_shape
+
+ def get_variable_tensor_shape_unpad(self, microbatch_id: int):
+ if (
+ self.input_shapes_unpad is None
+ or len(self.input_shapes_unpad) == 0
+ or microbatch_id >= self.total_num_microbatches
+ ):
+ return None
+
+ microbatch_group_size = self.num_mesh * self.num_chunks
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ microbatch_id_curr_model_chunk = microbatch_group_id * self.num_mesh + microbatch_id_in_group % self.num_mesh
+ return self.input_shapes_unpad[microbatch_id_curr_model_chunk]
+
+ def get_model_chunk_id(self, microbatch_id: int, forward: bool):
+ """Helper method to get the model chunk ID given the iteration number."""
+ microbatch_id_in_group = microbatch_id % (self.num_mesh * self.num_chunks)
+ model_chunk_id = microbatch_id_in_group // self.num_mesh
+ if not forward:
+ model_chunk_id = self.num_chunks - model_chunk_id - 1
+ return model_chunk_id
+
+ def is_first_microbatch_for_model_chunk_eager(self, microbatch_id: int) -> bool:
+ """Check if an iteration is the first for a model chunk eagerly"""
+ if microbatch_id % self.num_mesh != 0:
+ # Not the first time to run this model chunk
+ # For pipeline stage 0, chunk 0 is used by mb(0)
+ # mb(p), mb(2p), ...
+ return False
+ # grouping microbatches by pp_size, the groups will run different model chunk iteratively
+ microbatch_group_id = microbatch_id // self.num_mesh
+ if microbatch_group_id < self.num_chunks:
+ return True
+ return False
+
+ def is_first_microbatch_for_model_chunk(self, microbatch_id: int) -> bool:
+ """Check if an iteration is the first for a model chunk."""
+ # pp(0): mb(3+1)
+ # pp(1): mb(2+1)
+ # pp(2): mb(1+1)
+ # pp(3): mb(0+1)
+ microbatch_group_size = self.num_mesh * self.num_chunks
+ num_microbatch_groups = self.total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == 0:
+ return microbatch_id_in_group % self.num_mesh == 0
+ else:
+ return False
+
+ def is_last_microbatch_for_model_chunk(self, microbatch_id: int) -> bool:
+ """Check if an iteration is the last for a model chunk."""
+ microbatch_group_size = self.num_mesh * self.num_chunks
+ num_microbatch_groups = self.total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == num_microbatch_groups - 1:
+ return microbatch_id_in_group % self.num_mesh == self.num_mesh - 1
+ else:
+ return False
+
+ def _gen_schedule(self):
+ b = self.batches
+ d = self.num_mesh
+ s = self.num_chunks
+
+ warmup_batches = [min((d - i - 1) * 2 + (s - 1) * d, b * s) for i in range(d)]
+ self.warmup_batches = warmup_batches
+ remaining = [(b * s - w) for w in warmup_batches]
+ self.remaining = remaining
+ num_clock = (b * s + d - 1) * 2 # time todo flush
+ schedules = [[None] * d for c in range(num_clock)]
+ new_timeline = list(range(d))
+ bwd_done_idx = np.zeros(shape=[num_clock, d, s], dtype=np.int32)
+ next_fwd_batch_idx = np.zeros(shape=[d, s], dtype=np.int32)
+ next_bwd_batch_idx = np.zeros(shape=[d, s], dtype=np.int32)
+ # warm-up steps
+ for i in range(d):
+ for k in range(warmup_batches[i]):
+ t_i = new_timeline[i]
+ chunk_id = self.get_model_chunk_id(k, forward=True)
+ schedules[t_i][i] = Status(next_fwd_batch_idx[i][chunk_id], i, chunk_id, "F", "WUp", k)
+ new_timeline[i] += 1 # self add for new timeline
+ next_fwd_batch_idx[i][chunk_id] += 1 # do next micro batch
+
+ for i in reversed(range(d)):
+ for k in range(remaining[i]):
+ t_i = new_timeline[i]
+ f_k = k + warmup_batches[i]
+ chunk_id = self.get_model_chunk_id(f_k, forward=True)
+ schedules[t_i][i] = Status(next_fwd_batch_idx[i][chunk_id], i, chunk_id, "F", "1f1b", k)
+ next_fwd_batch_idx[i][chunk_id] += 1 # do next micro batch
+ bwd_k = k
+ chunk_id = self.get_model_chunk_id(bwd_k, forward=False)
+ bwd_done_idx[t_i][i] = bwd_done_idx[t_i - 1][i]
+ bwd_done_idx[t_i][i][chunk_id] = next_bwd_batch_idx[i][chunk_id]
+ t_i += 1
+
+ # do backward
+ if i + 1 < d:
+ while bwd_done_idx[t_i][i + 1][chunk_id] < next_bwd_batch_idx[i][chunk_id]:
+ assert bwd_done_idx[t_i - 1][i][chunk_id] == next_bwd_batch_idx[i][chunk_id]
+ bwd_done_idx[t_i][i][chunk_id] = bwd_done_idx[t_i - 1][i][chunk_id]
+ t_i = t_i + 1
+
+ if k == remaining[i] - 1: # last iterator
+ schedules[t_i][i] = Status(next_bwd_batch_idx[i][chunk_id], i, chunk_id, "B", "1f1b-l", k)
+ else:
+ schedules[t_i][i] = Status(next_bwd_batch_idx[i][chunk_id], i, chunk_id, "B", "1f1b", k)
+
+ bwd_done_idx[t_i][i] = bwd_done_idx[t_i - 1][i]
+ bwd_done_idx[t_i][i][chunk_id] = next_bwd_batch_idx[i][chunk_id]
+ next_bwd_batch_idx[i][chunk_id] += 1
+ new_timeline[i] = t_i + 1
+
+ # run cooldown passes
+ for i in reversed(range(d)):
+ for k in range(remaining[i], self.total_num_microbatches):
+ t_i = new_timeline[i]
+ bwd_k = k
+ chunk_id = self.get_model_chunk_id(bwd_k, forward=False)
+ if i + 1 < d:
+ while bwd_done_idx[t_i][i + 1][chunk_id] <= next_bwd_batch_idx[i][chunk_id]:
+ bwd_done_idx[t_i][i] = bwd_done_idx[t_i - 1][i]
+ bwd_done_idx[t_i][i][chunk_id] = next_bwd_batch_idx[i][chunk_id]
+ t_i = t_i + 1
+ schedules[t_i][i] = Status(next_bwd_batch_idx[i][chunk_id], i, chunk_id, "B", "CD", k)
+ bwd_done_idx[t_i][i] = bwd_done_idx[t_i - 1][i]
+ bwd_done_idx[t_i][i] = next_bwd_batch_idx[i]
+ next_bwd_batch_idx[i][chunk_id] += 1
+ new_timeline[i] = t_i + 1
+ bwd_done_idx[new_timeline[i] : num_clock, i, :] = b
+
+ return schedules
+
+
+class InterleavedOneFOneBInstructionGenerator(InstructionGenerator):
+ def __init__(
+ self,
+ deps: StageDeps,
+ meshes: List[DeviceMesh],
+ batches: int,
+ default_shape: Optional[Shape] = None,
+ default_dtype: Optional[torch.dtype] = None,
+ batch_shape_lists: Optional[List[Dict[int, Shape]]] = None,
+ batch_dtype_lists: Optional[List[Dict[int, torch.dtype]]] = None,
+ input_shapes: List[Dict[int, Shape]] = None,
+ input_shapes_unpad: List[Dict[int, Shape]] = None,
+ num_chunks: int = 1,
+ batch_p2p_comm: bool = True,
+ param_sync_overlap: bool = False,
+ overlap_p2p_comm: bool = False,
+ grad_sync_overlap: bool = False,
+ forward_only: bool = False,
+ ):
+ forward_only = True if not torch.is_grad_enabled() else forward_only
+ super().__init__(
+ deps=deps,
+ meshes=meshes,
+ batches=batches,
+ default_shape=default_shape,
+ default_dtype=default_dtype,
+ batch_shape_lists=batch_shape_lists,
+ batch_dtype_lists=batch_dtype_lists,
+ num_chunk=num_chunks,
+ forward_only=forward_only,
+ )
+ self.batch_p2p_comm = batch_p2p_comm
+ self.overlap_p2p_comm = overlap_p2p_comm
+ self.param_sync_overlap = param_sync_overlap
+ self.grad_sync_overlap = grad_sync_overlap
+ self.num_stage = len(meshes)
+ self.num_chunks = num_chunks
+ self.num_meshes = self.num_stage
+ self.schema = InterleavedPipeDreramFlush(
+ num_chunks=self.num_chunks,
+ meshes=self.meshes,
+ batches=self.batches,
+ default_shape=default_shape,
+ default_dtype=default_dtype,
+ input_shapes=input_shapes,
+ input_shapes_unpad=input_shapes_unpad,
+ )
+ self.forward_only = forward_only
+
+ def get_tensor_shape(self, microbatch_id: int, input_id: int = 0):
+ if (
+ self.schema.input_shapes is None
+ or len(self.schema.input_shapes) == 0
+ or microbatch_id >= self.schema.total_num_microbatches
+ ):
+ return self.schema.default_tensor_shape
+ microbatch_group_size = self.num_mesh * self.num_chunks
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ microbatch_id_curr_model_chunk = microbatch_group_id * self.num_mesh + microbatch_id_in_group % self.num_mesh
+ tensor_shape = self.schema.input_shapes[microbatch_id_curr_model_chunk]
+ if isinstance(tensor_shape, Dict):
+ tensor_shape = tensor_shape[input_id]
+ return tensor_shape
+
+ def get_variable_tensor_shape_unpad(self, microbatch_id: int, input_id: int = 0):
+ if (
+ self.schema.input_shapes is None
+ or len(self.schema.input_shapes) == 0
+ or microbatch_id >= self.schema.total_num_microbatches
+ ):
+ return self.schema.default_tensor_shape
+ microbatch_group_size = self.num_mesh * self.num_chunks
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ microbatch_id_curr_model_chunk = microbatch_group_id * self.num_mesh + microbatch_id_in_group % self.num_mesh
+ tensor_shape = self.schema.input_shapes_unpad[microbatch_id_curr_model_chunk]
+ if isinstance(tensor_shape, Dict):
+ tensor_shape = tensor_shape[input_id]
+ return tensor_shape
+
+ def get_tensor_dtype(self, microbatch_id: int, input_id: int = 0):
+ if (
+ self.batch_dtype_lists is None
+ or len(self.batch_dtype_lists) == 0
+ or microbatch_id >= self.schema.total_num_microbatches
+ ):
+ return self.default_dtype
+ microbatch_group_size = self.num_mesh * self.num_chunks
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ microbatch_id_curr_model_chunk = microbatch_group_id * self.num_mesh + microbatch_id_in_group % self.num_mesh
+ tensor_dtype = self.batch_dtype_lists[microbatch_id_curr_model_chunk]
+ if isinstance(tensor_dtype, Dict):
+ tensor_dtype = tensor_dtype[input_id]
+ return tensor_dtype
+
+ def get_shape_or_dtype(self, ff: Callable, comm_packages: List[CommPacket], microbatch_id):
+ def _get_shape_or_dtype(f: Callable, package: CommPacket):
+ return f(microbatch_id, package.input_id)
+
+ return list(map(partial(_get_shape_or_dtype, ff), comm_packages))
+
+ # call by pipe emitter
+ def gen_instruction(self):
+ schedules: List = self.schema.schedules
+ self.instruction_list = [[] for _ in range(self.num_stage)]
+ first_time_1f1b = [True] * self.num_stage
+ first_time_cool_down = [True] * self.num_stage
+ _forward_only = self.forward_only
+ if not torch.is_grad_enabled():
+ self.forward_only = True
+
+ # before warmup
+ for s in range(self.num_meshes):
+ recv_comms = self.deps.get_recv_comms(s)
+ tensor_shapes = self.get_shape_or_dtype(self.get_tensor_shape, recv_comms, 0)
+ tensor_dtypes = self.get_shape_or_dtype(self.get_tensor_dtype, recv_comms, 0)
+ self._set_inst(
+ RECV_FORWARD(
+ comm_packages=recv_comms,
+ tensor_shapes=tensor_shapes,
+ tensor_dtypes=tensor_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ batch_id=0,
+ is_pp_first_stage=self.deps.is_pipeline_first_stage(s),
+ debug="before warm-up",
+ ),
+ s,
+ )
+
+ one_f_one_b_set = [set() for _ in range(self.num_meshes)]
+
+ for clk, stages_schemas in enumerate(schedules):
+ for s, schema in enumerate(stages_schemas):
+ is_pp_first_stage = self.deps.is_pipeline_first_stage(s)
+ is_pp_last_stage = self.deps.is_pipeline_last_stage(s)
+ send_comms = self.deps.get_send_comms(s)
+ recv_comms = self.deps.get_recv_comms(s)
+ if schema:
+ stg = schema.stg
+ k = schema.k
+ send_shapes = self.get_shape_or_dtype(self.get_tensor_shape, send_comms, k)
+ send_dtypes = self.get_shape_or_dtype(self.get_tensor_dtype, send_comms, k)
+ send_shapes_unpad = self.get_shape_or_dtype(self.get_variable_tensor_shape_unpad, send_comms, k)
+ recv_shapes = self.get_shape_or_dtype(self.get_tensor_shape, recv_comms, k)
+ recv_dtypes = self.get_shape_or_dtype(self.get_tensor_dtype, recv_comms, k)
+ if "WUp" in stg:
+ if not self.overlap_p2p_comm:
+ self._set_inst(WAIT_FWD(), s)
+ elif not self.param_sync_overlap:
+ self._set_inst(DRAIN_RECV_REQS(drain_type="forward"), s)
+ # TODO: all warmup batch check
+
+ model_chunk_id = self.schema.get_model_chunk_id(k, forward=True)
+ is_vpp_first_stage = self.deps.is_vpp_first_stage(s, model_chunk_id)
+ is_vpp_last_stage = self.deps.is_vpp_last_stage(s, model_chunk_id)
+ param_sync_microbatch_id = k + self.schema.num_mesh
+ param_sync_chunk_id = self.schema.get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
+ first_microbatch_for_model_chunk = self.schema.is_first_microbatch_for_model_chunk(k)
+ self._set_inst(
+ FWD(
+ microbatch_id=k,
+ model_chunk_id=model_chunk_id,
+ param_sync_chunk_id=param_sync_chunk_id,
+ is_vpp_first_stage=is_vpp_first_stage,
+ is_vpp_last_stage=is_vpp_last_stage,
+ forward_only=self.forward_only,
+ num_model_chunks=self.num_chunk,
+ num_microbatches=self.batches * self.num_chunk,
+ param_sync_microbatch_id=param_sync_microbatch_id,
+ first_microbatch_for_model_chunk=first_microbatch_for_model_chunk,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ param_sync_overlap=self.param_sync_overlap,
+ ),
+ s,
+ )
+ # Determine if tensor should be received from previous stage.
+ next_forward_model_chunk_id = self.schema.get_model_chunk_id(k + 1, forward=True)
+ recv_prev = True
+ if is_pp_first_stage:
+ if next_forward_model_chunk_id == 0:
+ recv_prev = False
+ if k == (self.schema.total_num_microbatches - 1):
+ recv_prev = False
+
+ if is_vpp_last_stage:
+ self._set_inst(SET_OUTPUT_TO_NONE(), s)
+
+ if not self.overlap_p2p_comm:
+ if k == (self.schema.warmup_batches[s] - 1) and not self.forward_only:
+ self._set_inst(SET_INPUTGRAD_TO_NONE(), s)
+ recv_next = True
+ if is_pp_last_stage:
+ recv_next = False
+ self._set_inst(
+ SEND_FORWARD_BACKWARD_RECV_FORWARD_BACKWARD(
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ recv_comms=recv_comms,
+ send_comms=send_comms,
+ recv_shapes=recv_shapes,
+ recv_dtypes=recv_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ debug="none p2p overlap, last batch warm-up",
+ ),
+ s,
+ )
+
+ self._set_inst(APPEND_GRADS(chunk=self.num_chunk - 1), s)
+ else:
+ self._set_inst(
+ SEND_FORWARD_RECV_FORWARD(
+ recv_prev=recv_prev,
+ send_shapes=send_shapes,
+ send_tensor_shapes_unpad=send_shapes_unpad,
+ send_dtypes=send_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ microbatch_id=k,
+ send_comms=send_comms,
+ recv_comms=recv_comms,
+ debug="none p2p overlap, warm-up",
+ ),
+ s,
+ )
+
+ self._set_inst(APPEND_INPUTS(chunk=next_forward_model_chunk_id), s)
+ else:
+ tensor_shapes = self.get_shape_or_dtype(self.get_tensor_shape, send_comms, k + 1)
+ tensor_dtypes = self.get_shape_or_dtype(self.get_tensor_dtype, send_comms, k + 1)
+
+ self._set_inst(
+ SEND_FORWARD_RECV_FORWARD(
+ recv_prev=recv_prev,
+ send_shapes=tensor_shapes,
+ send_tensor_shapes_unpad=send_shapes_unpad,
+ send_dtypes=tensor_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ send_comms=send_comms,
+ recv_comms=recv_comms,
+ debug="p2p overlap, warm up",
+ ),
+ s,
+ )
+ if k == (self.schema.warmup_batches[s] - 1) and not self.forward_only:
+ self._set_inst(SET_INPUTGRAD_TO_NONE(), s)
+ recv_next = True
+ if is_pp_last_stage:
+ recv_next = False
+ self._set_inst(
+ SEND_BACKWARD_RECV_BACKWARD(
+ recv_next=recv_next,
+ send_shapes=send_shapes,
+ send_tensor_shapes_unpad=send_shapes_unpad,
+ send_dtypes=send_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ send_comms=send_comms,
+ recv_comms=recv_comms,
+ debug="warm-up",
+ ),
+ s,
+ )
+ self._set_inst(APPEND_GRADS(self.num_chunk - 1), s)
+ self._set_inst(APPEND_INPUTS(chunk=next_forward_model_chunk_id), s)
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(), s)
+ elif "1f1b" in stg: # 1f1b stage
+ forward_k = k + self.schema.warmup_batches[s]
+ if first_time_1f1b[s]:
+ if self.overlap_p2p_comm:
+ self._set_inst(DRAIN_SEND_REQS(), s)
+ first_time_1f1b[s] = False
+ if k in one_f_one_b_set[s]:
+ continue
+ else:
+ one_f_one_b_set[s].add(k)
+ if self.overlap_p2p_comm:
+ if not self.param_sync_overlap:
+ self._set_inst(DRAIN_RECV_REQS(drain_type="forward"), s)
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(), s)
+
+ model_chunk_id = self.schema.get_model_chunk_id(forward_k, forward=True)
+ is_vpp_first_stage = self.deps.is_vpp_first_stage(s, model_chunk_id)
+ is_vpp_last_stage = self.deps.is_vpp_last_stage(s, model_chunk_id)
+ param_sync_microbatch_id = forward_k + self.schema.num_mesh
+ param_sync_chunk_id = (
+ self.schema.get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
+ )
+ first_microbatch_for_model_chunk = self.schema.is_first_microbatch_for_model_chunk(
+ forward_k
+ )
+ self._set_inst(
+ FWD(
+ microbatch_id=forward_k,
+ model_chunk_id=model_chunk_id,
+ param_sync_chunk_id=param_sync_chunk_id,
+ param_sync_microbatch_id=param_sync_microbatch_id,
+ param_sync_overlap=self.param_sync_overlap,
+ first_microbatch_for_model_chunk=first_microbatch_for_model_chunk,
+ is_vpp_first_stage=is_vpp_first_stage,
+ is_vpp_last_stage=is_vpp_last_stage,
+ forward_only=self.forward_only,
+ num_model_chunks=self.num_chunk,
+ num_microbatches=self.batches * self.num_chunk,
+ debug="1f1b",
+ ),
+ s,
+ )
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ # Last virtual stage no activation tensor to send
+ if is_vpp_last_stage:
+ self._set_inst(SET_OUTPUT_TO_NONE(), s)
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if is_pp_first_stage:
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = self.schema.get_model_chunk_id(
+ forward_k - (self.schema.num_mesh - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (self.schema.num_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = self.schema.get_model_chunk_id(
+ forward_k + 1, forward=True
+ )
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (self.schema.remaining[s] - 1):
+ recv_prev = False
+
+ # Send activation tensor to the next stage and receive activation tensor from the
+ # previous stage
+ tensor_shape = self.schema.get_variable_tensor_shape(forward_k + 1)
+ send_tensor_shape_unpad = self.schema.get_variable_tensor_shape_unpad(forward_k)
+ self._set_inst(
+ SEND_FORWARD_RECV_FORWARD(
+ recv_prev=recv_prev,
+ send_shapes=send_shapes,
+ send_tensor_shapes_unpad=send_shapes_unpad,
+ send_dtypes=send_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ send_comms=send_comms,
+ recv_comms=recv_comms,
+ microbatch_id=forward_k,
+ debug="1f1b",
+ ),
+ s,
+ )
+ self._set_inst(DRAIN_RECV_REQS(drain_type="backward"), s)
+
+ # Backward pass.
+ backward_k = k
+ grad_sync_microbatch_id = backward_k - s
+ grad_sync_chunk_id = self.schema.get_model_chunk_id(grad_sync_microbatch_id, forward=False)
+ last_microbatch_for_model_chunk = self.schema.is_last_microbatch_for_model_chunk(
+ grad_sync_microbatch_id
+ )
+ backward_model_chunk_id = self.schema.get_model_chunk_id(backward_k, forward=False)
+ is_vpp_first_stage = self.deps.is_vpp_first_stage(s, backward_model_chunk_id)
+ is_vpp_last_stage = self.deps.is_vpp_last_stage(s, backward_model_chunk_id)
+ self._set_inst(
+ BWD(
+ is_vpp_last_stage=is_vpp_last_stage,
+ last_microbatch_for_model_chunk=last_microbatch_for_model_chunk,
+ grad_sync_chunk_id=grad_sync_chunk_id,
+ grad_sync_microbatch_id=grad_sync_microbatch_id,
+ model_chunk_id=backward_model_chunk_id,
+ microbatch_id=backward_k,
+ debug="1f1b",
+ ),
+ s,
+ )
+
+ # First virtual stage no activation gradient tensor to send
+ if is_vpp_first_stage:
+ self._set_inst(SET_INPUTGRAD_TO_NONE(), s)
+
+ # Determine if the current virtual stage has an activation gradient tensor to receive
+ recv_next = True
+ if is_pp_last_stage:
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = self.schema.get_model_chunk_id(
+ backward_k - (self.schema.num_mesh - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = self.schema.get_model_chunk_id(
+ backward_k + 1, forward=False
+ )
+
+ tensor_shape = self.schema.get_variable_tensor_shape(backward_k + 1)
+ send_tensor_shape_unpad = self.schema.get_variable_tensor_shape_unpad(backward_k)
+ self._set_inst(
+ SEND_BACKWARD_RECV_BACKWARD(
+ recv_next=recv_next,
+ send_shapes=send_shapes,
+ send_tensor_shapes_unpad=send_shapes_unpad,
+ send_dtypes=send_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ send_comms=send_comms,
+ recv_comms=recv_comms,
+ debug="1f1b",
+ ),
+ s,
+ )
+ else:
+ model_chunk_id = self.schema.get_model_chunk_id(forward_k, forward=True)
+ is_vpp_first_stage = self.deps.is_vpp_first_stage(s, model_chunk_id)
+ is_vpp_last_stage = self.deps.is_vpp_last_stage(s, model_chunk_id)
+ param_sync_microbatch_id = forward_k + self.schema.num_mesh
+ param_sync_chunk_id = (
+ self.schema.get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
+ )
+ first_microbatch_for_model_chunk = self.schema.is_first_microbatch_for_model_chunk(
+ forward_k
+ )
+ self._set_inst(
+ FWD(
+ microbatch_id=forward_k,
+ model_chunk_id=model_chunk_id,
+ is_vpp_first_stage=is_vpp_first_stage,
+ is_vpp_last_stage=is_vpp_last_stage,
+ param_sync_chunk_id=param_sync_chunk_id,
+ param_sync_microbatch_id=param_sync_microbatch_id,
+ first_microbatch_for_model_chunk=first_microbatch_for_model_chunk,
+ forward_only=self.forward_only,
+ ),
+ s,
+ )
+
+ # Backward pass.
+ backward_k = k
+ grad_sync_microbatch_id = backward_k - s
+ grad_sync_chunk_id = self.schema.get_model_chunk_id(grad_sync_microbatch_id, forward=False)
+ last_microbatch_for_model_chunk = self.schema.is_last_microbatch_for_model_chunk(
+ grad_sync_microbatch_id
+ )
+ backward_model_chunk_id = self.schema.get_model_chunk_id(backward_k, forward=False)
+ is_vpp_first_stage = self.deps.is_vpp_first_stage(s, backward_model_chunk_id)
+ is_vpp_last_stage = self.deps.is_vpp_last_stage(s, backward_model_chunk_id)
+ self._set_inst(
+ BWD(
+ microbatch_id=backward_k,
+ model_chunk_id=backward_model_chunk_id,
+ is_vpp_last_stage=is_vpp_last_stage,
+ last_microbatch_for_model_chunk=last_microbatch_for_model_chunk,
+ grad_sync_microbatch_id=grad_sync_microbatch_id,
+ grad_sync_chunk_id=grad_sync_chunk_id,
+ debug="1f1b",
+ ),
+ s,
+ )
+
+ # Send output_tensor and input_tensor_grad, receive input_tensor
+ # and output_tensor_grad.
+
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = self.schema.get_model_chunk_id(forward_k, forward=True)
+ is_vpp_last_stage = self.deps.is_vpp_last_stage(s, forward_model_chunk_id)
+ if is_vpp_last_stage:
+ self._set_inst(SET_OUTPUT_TO_NONE(), s)
+ backward_model_chunk_id = self.schema.get_model_chunk_id(backward_k, forward=False)
+ is_vpp_first_stage = self.deps.is_vpp_first_stage(s, backward_model_chunk_id)
+ if is_vpp_first_stage:
+ self._set_inst(SET_INPUTGRAD_TO_NONE(), s)
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if is_pp_first_stage:
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = self.schema.get_model_chunk_id(
+ forward_k - (self.num_meshes - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (self.num_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = self.schema.get_model_chunk_id(
+ forward_k + 1, forward=True
+ )
+
+ recv_next = True
+ if is_pp_last_stage:
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = self.schema.get_model_chunk_id(
+ backward_k - (self.num_meshes - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = self.schema.get_model_chunk_id(
+ backward_k + 1, forward=False
+ )
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (self.schema.remaining[s] - 1):
+ recv_prev = False
+
+ self._set_inst(
+ SEND_FORWARD_BACKWARD_RECV_FORWARD_BACKWARD(
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ send_comms=send_comms,
+ recv_comms=recv_comms,
+ recv_shapes=recv_shapes,
+ recv_dtypes=recv_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ debug="1f1b",
+ ),
+ s,
+ )
+
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(), s)
+
+ # Put input_tensor and output_tensor_grad in data structures in the
+ # right location.
+ if recv_prev:
+ self._set_inst(APPEND_INPUTS(chunk=next_forward_model_chunk_id), s)
+ if recv_next:
+ self._set_inst(APPEND_GRADS(chunk=next_backward_model_chunk_id), s)
+
+ # launch grad_sync_func here to overlap with p2p communication
+ if self.grad_sync_overlap:
+ raise NotImplementedError("grad sync is not implement yet")
+ elif stg == "CD": # cool down stage
+ if first_time_cool_down[s]:
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(), s)
+ if self.overlap_p2p_comm:
+ self._set_inst(DRAIN_SEND_REQS(), s)
+ if not self.forward_only:
+ if self.overlap_p2p_comm:
+ self._set_inst(DRAIN_RECV_REQS(drain_type="all", check_bwd_wait=True), s)
+ first_time_cool_down[s] = False
+ if self.forward_only:
+ continue # forward have no backward phase
+
+ grad_sync_microbatch_id = k - s
+ grad_sync_chunk_id = self.schema.get_model_chunk_id(grad_sync_microbatch_id, forward=False)
+ last_microbatch_for_model_chunk = self.schema.is_last_microbatch_for_model_chunk(
+ grad_sync_microbatch_id
+ )
+ model_chunk_id = self.schema.get_model_chunk_id(k, forward=False)
+ is_vpp_last_stage = self.deps.is_vpp_first_stage(s, model_chunk_id)
+ self._set_inst(
+ BWD(
+ microbatch_id=k,
+ is_vpp_last_stage=is_vpp_last_stage,
+ model_chunk_id=model_chunk_id,
+ last_microbatch_for_model_chunk=last_microbatch_for_model_chunk,
+ grad_sync_chunk_id=grad_sync_chunk_id,
+ grad_sync_microbatch_id=grad_sync_microbatch_id,
+ debug="cooldown",
+ ),
+ s,
+ )
+ next_backward_model_chunk_id = self.schema.get_model_chunk_id(k + 1, forward=False)
+ recv_next = True
+ if is_pp_last_stage:
+ if next_backward_model_chunk_id == (self.schema.num_chunks - 1):
+ recv_next = False
+ if k == (self.schema.total_num_microbatches - 1):
+ recv_next = False
+
+ tensor_shape = self.schema.get_variable_tensor_shape(k + 1)
+ send_tensor_shape_unpad = self.schema.get_variable_tensor_shape_unpad(k)
+ self._set_inst(
+ SEND_BACKWARD_RECV_BACKWARD(
+ recv_next=recv_next,
+ send_shapes=send_shapes,
+ send_tensor_shapes_unpad=send_shapes_unpad,
+ send_dtypes=send_dtypes,
+ batch_p2p_comm=self.batch_p2p_comm,
+ overlap_p2p_comm=self.overlap_p2p_comm,
+ send_comms=send_comms,
+ recv_comms=recv_comms,
+ debug="cooldown",
+ ),
+ s,
+ )
+ self._set_inst(APPEND_GRADS(chunk=next_backward_model_chunk_id), s)
+
+ if self.grad_sync_overlap:
+ raise NotImplementedError("grad sync is not support yet")
+
+ if self.overlap_p2p_comm:
+ self._set_inst(DRAIN_RECV_REQS(drain_type="all"), s)
+ else: # bubble
+ # do any other
+ self._set_inst(BUBBLE(), s)
+ # Launch any remaining grad reductions
+ # if grad_sync_func is not None:
+ # for model_chunk_id in range(num_model_chunks):
+ # if model_chunk_id not in synchronized_model_chunks:
+ # grad_sync_func(model[model_chunk_id], model_chunk_id)
+ # synchronized_model_chunks.add(model_chunk_id)
+
+ # add cool down things
+ for s in range(self.num_meshes):
+ if not self.forward_only:
+ # Launch any remaining grad reductions
+ self._set_inst(LAUNCH_SHARED_UNITS_SYNC(num_chunks=self.deps.get_num_chunks()), s)
+
+ if self.overlap_p2p_comm:
+ self._set_inst(DRAIN_SEND_REQS(), s)
+
+ # restore original self.forward_only if the current context manager is torch.no_grad()
+ if not torch.is_grad_enabled():
+ self.forward_only = _forward_only
+
+ self.gen_instruction_str_list()
+ return self.instruction_list
+
+ def gen_instruction_str_list(self):
+ instruction_lists = self.instruction_list
+ stage_strs = defaultdict(str)
+ for stage_id, instruction_list in enumerate(instruction_lists):
+ cur_stage_str = stage_strs[stage_id]
+ for inst in instruction_list:
+ cur_stage_str += f"{VESACLE_INSTRUCTION_MAPPING_V[type(inst)]},"
+ cur_stage_str = cur_stage_str[:-1]
+ stage_strs[stage_id] = cur_stage_str
+ builder.build_from_dict(stage_strs)
+
+ @manage_dump_file
+ def execute(
+ self,
+ stage_id,
+ autocast_dtype=torch.float32,
+ enable_autocast=False,
+ grad_scaler=None,
+ deallocate_pipeline_outputs=False,
+ param_sync_func=None,
+ grad_sync_func=None,
+ ):
+ # if the current context manager is torch.no_grad(), do not compute backward
+ temp_forward_only = self.forward_only
+ if not torch.is_grad_enabled():
+ self.forward_only = False
+
+ # init constant data
+ builder.constant_data["autocast_dtype"] = autocast_dtype
+ builder.constant_data["enable_autocast"] = enable_autocast
+ builder.constant_data["grad_scaler"] = grad_scaler
+ builder.constant_data["deallocate_pipeline_outputs"] = deallocate_pipeline_outputs
+ builder.constant_data["param_sync_func"] = param_sync_func
+ builder.constant_data["grad_sync_func"] = grad_sync_func
+
+ # Model chunk IDs with synchronized grads
+ builder.user_data["synchronized_model_chunks"] = set()
+ builder.user_data["input_tensors"] = [[] for _ in range(self.num_chunk)]
+ builder.user_data["output_tensors"] = [[] for _ in range(self.num_chunk)]
+ builder.user_data["output_tensor_grads"] = [[] for _ in range(self.num_chunk)]
+ builder.user_data["fwd_wait_handles"] = None
+ builder.user_data["bwd_wait_handles"] = None
+ builder.user_data["output_tensor"] = None
+ builder.user_data["input_tensor"] = None
+ builder.user_data["output_tensor_grad"] = None
+ builder.user_data["forward_data_store"] = []
+ model = self.deps.get_current_model(stage_id)
+
+ builder.model = model
+ instruction_list = self.get_instruction_list(stage_id)
+ builder.stage_id = stage_id
+ builder_instruction_list = builder.global_instructions_funcs[stage_id]
+
+ for inst, fn in zip(instruction_list, builder_instruction_list):
+ builder.user_data["inst"] = inst
+ fn()
+
+ # restore original self.forward_only if the current context manager is torch.no_grad()
+ if not torch.is_grad_enabled():
+ self.forward_only = temp_forward_only
+
+ return builder.user_data["forward_data_store"]
+
+
+@register_instruction(name="vescale_interleavd_1f1b_recv_forward")
+def vpp_recv_forward():
+ inst = builder.user_data["inst"]
+ tmp = inst.run()
+ input_tensors = builder.user_data["input_tensors"]
+ input_tensors[0].append(tmp)
+
+
+@register_instruction(name="vescale_interleavd_1f1b_forward")
+def vpp_forward():
+ inst = builder.user_data["inst"]
+ user_data = builder.user_data
+ forward_data_store = user_data["forward_data_store"]
+ input_tensors = user_data["input_tensors"]
+ output_tensors = user_data["output_tensors"]
+
+ constant_data = builder.constant_data
+ autocast_dtype = constant_data["autocast_dtype"]
+ enable_autocast = constant_data["enable_autocast"]
+ param_sync_func = constant_data["param_sync_func"]
+
+ forward_args = {
+ "data_iterator": builder.dataloader,
+ "model": builder.model,
+ "forward_data_store": forward_data_store,
+ "dtype": autocast_dtype,
+ "enable_autocast": enable_autocast,
+ }
+ output_tensor = inst.run(
+ input_tensors=input_tensors,
+ output_tensors=output_tensors,
+ param_sync_func=param_sync_func,
+ kwargs=forward_args,
+ )
+ user_data["output_tensor"] = output_tensor
+
+
+@register_instruction(name="vescale_interleaved_1f1b_backward")
+def vpp_backward():
+ inst = builder.user_data["inst"]
+ model = builder.model
+ grad_scaler = builder.constant_data["grad_scaler"]
+ deallocate_pipeline_outputs = builder.constant_data["deallocate_pipeline_outputs"]
+ backward_args = {
+ "grad_scaler": grad_scaler,
+ "model": model,
+ "deallocate_pipeline_outputs": deallocate_pipeline_outputs,
+ }
+ grad_sync_func = builder.constant_data["grad_sync_func"]
+ input_tensors = builder.user_data["input_tensors"]
+ output_tensors = builder.user_data["output_tensors"]
+ output_tensor_grads = builder.user_data["output_tensor_grads"]
+ synchronized_model_chunks = builder.user_data["synchronized_model_chunks"]
+
+ input_tensor_grad = inst.run(
+ input_tensors=input_tensors,
+ output_tensors=output_tensors,
+ output_tensor_grads=output_tensor_grads,
+ grad_sync_func=grad_sync_func,
+ synchronized_model_chunks=synchronized_model_chunks,
+ kwargs=backward_args,
+ )
+ builder.user_data["input_tensor_grad"] = input_tensor_grad
+
+
+@register_instruction(name="vescale_interleavd_1f1b_set_output_to_none")
+def vpp_set_output_to_none():
+ inst = builder.user_data["inst"]
+ output_tensor = inst.run()
+ builder.user_data["output_tensor"] = None
+
+
+@register_instruction(name="vescale_interleavd_1f1b_set_input_grad_to_none")
+def vpp_set_input_grad_to_none():
+ inst = builder.user_data["inst"]
+ input_tensor_grad = inst.run()
+ builder.user_data["input_tensor_grad"] = input_tensor_grad
+
+
+@register_instruction(name="vescale_interleaved_1f1b_send_forward_recv_forward")
+def vpp_send_forward_recv_forward():
+ inst = builder.user_data["inst"]
+ output_tensor = builder.user_data["output_tensor"]
+ input_tensor = inst.run(output_tensor=output_tensor)
+ if inst.overlap_p2p_comm:
+ input_tensor, fwd_wait_handles = input_tensor
+ builder.user_data["fwd_wait_handles"] = fwd_wait_handles
+ builder.user_data["input_tensor"] = input_tensor
+
+
+@register_instruction(name="vescale_interleavd_1f1b_send_backward_recv_backward")
+def vpp_send_backward_recv_backward():
+ inst = builder.user_data["inst"]
+ input_tensor_grad = builder.user_data["input_tensor_grad"]
+ output_tensor_grad = inst.run(input_tensor_grad=input_tensor_grad)
+ if inst.overlap_p2p_comm:
+ output_tensor_grad, bwd_wait_handles = output_tensor_grad
+ builder.user_data["bwd_wait_handles"] = bwd_wait_handles
+ builder.user_data["output_tensor_grad"] = output_tensor_grad
+
+
+@register_instruction(name="vescale_interleaved_1f1b_send_forward_backward_recv_forward_backward")
+def vpp_send_forward_backward_recv_forward_backward():
+ inst = builder.user_data["inst"]
+ output_tensor = builder.user_data["output_tensor"]
+ input_tensor_grad = builder.user_data["input_tensor_grad"]
+ input_tensor, output_tensor_grad = inst.run(output_tensor=output_tensor, input_tensor_grad=input_tensor_grad)
+ builder.user_data["input_tensor"] = input_tensor
+ builder.user_data["output_tensor_grad"] = output_tensor_grad
+
+
+@register_instruction(name="vescale_interleavd_1f1b_append_grads")
+def vpp_append_grads():
+ inst = builder.user_data["inst"]
+ output_tensor_grads = builder.user_data["output_tensor_grads"]
+ output_tensor_grad = builder.user_data["output_tensor_grad"]
+ inst.run(output_tensor_grad, output_tensor_grads)
+
+
+@register_instruction(name="vescale_interleavd_1f1b_append_inputs")
+def vpp_append_inputs():
+ inst = builder.user_data["inst"]
+ input_tensor = builder.user_data["input_tensor"]
+ input_tensors = builder.user_data["input_tensors"]
+ inst.run(input_tensor, input_tensors)
+
+
+@register_instruction(name="vescale_interleavd_1f1b_deallocate_output_tensor")
+def vpp_deallocate_tensors():
+ inst = builder.user_data["inst"]
+ deallocate_pipeline_outputs = builder.constant_data["deallocate_pipeline_outputs"]
+ output_tensor = builder.user_data["output_tensor"]
+ inst.run(output_tensor=output_tensor, deallocate_pipeline_outputs=deallocate_pipeline_outputs)
+
+
+@register_instruction(name="vescale_interleaved_1f1b_drain_send_reqs")
+def vpp_drain_send_reqs():
+ inst = builder.user_data["inst"]
+ inst.run()
+
+
+@register_instruction(name="vescale_interleaved_1f1b_drain_recv_reqs")
+def vpp_drain_recv_reqs():
+ inst = builder.user_data["inst"]
+ bwd_wait_handles = builder.user_data["bwd_wait_handles"]
+ inst.run(bwd_wait_handles=bwd_wait_handles)
+
+
+@register_instruction(name="vescale_interleaved_1f1b_wait_fwd")
+def vpp_wait_fwd():
+ inst = builder.user_data["inst"]
+ fwd_wait_handles = builder.user_data["fwd_wait_handles"]
+ inst.run(fwd_wait_handles=fwd_wait_handles)
+
+
+@register_instruction(name="vescale_interleavd_1f1b_launch_shared_units_sync")
+def vpp_launch_shared_units_sync():
+ model = builder.model
+ inst = builder.user_data["inst"]
+ inst.run(model=model)
+
+
+@register_instruction(name="vescale_interleaved_1f1b_pre_forward_data")
+def vpp_prepare_forward_args():
+ fn = builder.user_data["prepare_data_fn"]
+ return fn()
+
+
+@register_instruction(name="vescale_interleaved_1f1b_forward")
+def forward_fn(p2p_input, local_input):
+ model_chunk_id = builder.user_data["model_chunk_id"]
+ if isinstance(builder.model, Sequence):
+
+ def _feed_input(data):
+ if isinstance(data, Sequence):
+ return model(*data)
+ elif isinstance(data, Dict):
+ return model(**data)
+ else:
+ return model(data)
+
+ model = builder.model[model_chunk_id]
+ if p2p_input is not None:
+ return _feed_input(p2p_input)
+ else:
+ return _feed_input(local_input)
+ else:
+ return builder.model(p2p_input, local_input, model_chunk_id)
+
+
+@register_instruction(name="vescale_interleaved_1f1b_loss_fn")
+def loss_fn():
+ loss_func = builder.loss_fn
+ output_tensor = builder.user_data["output_tensor"]
+ if loss_func is None:
+ return output_tensor, None
+ temp_tensor = output_tensor
+ args_spec = signature(loss_func)
+ args_len = len(args_spec.parameters.keys())
+ if args_len == 1:
+ output_tensor = loss_func(output_tensor)
+ else:
+ ground_truth = builder.user_data["ground_truth"]
+ loss_fn_inputs = [output_tensor] + ground_truth
+ output_tensor = loss_func(*loss_fn_inputs)
+ assert args_len == len(loss_fn_inputs), "Mismatch of loss function #args and #actual inputs!"
+ builder.user_data["output_tensor"] = output_tensor
+ return temp_tensor, output_tensor
+
+
+VESACLE_INSTRUCTION_MAPPING_V = {
+ RECV_FORWARD: "vescale_interleavd_1f1b_recv_forward",
+ FWD: "vescale_interleavd_1f1b_forward",
+ BWD: "vescale_interleaved_1f1b_backward",
+ SET_OUTPUT_TO_NONE: "vescale_interleavd_1f1b_set_output_to_none",
+ SET_INPUTGRAD_TO_NONE: "vescale_interleavd_1f1b_set_input_grad_to_none",
+ SEND_FORWARD_RECV_FORWARD: "vescale_interleaved_1f1b_send_forward_recv_forward",
+ SEND_BACKWARD_RECV_BACKWARD: "vescale_interleavd_1f1b_send_backward_recv_backward",
+ SEND_FORWARD_BACKWARD_RECV_FORWARD_BACKWARD: "vescale_interleaved_1f1b_send_forward_backward_recv_forward_backward",
+ APPEND_GRADS: "vescale_interleavd_1f1b_append_grads",
+ APPEND_INPUTS: "vescale_interleavd_1f1b_append_inputs",
+ DEALLOCATE_OUTPUT_TENSOR: "vescale_interleavd_1f1b_deallocate_output_tensor",
+ DRAIN_SEND_REQS: "vescale_interleaved_1f1b_drain_send_reqs",
+ DRAIN_RECV_REQS: "vescale_interleaved_1f1b_drain_recv_reqs",
+ WAIT_FWD: "vescale_interleaved_1f1b_wait_fwd",
+ LAUNCH_SHARED_UNITS_SYNC: "vescale_interleavd_1f1b_launch_shared_units_sync",
+}
diff --git a/vescale/pipe/_schedules/pipedream_flush.py b/vescale/pipe/_schedules/pipedream_flush.py
new file mode 100644
index 0000000..5ec3204
--- /dev/null
+++ b/vescale/pipe/_schedules/pipedream_flush.py
@@ -0,0 +1,1287 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from vescale.pipe._schedules.instruction_base import (
+ PipelineSchema,
+ Status,
+ InstructionGenerator,
+ Shape,
+ BaseInstruction,
+ StageDeps,
+ CommPacket,
+ CompilePPCollectiveKind,
+ CompilePPCollectiveOperator,
+ VESCALE_INTRUCTION_BUILDER as builder,
+ register_instruction,
+ registed_functions,
+)
+from functools import partial
+from dataclasses import dataclass
+from dataclasses import field
+from collections import defaultdict
+from vescale.dtensor.dtensor import DTensor, make_dtensor
+import contextlib
+import torch
+import torch.distributed as dist
+from inspect import signature
+from vescale.dtensor.device_mesh import DeviceMesh
+from typing import Sequence, Optional, List, Union, Dict, Callable, Tuple
+import numpy as np
+from vescale.pipe.p2p_communication import (
+ recv_backward,
+ recv_forward,
+ send_backward,
+ send_forward,
+ send_forward_recv_backward,
+ send_backward_recv_forward,
+)
+from vescale.ndtimeline import ndtimer, ndtimeit_p2p
+from vescale.ndtimeline.predefined import FORWARD_COMPUTE, BACKWARD_COMPUTE, CROSS_MESH_RECV, CROSS_MESH_SEND
+from vescale.pipe.pipe_stage import PipeModule
+from vescale.dtensor._diff import dummy_p2p, manage_dump_file
+from torch.distributed._functional_collectives import send, recv
+from vescale.dtensor.placement_types import Placement
+from vescale.dtensor._utils import compute_global_tensor_info
+from torch.distributed.distributed_c10d import _get_default_group
+
+
+def maybe_tensor(tensor):
+ if isinstance(tensor, DTensor):
+ return tensor._local_tensor
+ elif isinstance(tensor, torch.Tensor):
+ return tensor
+ else:
+ raise RuntimeError(f"Error parsing tensor {tensor}")
+
+
+def cross_mesh_recv(comm, p2p_tensor):
+ mapping_group = comm.cur_mesh.get_mapping_rank(comm.peer_mesh)
+ if isinstance(mapping_group, int): # equal size
+ default_pg = _get_default_group()
+ with ndtimeit_p2p(CROSS_MESH_RECV, default_pg, mapping_group, is_batched=False):
+ tensor = torch.empty((3, 3), device=p2p_tensor.device, dtype=torch.int64)
+ recv(tensor, mapping_group, default_pg)
+ p_size = sum(tensor[:, 0] >= 0)
+ tensor = tensor[:p_size]
+ sharding_type = [Placement.serialize_from_tensor(p) for p in tensor]
+ sharding = sharding_type
+ if len(sharding_type) > 0:
+ global_shape, global_stride = compute_global_tensor_info(p2p_tensor, comm.cur_mesh, sharding)
+ p2p_tensor = make_dtensor(
+ p2p_tensor,
+ comm.cur_mesh,
+ sharding,
+ shape=torch.Size(global_shape),
+ dtype=p2p_tensor.dtype,
+ requires_grad=p2p_tensor.requires_grad,
+ stride=tuple(global_stride),
+ )
+ return p2p_tensor
+ else:
+ return p2p_tensor
+ else:
+ raise NotImplementedError("currently not support change mesh size")
+
+
+def cross_mesh_send(comm, dt):
+ mapping_group = comm.cur_mesh.get_mapping_rank(comm.peer_mesh)
+ if isinstance(mapping_group, int): # equal size
+ default_pg = _get_default_group()
+ with ndtimeit_p2p(CROSS_MESH_SEND, default_pg, mapping_group, is_batched=False):
+ if isinstance(dt, DTensor):
+ send_sharding = torch.stack(
+ [p.serialize_to_tensor(dt.device) for p in dt._spec.placements]
+ + [
+ torch.full((3,), -1, device=dt.device, dtype=torch.int64)
+ for _ in range(3 - len(dt._spec.placements))
+ ]
+ )
+ send(send_sharding, mapping_group, default_pg)
+ else: # tensor
+ send(torch.full((3, 3), -1, device=dt.device, dtype=torch.int64), mapping_group, default_pg)
+ else:
+ raise NotImplementedError("currently not support change mesh size")
+
+
+def cross_mesh_double(comm, fwd_tensor, p2p_tensor):
+ if isinstance(fwd_tensor, DTensor):
+ placements = fwd_tensor._spec.placements
+ global_shape, global_stride = compute_global_tensor_info(p2p_tensor, comm.cur_mesh, placements)
+ p2p_tensor = make_dtensor(
+ p2p_tensor,
+ comm.cur_mesh,
+ placements,
+ shape=torch.Size(global_shape),
+ dtype=p2p_tensor.dtype,
+ requires_grad=p2p_tensor.requires_grad,
+ stride=tuple(global_stride),
+ )
+ return p2p_tensor
+
+
+@dataclass
+class RECV_FORWARD(BaseInstruction):
+ comm_packages: List[CommPacket] = field(default_factory=list)
+ tensor_shapes: Union[List[Shape], Shape] = field(default_factory=list)
+ tensor_dtypes: Union[List[torch.dtype], torch.dtype] = field(default_factory=list)
+ batch_id: Optional[int] = None
+ debug: str = ""
+
+ @property
+ def name(self):
+ return "recv_forward"
+
+ def run(self) -> List:
+ def f(info):
+ comm, shape, dtype = info
+ p2p_tensor = recv_forward(
+ tensor_shape=shape,
+ recv_dtype=dtype,
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ )
+ p2p_tensor = cross_mesh_recv(comm, p2p_tensor)
+ return p2p_tensor
+
+ infos = zip(self.comm_packages, self.tensor_shapes, self.tensor_dtypes)
+ out = list(map(f, infos))
+ return out if len(out) > 0 else None
+
+ def compile(self) -> List[CompilePPCollectiveOperator]:
+ out: List[CompilePPCollectiveOperator] = []
+ for comm in self.comm_packages:
+ cur_mesh, peer_mesh = comm.cur_mesh, comm.peer_mesh
+ coordinate = (cur_mesh.mesh == dist.get_rank()).nonzero(as_tuple=True)
+ src = peer_mesh.mesh[coordinate].item()
+
+ out.append(CompilePPCollectiveOperator(kind=CompilePPCollectiveKind.RECV, src=src))
+ return out
+
+
+@dataclass
+class SEND_FORWARD(BaseInstruction):
+ comm_packages: List[CommPacket] = field(default_factory=list)
+ tensor_shapes: List[Shape] = field(default_factory=list)
+ batch_id: int = 0
+
+ @property
+ def name(self):
+ return "send_forward"
+
+ @dummy_p2p
+ def run(self, output_tensors: List[torch.Tensor]):
+ if not isinstance(output_tensors, list):
+ output_tensors = [output_tensors]
+
+ def f(info):
+ output_tensor, comm, shape = info
+ send_forward(
+ output_tensor=maybe_tensor(output_tensor),
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ tensor_shape=shape,
+ )
+ cross_mesh_send(comm, output_tensor)
+
+ infos = zip(output_tensors, self.comm_packages, self.tensor_shapes)
+ return list(map(f, infos))
+
+ def compile(self) -> List[CompilePPCollectiveOperator]:
+ out: List[CompilePPCollectiveOperator] = []
+ for comm in self.comm_packages:
+ cur_mesh, peer_mesh = comm.cur_mesh, comm.peer_mesh
+ coordinate = (cur_mesh.mesh == dist.get_rank()).nonzero(as_tuple=True)
+ dst = peer_mesh.mesh[coordinate].item()
+
+ out.append(CompilePPCollectiveOperator(kind=CompilePPCollectiveKind.SEND, dst=dst))
+ return out
+
+
+@dataclass
+class RECV_BACKWARD(BaseInstruction):
+ comm_packages: List[CommPacket] = field(default_factory=list)
+ tensor_shapes: Union[List[Shape], Shape] = field(default_factory=list)
+ tensor_dtypes: List[torch.dtype] = field(default_factory=list)
+
+ @property
+ def name(self):
+ return "recv_backward"
+
+ @dummy_p2p
+ def run(self):
+ def f(info):
+ comm, shape, dtype = info
+ p2p_tensor = recv_backward(
+ tensor_shape=shape,
+ recv_dtype=dtype,
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ )
+ p2p_tensor = cross_mesh_recv(comm, p2p_tensor)
+ return p2p_tensor
+
+ infos = zip(self.comm_packages, self.tensor_shapes, self.tensor_dtypes)
+ out = list(map(f, infos))
+ return out if len(out) > 0 else None
+
+ def compile(self) -> List[CompilePPCollectiveOperator]:
+ out: List[CompilePPCollectiveOperator] = []
+ for comm in self.comm_packages:
+ cur_mesh, peer_mesh = comm.cur_mesh, comm.peer_mesh
+ coordinate = (cur_mesh.mesh == dist.get_rank()).nonzero(as_tuple=True)
+ src = peer_mesh.mesh[coordinate].item()
+
+ out.append(CompilePPCollectiveOperator(kind=CompilePPCollectiveKind.RECV, src=src, is_backward=True))
+ return out
+
+
+@dataclass
+class SEND_BACKWARD(BaseInstruction):
+ recv_comms: List[CommPacket] = field(default_factory=list)
+ tensor_shapes: Union[List[Shape], Shape] = field(default_factory=list)
+
+ @property
+ def name(self):
+ return "send_backward"
+
+ @dummy_p2p
+ def run(self, input_tensor_grad):
+ if not isinstance(input_tensor_grad, list):
+ input_tensor_grad = [input_tensor_grad]
+
+ def f(info):
+ grad, comm, shape = info
+ send_backward(
+ input_tensor_grad=maybe_tensor(grad),
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ tensor_shape=shape,
+ )
+ cross_mesh_send(comm, grad)
+
+ infos = zip(input_tensor_grad, self.recv_comms, self.tensor_shapes)
+ return list(map(f, infos))
+
+ def compile(self) -> List[CompilePPCollectiveOperator]:
+ out: List[CompilePPCollectiveOperator] = []
+ for comm in self.recv_comms:
+ cur_mesh, peer_mesh = comm.cur_mesh, comm.peer_mesh
+ coordinate = (cur_mesh.mesh == dist.get_rank()).nonzero(as_tuple=True)
+ dst = peer_mesh.mesh[coordinate].item()
+
+ out.append(CompilePPCollectiveOperator(kind=CompilePPCollectiveKind.SEND, dst=dst, is_backward=True))
+ return out
+
+
+@dataclass
+class SEND_FORWARD_RECV_BACKWARD(BaseInstruction):
+ comm_packages: List[CommPacket] = field(default_factory=list)
+ tensor_shapes: Union[List[Shape], Shape] = field(default_factory=list)
+ tensor_dtypes: Union[List[torch.dtype], torch.dtype] = field(default_factory=list)
+ send_batch_id: int = 0
+ recv_batch_id: int = 0
+
+ @property
+ def name(self):
+ return "send_forward_recv_backward"
+
+ @dummy_p2p
+ def run(self, output_tensors):
+ if not isinstance(output_tensors, list):
+ output_tensors = [output_tensors]
+
+ def f(info):
+ output_tensor, comm, shape, dtype = info
+ p2p_tensor = send_forward_recv_backward(
+ output_tensor=maybe_tensor(output_tensor),
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ tensor_shape=shape,
+ recv_dtype=dtype,
+ )
+ p2p_tensor = cross_mesh_double(comm, output_tensor, p2p_tensor)
+ return p2p_tensor
+
+ infos = zip(output_tensors, self.comm_packages, self.tensor_shapes, self.tensor_dtypes)
+ out = list(map(f, infos))
+ return out if len(out) > 0 else None
+
+ def compile(self) -> List[CompilePPCollectiveOperator]:
+ out: List[CompilePPCollectiveOperator] = []
+ for comm in self.comm_packages:
+ cur_mesh, peer_mesh = comm.cur_mesh, comm.peer_mesh
+ coordinate = (cur_mesh.mesh == dist.get_rank()).nonzero(as_tuple=True)
+ peer_rank = peer_mesh.mesh[coordinate].item()
+
+ out.append(CompilePPCollectiveOperator(kind=CompilePPCollectiveKind.SEND, dst=peer_rank))
+ out.append(CompilePPCollectiveOperator(kind=CompilePPCollectiveKind.RECV, src=peer_rank, is_backward=True))
+ return out
+
+
+@dataclass
+class SEND_BACKWARD_RECV_FORWARD(BaseInstruction):
+ recv_comms: List[CommPacket]
+ tensor_shapes: Union[List[Shape], Shape] = field(default_factory=list)
+ tensor_dtypes: Union[List[torch.dtype], torch.dtype] = field(default_factory=list)
+
+ @property
+ def name(self):
+ return "send_backward_recv_forward"
+
+ @dummy_p2p
+ def run(self, input_tensor_grad):
+ if not isinstance(input_tensor_grad, list):
+ input_tensor_grad = [input_tensor_grad]
+
+ def f(info):
+ grad, comm, shape, dtype = info
+ p2p_tenosr = send_backward_recv_forward(
+ input_tensor_grad=maybe_tensor(grad),
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ tensor_shape=shape,
+ recv_dtype=dtype,
+ )
+ p2p_tenosr = cross_mesh_double(comm, grad, p2p_tenosr)
+ return p2p_tenosr
+
+ infos = zip(input_tensor_grad, self.recv_comms, self.tensor_shapes, self.tensor_dtypes)
+
+ out = list(map(f, infos))
+ return out if len(out) > 0 else None
+
+ def compile(self) -> List[CompilePPCollectiveOperator]:
+ out: List[CompilePPCollectiveOperator] = []
+ for comm in self.recv_comms:
+ cur_mesh, peer_mesh = comm.cur_mesh, comm.peer_mesh
+ coordinate = (cur_mesh.mesh == dist.get_rank()).nonzero(as_tuple=True)
+ peer_rank = peer_mesh.mesh[coordinate].item()
+
+ out.append(CompilePPCollectiveOperator(kind=CompilePPCollectiveKind.SEND, dst=peer_rank, is_backward=True))
+ out.append(CompilePPCollectiveOperator(kind=CompilePPCollectiveKind.RECV, src=peer_rank))
+ return out
+
+
+@dataclass
+class FORWARD_STEP(BaseInstruction):
+ model: Optional[Union[torch.nn.Module, PipeModule]] = None
+ is_pp_first_stage: bool = False
+ is_pp_last_stage: bool = False
+ local_comm: List[CommPacket] = field(default_factory=list)
+ p2p_comm: List[CommPacket] = field(default_factory=list)
+ p2p_index_mapping: List[Tuple[int, int]] = field(default_factory=list)
+ stage_id: int = 0
+ batch_id: int = 0
+ forward_only: bool = False
+
+ @property
+ def name(self):
+ return "forward_step"
+
+ def construct_input_args(self, p2p_tensors, local_inputs):
+ """
+ stage 0: a , c
+ stage 1: b
+ stage 2: dataloader
+
+ stage 2: forward(c,b,dataloader,a)
+
+ p2p_order: [(0, 2), (1, 0), (2, 0), (0, 0)]
+ send_order: [(0, 0), (0, 2), (1, 0)]
+ we assume that the p2p send is follow interge order
+
+ we assume that the p2p will allways be args
+
+ """
+ if not isinstance(local_inputs, (Sequence, Dict)):
+ local_inputs = [local_inputs]
+ if not isinstance(p2p_tensors, list):
+ p2p_tensors = [p2p_tensors]
+ p2p_index_without_local = list(
+ filter(lambda item: item.peer_stage_idx != self.stage_id, self.p2p_index_mapping)
+ )
+ p2p_send_order = sorted(p2p_index_without_local)
+ local_input_mapping = list(filter(lambda item: item.peer_stage_idx == self.stage_id, self.p2p_index_mapping))
+
+ args = []
+ kwargs = {}
+ ground_truth = []
+ for item in self.p2p_index_mapping:
+ if item.peer_stage_idx == self.stage_id:
+ index = local_input_mapping.index(item)
+ args.append(local_inputs[index])
+ else:
+ index = p2p_send_order.index(item)
+ args.append(p2p_tensors[index])
+ if isinstance(local_inputs, Sequence) and len(local_inputs) > 1:
+ ground_truth.append(local_inputs[-1])
+ elif isinstance(local_inputs, Dict) and "labels" in local_inputs:
+ ground_truth.append(local_inputs["labels"])
+ return args, kwargs, ground_truth
+
+ @dummy_p2p
+ def run(self, input_tensor, kwargs):
+ """Forward step for passed-in model.
+
+ If first stage, input tensor is obtained from data_iterator, otherwise
+ passed-in input_tensor is used.
+
+ Returns output tensor."""
+
+ data_iterator, forward_data_store, autocast_dtype, enable_autocast = (
+ kwargs["data_iterator"],
+ kwargs["forward_data_store"],
+ kwargs["autocast_dtype"],
+ kwargs["enable_autocast"],
+ )
+ if enable_autocast:
+ context_manager = torch.autocast("cuda", dtype=autocast_dtype)
+ else:
+ context_manager = contextlib.nullcontext()
+ with context_manager:
+
+ def prepare_data():
+ local_tensors = []
+ ground_truth = []
+ if data_iterator is not None:
+ if isinstance(data_iterator, list):
+ if len(data_iterator) > self.batch_id:
+ local_tensors = data_iterator[self.batch_id]
+ else:
+ local_tensors = next(data_iterator)
+ if isinstance(local_tensors, Sequence) and len(local_tensors) > 1:
+ ground_truth.append(local_tensors[-1])
+ elif isinstance(local_tensors, Dict) and "labels" in local_tensors:
+ ground_truth.append(local_tensors["labels"])
+ return input_tensor, local_tensors, ground_truth
+
+ builder.user_data["prepare_data_fn"] = prepare_data
+ builder.user_data["batch_id"] = self.batch_id
+ builder.user_data["p2p_tensors"] = input_tensor
+ p2p_tensor, local_tensors, ground_truth = registed_functions["vescale_1f1b_pre_forward_data"]()
+ builder.user_data["ground_truth"] = ground_truth
+ output_tensor = registed_functions["vescale_1f1b_forward"](p2p_tensor, local_tensors)
+ builder.user_data["output_tensor"] = output_tensor
+
+ if self.is_pp_last_stage:
+ # update status machine
+ output_tensor, loss_tensor = registed_functions["vescale_1f1b_loss_fn"]()
+ forward_data_store.append((output_tensor, loss_tensor))
+ if builder.loss_fn is None:
+ return output_tensor
+ else:
+ return loss_tensor
+
+ return output_tensor
+
+
+@dataclass
+class BACKWARD_STEP(BaseInstruction):
+ @property
+ def name(self):
+ return "backward step"
+
+ @dummy_p2p
+ def run(self, input_tensor, output_tensor, output_tensor_grad, kwargs):
+ """Backward step through passed-in output tensor.
+
+ If last stage, output_tensor_grad is None, otherwise gradient of loss
+ with respect to stage's output tensor.
+
+ Returns gradient of loss with respect to input tensor (None if first
+ stage)."""
+
+ grad_scaler = kwargs["grad_scaler"]
+ deallocate_pipeline_outputs = kwargs["deallocate_pipeline_outputs"]
+ # NOTE: This code currently can handle at most one skip connection. It
+ # needs to be modified slightly to support arbitrary numbers of skip
+ # connections.
+
+ # Retain the grad on the input_tensor.
+ unwrap_input_tensor_grad = False
+ if not isinstance(input_tensor, list):
+ input_tensor = [input_tensor]
+ unwrap_input_tensor_grad = True
+ for x in input_tensor:
+ if x is not None:
+ x.retain_grad()
+
+ if not isinstance(output_tensor, list):
+ output_tensor = [output_tensor]
+ if not isinstance(output_tensor_grad, list):
+ output_tensor_grad = [output_tensor_grad]
+
+ # extract loss value from output tensors
+ if isinstance(output_tensor[0], Sequence):
+ for j in range(len(output_tensor[0])):
+ if output_tensor[0][j].ndim == 0 and output_tensor[0][j].numel() == 1:
+ loss_value = output_tensor[0][j]
+ break
+ else:
+ loss_value = output_tensor[0][-1]
+ else:
+ loss_value = output_tensor[0]
+
+ # Backward pass.
+ if len(output_tensor_grad) == 0 and grad_scaler is not None:
+ output_tensor = grad_scaler(loss_value)
+
+ if deallocate_pipeline_outputs:
+ assert 0
+ else:
+ torch.autograd.backward(loss_value, grad_tensors=output_tensor_grad[0])
+
+ # Collect the grad of the input_tensor.
+ input_tensor_grad = [None]
+ if input_tensor is not None:
+ input_tensor_grad = []
+ for x in input_tensor:
+ if x is None:
+ input_tensor_grad.append(None)
+ else:
+ input_tensor_grad.append(x.grad)
+
+ if unwrap_input_tensor_grad:
+ input_tensor_grad = input_tensor_grad[0]
+
+ return input_tensor_grad
+
+
+@dataclass
+class DEALLOCATE_OUTPUT_TENSOR(BaseInstruction):
+ deallocate_out: bool = True
+
+ @property
+ def name(self):
+ return "deallocate output tensor "
+
+ @dummy_p2p
+ def run(self, out, deallocate_pipeline_outputs=False):
+ """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
+
+ This method should be called right after the output tensor has been
+ sent to the next pipeline stage. At this point, the output tensor is
+ only useful for its '.grad_fn' field, and not its '.data'.
+ """
+ # TODO: support DTensor
+ if (out is None) or (not deallocate_pipeline_outputs):
+ return
+
+ def f(out):
+ assert isinstance(out, [torch.Tensor, DTensor]), f"expected Tensor, found {type(out).__name__}."
+ assert out._base is None, "counter-productive to free a view of another tensor."
+ if isinstance(out, [torch.Tensor, DTensor]):
+ out._local_tensor.data = torch.empty(
+ (1,),
+ device=out.device,
+ dtype=out.dtype,
+ )
+ else:
+ out.data = torch.empty(
+ (1,),
+ device=out.device,
+ dtype=out.dtype,
+ )
+
+ if not isinstance(out, list):
+ for o in out:
+ f(o)
+ else:
+ f(out)
+
+
+@dataclass
+class APPEND_INPUTS(BaseInstruction):
+ @property
+ def name(self):
+ return "append_inputs"
+
+ @dummy_p2p
+ def run(self, input_tensors, input_tensor):
+ input_tensors.append(input_tensor)
+
+
+@dataclass
+class APPEND_OUTPUTS(BaseInstruction):
+ @property
+ def name(self):
+ return "append_outputs"
+
+ @dummy_p2p
+ def run(self, output_tensors, output_tensor):
+ output_tensors.append(output_tensor)
+
+
+@dataclass
+class POP_INPUT(BaseInstruction):
+ @property
+ def name(self):
+ return "pop input"
+
+ @dummy_p2p
+ def run(self, input_tensors):
+ input_tensor = input_tensors.pop(0)
+ return input_tensor
+
+
+@dataclass
+class POP_OUTPUT(BaseInstruction):
+ @property
+ def name(self):
+ return "pop output"
+
+ @dummy_p2p
+ def run(self, output_tensors):
+ output_tensor = output_tensors.pop(0)
+ return output_tensor
+
+
+class PipeDream(PipelineSchema):
+ """
+ generate pipedream schedule (a.k.a 1f1b)
+ memory-efficient than gpipe
+ """
+
+ @property
+ def name(self):
+ return "1f1b"
+
+ def _gen_schedule(self):
+ """
+ run forward then run backward
+ the sequence timeline as show before
+ d: device
+ m: batches
+ T: timeline
+
+ T (m,d) (m,d) (m,d)
+ - ------ ------ -------
+
+ 0 (0,0,F)
+ 1 (1,0,F) (0,1,F)
+ 2 (2,0,F) (1,1,F) (0,2,F)
+ 3 (0,2,B)
+ 4 (0,1,B) (1,2,F)
+ 5 (0,0,B) (2,1,F) (1,2,B)
+ 6 (3,0,F) (1,1,B) (2,2,F)
+ ...
+ """
+ m = self.batches
+ d = self.num_mesh
+
+ num_clock = (m + d - 1) * 2 # time todo flush
+ schedules = [[None] * d for c in range(num_clock)]
+ warmup_batches = [min(d - i - 1, m) for i in range(d)]
+ remain_batches = [m - i for i in warmup_batches]
+ next_fwd_batch_idx = [0 for _ in range(d)]
+ next_bwd_batch_idx = [0 for _ in range(d)]
+
+ self.warmup_batches = warmup_batches
+ self.remain_batches = remain_batches
+
+ new_timeline = list(range(d))
+ """
+ t_i|m
+ 0 1 2
+ 0 0 0 0
+ 1 0 0 0
+ 2 0 0 0
+ 3 0 0 1
+ 4 0 1 1
+ 5 1 1 1
+ 1f1b
+ """
+ bwd_done_idx = np.zeros(shape=[num_clock, d], dtype=np.int32)
+ # warm-up steps
+ for i in range(d):
+ for k in range(warmup_batches[i]):
+ t_i = new_timeline[i]
+ schedules[t_i][i] = Status(batch_idx=next_fwd_batch_idx[i], stage_id=i, f_b="F", stg="WUp", k=k)
+ new_timeline[i] += 1 # self add for new timeline
+ next_fwd_batch_idx[i] += 1 # do next micro batch
+
+ # run 1f1b steps
+ for i in reversed(range(d)):
+ for idx in range(remain_batches[i]):
+ # do forward
+ t_i = new_timeline[i]
+ schedules[t_i][i] = Status(batch_idx=next_fwd_batch_idx[i], stage_id=i, f_b="F", stg="1f1b", k=idx)
+ next_fwd_batch_idx[i] += 1
+ bwd_done_idx[t_i][i] = next_bwd_batch_idx[i]
+ t_i += 1
+
+ # do backward
+ if i + 1 < d:
+ while bwd_done_idx[t_i][i + 1] < next_bwd_batch_idx[i]:
+ # if the stage 2 is done, the stage i must be equal 0
+ assert bwd_done_idx[t_i - 1][i] == next_bwd_batch_idx[i]
+ bwd_done_idx[t_i][i] = bwd_done_idx[t_i - 1][i]
+ t_i = t_i + 1
+
+ if idx == remain_batches[i] - 1: # last iterator
+ schedules[t_i][i] = Status(
+ batch_idx=next_bwd_batch_idx[i], stage_id=i, f_b="B", stg="1f1b-l", k=idx
+ )
+ else:
+ schedules[t_i][i] = Status(batch_idx=next_bwd_batch_idx[i], stage_id=i, f_b="B", stg="1f1b", k=idx)
+ bwd_done_idx[t_i][i] = next_bwd_batch_idx[i]
+ next_bwd_batch_idx[i] += 1
+ new_timeline[i] = t_i + 1
+
+ # run cool duwn
+ for i in reversed(range(d)):
+ for k in range(warmup_batches[i]):
+ assert i + 1 < d
+ t_i = new_timeline[i]
+ while bwd_done_idx[t_i][i + 1] <= next_bwd_batch_idx[i]:
+ bwd_done_idx[t_i][i] = next_bwd_batch_idx[i]
+ t_i = t_i + 1
+ schedules[t_i][i] = Status(batch_idx=next_bwd_batch_idx[i], stage_id=i, f_b="B", stg="CD", k=k)
+ bwd_done_idx[t_i][i] = next_bwd_batch_idx[i]
+ next_bwd_batch_idx[i] += 1
+ new_timeline[i] = t_i + 1
+ if i > 0:
+ bwd_done_idx[new_timeline[i] : num_clock, i] = m
+ return schedules
+
+
+class OneFOneBInstrcutionGenerator(InstructionGenerator):
+ def __init__(
+ self,
+ deps: StageDeps,
+ meshes: List[DeviceMesh],
+ batches: int,
+ default_shape: Optional[Shape] = None,
+ default_dtype: Optional[torch.dtype] = None,
+ batch_shape_lists: Optional[List[Dict[int, Shape]]] = None,
+ batch_dtype_lists: Optional[List[Dict[int, torch.dtype]]] = None,
+ forward_only: bool = False,
+ ):
+ forward_only = True if not torch.is_grad_enabled() else forward_only
+ super().__init__(
+ deps=deps,
+ meshes=meshes,
+ batches=batches,
+ default_shape=default_shape,
+ default_dtype=default_dtype,
+ batch_shape_lists=batch_shape_lists,
+ batch_dtype_lists=batch_dtype_lists,
+ forward_only=forward_only,
+ )
+ self.num_stage = len(meshes)
+ self.schema = PipeDream(num_stage=self.num_stage, meshes=meshes, batches=self.batches)
+ self.forward_only = forward_only
+
+ def get_tensor_shape(self, microbatch_id, input_id):
+ if self.batch_shape_lists:
+ if input_id in self.batch_shape_lists[microbatch_id].keys():
+ return self.batch_shape_lists[microbatch_id][input_id]
+ return self.default_shape
+
+ def get_tensor_dtype(self, microbatch_id, input_id):
+ if self.batch_dtype_lists:
+ if input_id in self.batch_dtype_lists[microbatch_id].keys():
+ return self.batch_dtype_lists[microbatch_id][input_id]
+ return self.default_dtype
+
+ def get_tensor_shapes_and_dtypes(self, comm_packages: List[CommPacket], microbatch_id: int):
+ def get_shape_or_dtype(f: Callable, package: CommPacket):
+ return f(microbatch_id, package.input_id)
+
+ shapes = map(partial(get_shape_or_dtype, self.get_tensor_shape), comm_packages)
+ dtypes = map(partial(get_shape_or_dtype, self.get_tensor_dtype), comm_packages)
+ return list(shapes), list(dtypes)
+
+ # call by pipe emitter
+ def gen_instruction(self):
+ # If the context is torch.no_grad(), only execute forward
+ _forward_only = self.forward_only
+ if not torch.is_grad_enabled():
+ self.forward_only = True
+
+ schedules = self.schema.schedules
+ self.instruction_list = [[] for _ in range(self.num_stage)]
+ stack = defaultdict(list) # for 1f1b
+ first_time_1f1b = [True] * self.num_stage
+ for clk, stages_schemas in enumerate(schedules):
+ for s, schema in enumerate(stages_schemas):
+ send_comms = self.deps.get_send_comms(s)
+ recv_comms = self.deps.get_recv_comms(s)
+ p2p_index_mapping = self.deps.mapping[s]
+ cur_model = self.deps.get_current_model(s)
+ local_comm = self.deps.get_local_comms(s)
+ is_pp_first_stage = self.deps.is_pipeline_first_stage(s)
+ is_pp_last_stage = self.deps.is_pipeline_last_stage(s)
+ if isinstance(cur_model, Sequence):
+ assert self.num_chunk == 1, "1f1b support model chunk is 1."
+ cur_model = cur_model[0]
+ # batch size, stage idx, forward backward,
+ if schema:
+ b_idx = schema.batch_idx
+ stg = schema.stg
+ if "WUp" in stg: # warmup stage
+ # recv forward
+ recv_shapes, recv_dtypes = self.get_tensor_shapes_and_dtypes(recv_comms, b_idx)
+ send_shapes, _ = self.get_tensor_shapes_and_dtypes(send_comms, b_idx)
+ self._set_inst(
+ RECV_FORWARD(
+ comm_packages=recv_comms,
+ tensor_shapes=recv_shapes,
+ tensor_dtypes=recv_dtypes,
+ batch_id=b_idx,
+ debug="warm-up",
+ ),
+ s,
+ )
+ self._set_inst(
+ FORWARD_STEP(
+ model=cur_model,
+ is_pp_first_stage=is_pp_first_stage,
+ is_pp_last_stage=is_pp_last_stage,
+ local_comm=local_comm,
+ p2p_comm=recv_comms,
+ p2p_index_mapping=p2p_index_mapping,
+ stage_id=s,
+ batch_id=b_idx,
+ forward_only=self.forward_only,
+ ),
+ s,
+ )
+ self._set_inst(
+ SEND_FORWARD(
+ comm_packages=send_comms,
+ tensor_shapes=send_shapes,
+ batch_id=b_idx,
+ ),
+ s,
+ )
+
+ if not self.forward_only:
+ self._set_inst(APPEND_INPUTS(), s)
+ self._set_inst(APPEND_OUTPUTS(), s)
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(), s)
+ elif "1f1b" in stg: # 1f1b stage
+ cur_st = stack[s]
+ if len(cur_st) < 2:
+ cur_st.append(schema) # lazy do
+ else:
+ raise RuntimeError("unknown schedule")
+
+ if len(cur_st) == 2:
+ if first_time_1f1b[s]:
+ recv_shapes, recv_dtypes = self.get_tensor_shapes_and_dtypes(recv_comms, b_idx)
+ # before run 1f1b
+ self._set_inst(
+ RECV_FORWARD(
+ comm_packages=recv_comms,
+ tensor_shapes=recv_shapes,
+ tensor_dtypes=recv_dtypes,
+ batch_id=b_idx,
+ debug="first 1f1b",
+ ),
+ s,
+ )
+ first_time_1f1b[s] = False
+ fwd = cur_st[0]
+ bwd = cur_st[1]
+ fw_b_idx = fwd.batch_idx
+ bw_b_idx = bwd.batch_idx
+ self._set_inst(
+ FORWARD_STEP(
+ model=cur_model,
+ is_pp_first_stage=is_pp_first_stage,
+ is_pp_last_stage=is_pp_last_stage,
+ local_comm=local_comm,
+ p2p_comm=recv_comms,
+ p2p_index_mapping=p2p_index_mapping,
+ stage_id=s,
+ batch_id=fw_b_idx,
+ forward_only=self.forward_only,
+ ),
+ s,
+ )
+
+ if self.forward_only:
+ send_shapes, _ = self.get_tensor_shapes_and_dtypes(send_comms, fw_b_idx)
+ self._set_inst(
+ SEND_FORWARD(
+ comm_packages=send_comms, tensor_shapes=send_shapes, batch_id=fw_b_idx
+ ),
+ s,
+ )
+ last_iteration = fwd.k == (self.schema.remain_batches[s] - 1)
+ if not last_iteration:
+ recv_shapes, recv_dtypes = self.get_tensor_shapes_and_dtypes(recv_comms, fw_b_idx)
+ self._set_inst(
+ RECV_FORWARD(
+ comm_packages=recv_comms,
+ tensor_shapes=recv_shapes,
+ tensor_dtypes=recv_dtypes,
+ batch_id=fw_b_idx,
+ debug="last_1f1b",
+ ),
+ s,
+ )
+ stack[s].clear()
+ else:
+ send_shapes, send_dtypes = self.get_tensor_shapes_and_dtypes(send_comms, bw_b_idx)
+ self._set_inst(
+ SEND_FORWARD_RECV_BACKWARD(
+ comm_packages=send_comms,
+ tensor_shapes=send_shapes,
+ tensor_dtypes=send_dtypes,
+ send_batch_id=fw_b_idx,
+ recv_batch_id=bw_b_idx,
+ ),
+ s,
+ )
+ self._set_inst(APPEND_INPUTS(), s)
+ self._set_inst(APPEND_OUTPUTS(), s)
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(), s)
+ self._set_inst(POP_INPUT(), s)
+ self._set_inst(POP_OUTPUT(), s)
+ self._set_inst(BACKWARD_STEP(), s)
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(deallocate_out=False), s)
+
+ if stg == "1f1b-l":
+ recv_shapes, recv_dtypes = self.get_tensor_shapes_and_dtypes(recv_comms, bw_b_idx)
+ self._set_inst(SEND_BACKWARD(recv_comms=recv_comms, tensor_shapes=recv_shapes), s)
+ else:
+ recv_shapes, recv_dtypes = self.get_tensor_shapes_and_dtypes(recv_comms, fw_b_idx)
+ self._set_inst(
+ SEND_BACKWARD_RECV_FORWARD(
+ recv_comms=recv_comms, tensor_shapes=recv_shapes, tensor_dtypes=recv_dtypes
+ ),
+ s,
+ )
+ stack[s].clear() # save for next
+ else: # 1f1b do f
+ continue
+ elif stg == "CD": # cool down stage
+ if not self.forward_only:
+ self._set_inst(POP_INPUT(), s)
+ self._set_inst(POP_OUTPUT(), s)
+ # recv backward
+
+ send_shapes, send_dtypes = self.get_tensor_shapes_and_dtypes(send_comms, b_idx)
+ self._set_inst(
+ RECV_BACKWARD(
+ comm_packages=send_comms, tensor_shapes=send_shapes, tensor_dtypes=send_dtypes
+ ),
+ s,
+ )
+ # backward step
+ self._set_inst(BACKWARD_STEP(), s)
+ # deallocate input, output
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(), s)
+ self._set_inst(DEALLOCATE_OUTPUT_TENSOR(deallocate_out=False), s)
+ # send backward
+ recv_shapes, recv_dtypes = self.get_tensor_shapes_and_dtypes(recv_comms, b_idx)
+ self._set_inst(SEND_BACKWARD(recv_comms=recv_comms, tensor_shapes=recv_shapes), s)
+ else: # bubble
+ # TODO
+ # do any other
+ continue
+ self.gen_instruction_str_list()
+
+ # restore original self.forward_only if the current context manager is torch.no_grad()
+ if not torch.is_grad_enabled():
+ self.forward_only = _forward_only
+
+ return self.instruction_list
+
+ def gen_instruction_str_list(self):
+ instruction_lists = self.instruction_list
+ stage_strs = defaultdict(str)
+ for stage_id, instruction_list in enumerate(instruction_lists):
+ cur_stage_str = stage_strs[stage_id]
+ for inst in instruction_list:
+ cur_stage_str += f"{VESACLE_INSTRUCTION_MAPPING[type(inst)]},"
+ cur_stage_str = cur_stage_str[:-1]
+ stage_strs[stage_id] = cur_stage_str
+ builder.build_from_dict(stage_strs)
+
+ @manage_dump_file
+ def execute(
+ self,
+ stage_id,
+ autocast_dtype=torch.float,
+ enable_autocast=False,
+ grad_scaler=None,
+ deallocate_pipeline_outputs=False,
+ ):
+ builder.constant_data["autocast_dtype"] = autocast_dtype
+ builder.constant_data["enable_autocast"] = enable_autocast
+ builder.constant_data["grad_scaler"] = grad_scaler
+ builder.constant_data["deallocate_pipeline_outputs"] = deallocate_pipeline_outputs
+
+ user_data = builder.user_data
+ user_data["input_tensors"] = []
+ user_data["output_tensors"] = []
+ user_data["input_tensor"] = None # engine need to maintain the dataflow
+ user_data["output_tensor"] = None # engine need to maintian the output flow
+ user_data["output_tensor_grad"] = None
+ user_data["input_tensor_grad"] = None
+ user_data["forward_data_store"] = []
+
+ instruction_list = self.get_instruction_list(stage_id)
+ builder.stage_id = stage_id
+ builder_instruction_list = builder.global_instructions_funcs[stage_id]
+
+ _forward_only = self.forward_only
+ if not torch.is_grad_enabled():
+ self.forward_only = True
+
+ for inst, fn in zip(instruction_list, builder_instruction_list):
+ user_data["inst"] = inst
+ fn()
+
+ # restore original self.forward_only if the current context manager is torch.no_grad()
+ if not torch.is_grad_enabled():
+ self.forward_only = _forward_only
+
+ return user_data["forward_data_store"]
+
+
+@register_instruction(name="vescale_1f1b_recv_forward")
+def vescale_recv_forward():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ input_tensor = inst.run()
+ builder.user_data["input_tensor"] = input_tensor
+ return input_tensor
+
+
+@register_instruction(name="vescale_1f1b_recv_backward")
+def vescale_recv_backward():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ output_tensor_grad = inst.run()
+ builder.user_data["output_tensor_grad"] = output_tensor_grad
+ return output_tensor_grad
+
+
+@register_instruction(name="vescale_1f1b_send_forward")
+def vescale_send_forward():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ output_tensor = user_data["output_tensor"]
+ inst.run(output_tensors=output_tensor)
+
+
+@register_instruction(name="vescale_1f1b_send_backward")
+def vescale_send_backward():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ input_tensor_grad = user_data["input_tensor_grad"]
+ inst.run(input_tensor_grad=input_tensor_grad)
+
+
+@register_instruction(name="vescale_1f1b_send_forward_recv_backward")
+def vescale_send_forward_recv_backward():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ output_tensor = user_data["output_tensor"]
+ output_tensor_grad = inst.run(output_tensors=output_tensor)
+ builder.user_data["output_tensor_grad"] = output_tensor_grad
+
+
+@register_instruction(name="vescale_1f1b_send_backward_recv_forward")
+def vescale_send_backward_recv_forward():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ input_tensor_grad = user_data["input_tensor_grad"]
+ with torch.no_grad():
+ input_tensor = inst.run(input_tensor_grad=input_tensor_grad)
+ builder.user_data["input_tensor"] = input_tensor
+
+
+@register_instruction(name="vescale_1f1b_forward_step")
+@ndtimer(FORWARD_COMPUTE)
+def vescale_forward_step():
+ user_data = builder.user_data
+ constant_data = builder.constant_data
+ inst = user_data["inst"]
+ input_tensor = user_data["input_tensor"]
+ forward_data_store = user_data["forward_data_store"]
+ autocast_dtype = constant_data["autocast_dtype"]
+ builder.model = inst.model
+ if not autocast_dtype:
+ autocast_dtype = torch.float32
+ enable_autocast = constant_data["enable_autocast"]
+ if not enable_autocast:
+ enable_autocast = False
+ if forward_data_store is None:
+ forward_data_store = []
+ forward_args = {
+ "data_iterator": builder.dataloader,
+ "forward_data_store": forward_data_store,
+ "autocast_dtype": autocast_dtype,
+ "enable_autocast": enable_autocast,
+ }
+ output_tensor = inst.run(input_tensor=input_tensor, kwargs=forward_args)
+ builder.user_data["output_tensor"] = output_tensor
+ builder.user_data["forward_data_store"] = forward_data_store
+
+
+@register_instruction(name="vescale_1f1b_loss_fn")
+def loss_fn():
+ user_data = builder.user_data
+ output_tensor = user_data["output_tensor"]
+ loss_func = builder.loss_fn
+ if loss_func is None or output_tensor is None:
+ return output_tensor, None
+ temp_tensor = output_tensor
+ ground_truth = user_data["ground_truth"]
+ # signature provides a more uniform way to parse callable arguments, including lambda functions
+ args_spec = signature(loss_func)
+ args_len = len(args_spec.parameters.keys())
+ if args_len == 1:
+ output_tensor = loss_func(output_tensor)
+ else:
+ ground_truth = builder.user_data["ground_truth"]
+ loss_fn_inputs = [output_tensor] + ground_truth
+ output_tensor = loss_func(*loss_fn_inputs)
+ assert args_len == len(loss_fn_inputs), "Mismatch of loss function #args and #actual inputs!"
+ return temp_tensor, output_tensor
+
+
+@register_instruction(name="vescale_1f1b_pre_forward_data")
+def prepare_data():
+ user_data = builder.user_data
+ return user_data["prepare_data_fn"]()
+
+
+@register_instruction(name="vescale_1f1b_forward")
+def forward_fn(p2p_input, local_input):
+ if isinstance(builder.model, PipeModule):
+ return builder.model(p2p_input, local_input, chunk_id=0)
+ else:
+
+ def _feed_input(model, data):
+ if isinstance(data, Sequence):
+ return model(*data)
+ elif isinstance(data, Dict):
+ return model(**data)
+ else:
+ return model(data)
+
+ if p2p_input is not None:
+ return _feed_input(builder.model, p2p_input)
+ else:
+ return _feed_input(builder.model, local_input)
+
+
+@register_instruction(name="vescale_1f1b_backward_step")
+@ndtimer(BACKWARD_COMPUTE)
+def vescale_backward_step():
+ constant_data = builder.constant_data
+ grad_scaler = constant_data["grad_scaler"]
+ deallocate_pipeline_outputs = constant_data["deallocate_pipeline_outputs"]
+ backward_args = {
+ "grad_scaler": grad_scaler,
+ "deallocate_pipeline_outputs": deallocate_pipeline_outputs,
+ }
+
+ user_data = builder.user_data
+ input_tensor = user_data["input_tensor"]
+ output_tensor = user_data["output_tensor"]
+ output_tensor_grad = user_data["output_tensor_grad"]
+ inst = user_data["inst"]
+
+ input_tensor_grad = inst.run(
+ input_tensor=input_tensor,
+ output_tensor=output_tensor,
+ output_tensor_grad=output_tensor_grad,
+ kwargs=backward_args,
+ )
+ builder.user_data["input_tensor_grad"] = input_tensor_grad
+
+
+@register_instruction(name="vescale_1f1b_pop_input")
+def vescale_1f1b_pop_input():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ input_tensors = user_data["input_tensors"]
+ input_tensor = inst.run(input_tensors=input_tensors)
+ builder.user_data["input_tensor"] = input_tensor
+
+
+@register_instruction(name="vescale_1f1b_pop_output")
+def vescale_1f1b_pop_output():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ output_tensors = user_data["output_tensors"]
+ output_tensor = inst.run(output_tensors=output_tensors)
+ builder.user_data["output_tensor"] = output_tensor
+
+
+@register_instruction(name="vescale_1f1b_append_inputs")
+def vescale_1f1b_append_inputs():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ input_tensors = user_data["input_tensors"]
+ input_tensor = user_data["input_tensor"]
+ if input_tensors is None:
+ input_tensors = []
+ inst.run(input_tensors=input_tensors, input_tensor=input_tensor)
+ user_data["input_tensors"] = input_tensors
+
+
+@register_instruction(name="vescale_1f1b_append_outputs")
+def vescale_1f1b_append_outputs():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ output_tensors = user_data["output_tensors"]
+ output_tensor = user_data["output_tensor"]
+ if output_tensors is None:
+ output_tensors = []
+ inst.run(output_tensors=output_tensors, output_tensor=output_tensor)
+ user_data["output_tensors"] = output_tensors
+
+
+@register_instruction(name="vescale_1f1b_deallocate_output_tensor")
+def vescale_1f1b_deallocate_output_tensor():
+ user_data = builder.user_data
+ inst = user_data["inst"]
+ const_data = builder.constant_data
+ deallocate_pipeline_outputs = const_data["deallocate_pipeline_outputs"]
+ if inst.deallocate_out:
+ output_tensor = user_data["output_tensor"]
+ inst.run(output_tensor, deallocate_pipeline_outputs=deallocate_pipeline_outputs)
+ else:
+ input_tensor = user_data["input_tensor"]
+ if input_tensor and input_tensor[0] is not None:
+ input_tensor[0].grad = None
+ inst.run(input_tensor, deallocate_pipeline_outputs=deallocate_pipeline_outputs)
+
+
+VESACLE_INSTRUCTION_MAPPING = {
+ RECV_FORWARD: "vescale_1f1b_recv_forward",
+ RECV_BACKWARD: "vescale_1f1b_recv_backward",
+ SEND_FORWARD: "vescale_1f1b_send_forward",
+ SEND_BACKWARD: "vescale_1f1b_send_backward",
+ SEND_FORWARD_RECV_BACKWARD: "vescale_1f1b_send_forward_recv_backward",
+ SEND_BACKWARD_RECV_FORWARD: "vescale_1f1b_send_backward_recv_forward",
+ FORWARD_STEP: "vescale_1f1b_forward_step",
+ BACKWARD_STEP: "vescale_1f1b_backward_step",
+ POP_INPUT: "vescale_1f1b_pop_input",
+ POP_OUTPUT: "vescale_1f1b_pop_output",
+ APPEND_INPUTS: "vescale_1f1b_append_inputs",
+ APPEND_OUTPUTS: "vescale_1f1b_append_outputs",
+ DEALLOCATE_OUTPUT_TENSOR: "vescale_1f1b_deallocate_output_tensor",
+}
diff --git a/vescale/pipe/_schedules/pp_collective_emitter.py b/vescale/pipe/_schedules/pp_collective_emitter.py
new file mode 100644
index 0000000..edae5a5
--- /dev/null
+++ b/vescale/pipe/_schedules/pp_collective_emitter.py
@@ -0,0 +1,289 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from typing import List, Union, Dict
+import logging
+
+import torch
+from torch.export.graph_signature import TensorArgument
+
+from vescale.pipe.pipe_emmiter import ScheduleEngine, OneFOneBInstrcutionGenerator
+from vescale.plan.spec import ScheduleType
+from vescale.pipe._schedules.instruction_base import (
+ BaseInstruction,
+ CompilePPCollectiveKind,
+ CompilePPCollectiveOperator,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def read_fg(fg):
+ num_inputs = 0
+ num_outputs = None
+ for node in fg.graph.nodes:
+ if node.op == "placeholder":
+ num_inputs += 1
+ if node.op == "output":
+ num_outputs = len(node.args[0])
+ return num_inputs, num_outputs
+
+
+class PPCollectiveOpEmitter:
+ def __init__(self, curr_rank: int = None) -> None:
+ self.num_params_and_buffers = self.num_real_inputs = self.num_real_outputs = None
+
+ self.curr_rank = curr_rank
+
+ self.fwd_send_dsts = []
+ self.bwd_send_dsts = []
+ self.fwd_recv_srcs = []
+ self.bwd_recv_srcs = []
+
+ def gen_pp_collective_topo_from_schedule_engine(self, pipe_engine: ScheduleEngine):
+ fwd_recv_srcs, fwd_send_dsts, bwd_send_dsts, bwd_recv_srcs = set(), set(), set(), set()
+ assert (
+ pipe_engine.schedule == ScheduleType.SIMPLE_1F1B
+ ), "For inserting send/recv operators, we only need the topology information, please consider use this simplier PipeSchedule"
+ assert isinstance(
+ pipe_engine.p_emmiter.instruction_generator, OneFOneBInstrcutionGenerator
+ ), "For inserting send/recv operators, we only need the topology information, please consider use this simplier PipeSchedule"
+ insts: List[BaseInstruction] = pipe_engine.get_instruction_list(pipe_engine.stage_id)
+ compiled_insts: List[List[CompilePPCollectiveOperator]] = [
+ inst.compile() for inst in insts if hasattr(inst, "compile")
+ ]
+ flat_compile_insts = []
+ for list_insts in compiled_insts:
+ flat_compile_insts.extend(list_insts)
+ for inst in flat_compile_insts:
+ if inst.kind is CompilePPCollectiveKind.BORADCAST:
+ raise NotImplementedError("broadcast is not supported now")
+ elif inst.kind is CompilePPCollectiveKind.SEND:
+ if inst.is_backward:
+ bwd_send_dsts.add(inst.dst)
+ else:
+ fwd_send_dsts.add(inst.dst)
+ elif inst.kind is CompilePPCollectiveKind.RECV:
+ if inst.is_backward:
+ bwd_recv_srcs.add(inst.src)
+ else:
+ fwd_recv_srcs.add(inst.src)
+ else:
+ raise NotImplementedError("Unknown collective operators")
+ self.gen_pp_collective_topo_from_given(
+ list(fwd_send_dsts), list(fwd_recv_srcs), list(bwd_send_dsts), list(bwd_recv_srcs)
+ )
+
+ def gen_pp_collective_topo_from_given(
+ self,
+ fwd_send_dsts: List[int] = None,
+ fwd_recv_srcs: List[int] = None,
+ bwd_send_dsts: List[int] = None,
+ bwd_recv_srcs: List[int] = None,
+ ):
+ self.fwd_send_dsts = fwd_send_dsts
+ self.fwd_recv_srcs = fwd_recv_srcs
+ self.bwd_send_dsts = bwd_send_dsts
+ self.bwd_recv_srcs = bwd_recv_srcs
+
+ # this function should return a dict to indicate a output_spec change in ExportedProgram
+ def insert_send_fwd(self, fg: torch.fx.GraphModule) -> Dict[str, str]:
+ if not self.fwd_send_dsts:
+ return {}
+ assert len(self.fwd_send_dsts) == self.num_real_outputs
+ replaced_outputs = {}
+ for node in fg.graph.nodes:
+ if node.op != "output":
+ continue
+ with fg.graph.inserting_before(node):
+ node_args = node.args[0]
+ for i in range(self.num_real_outputs):
+ arg = node_args[i]
+ new_node = fg.graph.create_node(
+ op="call_function",
+ target=torch.ops.c10d_functional.send.default,
+ args=(
+ arg,
+ self.fwd_send_dsts[i],
+ f"{self.curr_rank}{self.fwd_send_dsts[i]}",
+ [self.curr_rank, self.fwd_send_dsts[i]],
+ 2,
+ ),
+ kwargs={},
+ name="pp_send_fwd",
+ )
+ new_node.meta["stack_trace"] = "inserted by pp_collective_emitter"
+ new_node.meta["val"] = arg.meta.get("val", None)
+ new_node.meta["tensor_meta"] = arg.meta.get("tensor_meta", None)
+ replaced_outputs[arg.name] = new_node.name
+ node.replace_input_with(arg, new_node)
+ fg.recompile()
+ return replaced_outputs
+
+ def insert_recv_fwd(self, fg: torch.fx.GraphModule):
+ if not self.fwd_recv_srcs:
+ return
+ assert len(self.fwd_recv_srcs) == self.num_real_inputs
+ seen_placeholders = 0
+ for node in fg.graph.nodes:
+ if node.op != "placeholder":
+ continue
+ seen_placeholders += 1
+ if seen_placeholders <= self.num_params_and_buffers:
+ continue
+ real_input_idx = seen_placeholders - self.num_params_and_buffers - 1
+ with fg.graph.inserting_after(node):
+ src = self.fwd_recv_srcs[real_input_idx]
+ new_node = fg.graph.create_node(
+ op="call_function",
+ target=torch.ops.c10d_functional.recv.default,
+ args=(
+ node,
+ src,
+ f"{src}{self.curr_rank}",
+ [src, self.curr_rank],
+ 2,
+ ),
+ kwargs={},
+ name="pp_recv_fwd",
+ )
+ new_node.meta["stack_trace"] = "inserted by pp_collective_emitter"
+ new_node.meta["val"] = node.meta.get("val", None)
+ new_node.meta["tensor_meta"] = node.meta.get("tensor_meta", None)
+ for user in list(node.users):
+ if user == new_node:
+ continue
+ user.replace_input_with(node, new_node)
+
+ fg.recompile()
+
+ def insert_send_bwd(self, fg: torch.fx.GraphModule):
+ if not self.bwd_send_dsts:
+ return
+ assert len(self.bwd_send_dsts) == self.num_real_inputs
+ for node in fg.graph.nodes:
+ if node.op != "output":
+ continue
+ with fg.graph.inserting_before(node):
+ args = node.args[0]
+ for i in range(self.num_real_inputs):
+ dst = self.bwd_send_dsts[i]
+ arg = args[i + self.num_params_and_buffers]
+ new_node = fg.graph.create_node(
+ op="call_function",
+ target=torch.ops.c10d_functional.send.default,
+ args=(
+ arg,
+ dst,
+ f"{self.curr_rank}{dst}",
+ [self.curr_rank, dst],
+ 2,
+ ),
+ kwargs={},
+ name="pp_send_bwd",
+ )
+ new_node.meta["stack_trace"] = "inserted by pp_collective_emitter"
+ new_node.meta["val"] = arg.meta.get("val", None)
+ new_node.meta["tensor_meta"] = arg.meta.get("tensor_meta", None)
+ node.replace_input_with(arg, new_node)
+ fg.recompile()
+
+ def insert_recv_bwd(self, fg: torch.fx.GraphModule):
+ if not self.bwd_recv_srcs:
+ return
+ assert len(self.bwd_recv_srcs) == self.num_real_outputs
+ seen_placeholders = 0
+ for node in fg.graph.nodes:
+ if node.op != "placeholder":
+ continue
+ seen_placeholders += 1
+ if seen_placeholders <= self.num_params_and_buffers:
+ continue
+ with fg.graph.inserting_after(node):
+ src = self.bwd_recv_srcs[seen_placeholders - self.num_params_and_buffers - 1]
+ new_node = fg.graph.create_node(
+ op="call_function",
+ target=torch.ops.c10d_functional.recv.default,
+ args=(
+ node,
+ src,
+ f"{src}{self.curr_rank}",
+ [src, self.curr_rank],
+ 2,
+ ),
+ kwargs={},
+ name="pp_recv_bwd",
+ )
+ new_node.meta["stack_trace"] = "inserted by pp_collective_emitter"
+ new_node.meta["val"] = node.meta.get("val", None)
+ new_node.meta["tensor_meta"] = node.meta.get("tensor_meta", None)
+ for user in list(node.users):
+ if user == new_node:
+ continue
+ user.replace_input_with(node, new_node)
+
+ fg.recompile()
+
+ def load_original_graph_module(self, original_gm):
+ named_parameters = dict(original_gm.named_parameters(remove_duplicate=False))
+ named_buffers = dict(original_gm.named_buffers(remove_duplicate=False))
+ self.num_params_and_buffers = len(named_buffers) + len(named_parameters)
+ self.num_real_inputs, self.num_real_outputs = read_fg(original_gm)
+
+ def run(self, fg: Union[torch.fx.GraphModule, torch.export.ExportedProgram] = None, is_backward: bool = None):
+ if isinstance(fg, torch.fx.GraphModule):
+ logging.info(
+ "You are inserting PP collective operators to a torch.compiled graph, make sure call PPCollectiveOpEmitter.load_original_graph_module first"
+ )
+ assert (
+ self.num_real_outputs is not None
+ and self.num_params_and_buffers is not None
+ and self.num_real_inputs is not None
+ ), "Please call PPCollectiveOpEmitter.load_original_graph_module first"
+
+ assert is_backward is not None, "Please provide is_backward argument"
+ if not is_backward:
+ num_total_inputs, _ = read_fg(fg)
+ else:
+ _, num_total_inputs = read_fg(fg)
+ assert num_total_inputs == self.num_real_inputs + self.num_params_and_buffers
+ if not is_backward:
+ self.insert_send_fwd(fg)
+ self.insert_recv_fwd(fg)
+ else:
+ self.insert_send_bwd(fg)
+ self.insert_recv_bwd(fg)
+ return fg
+
+ elif isinstance(fg, torch.export.ExportedProgram):
+ logging.info("You are inserting PP collective operators to a torch.exported graph")
+ ep = fg
+ self.num_params_and_buffers = len(ep.state_dict)
+ fg = ep.graph_module
+ self.num_real_inputs, self.num_real_outputs = read_fg(fg)
+ self.num_real_inputs -= self.num_params_and_buffers
+ replaced_outputs = self.insert_send_fwd(fg)
+ self.insert_recv_fwd(fg)
+
+ # output_spec changes
+ for o_spec in ep._graph_signature.output_specs:
+ if isinstance(o_spec.arg, TensorArgument) and o_spec.arg.name in replaced_outputs:
+ o_spec.arg = TensorArgument(replaced_outputs[o_spec.arg.name])
+ return ep
+
+ else:
+ raise NotImplementedError("Unknown model type")
diff --git a/vescale/pipe/_schedules/zero_bubble_v.py b/vescale/pipe/_schedules/zero_bubble_v.py
new file mode 100644
index 0000000..294a806
--- /dev/null
+++ b/vescale/pipe/_schedules/zero_bubble_v.py
@@ -0,0 +1,1170 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from typing import List, Sequence, Optional, Dict
+from collections import deque, defaultdict
+from dataclasses import dataclass
+from inspect import signature
+import contextlib
+
+import torch
+
+from vescale.pipe._schedules.instruction_base import (
+ InstructionGenerator,
+ StageDeps,
+ CommPacket,
+ register_instruction,
+ Shape,
+ registed_functions,
+ VESCALE_INTRUCTION_BUILDER as builder,
+)
+from vescale.pipe.p2p_communication import (
+ recv_backward,
+ recv_forward,
+ send_backward,
+ send_forward,
+)
+from vescale.dtensor._diff import manage_dump_file
+from vescale.dtensor.device_mesh import DeviceMesh
+from vescale.dtensor.dtensor import DTensor, make_dtensor
+from vescale.ndtimeline import ndtimeit_p2p
+from vescale.ndtimeline.predefined import CROSS_MESH_RECV, CROSS_MESH_SEND
+from torch.distributed._functional_collectives import send, recv
+from vescale.dtensor.placement_types import Placement
+from vescale.dtensor._utils import compute_global_tensor_info
+from torch.distributed.distributed_c10d import _get_default_group
+from vescale.model.base_gpt.utils import switch_dtensor
+
+import logging
+
+logger = logging.getLogger(__file__)
+
+
+def maybe_tensor(tensor):
+ if isinstance(tensor, DTensor):
+ return tensor._local_tensor
+ elif isinstance(tensor, torch.Tensor):
+ return tensor
+ else:
+ raise RuntimeError(f"Error parsing tensor {tensor}")
+
+
+def cross_mesh_recv(comm, p2p_tensor):
+ mapping_group = comm.cur_mesh.get_mapping_rank(comm.peer_mesh)
+ if isinstance(mapping_group, int): # equal size
+ default_pg = _get_default_group()
+ with ndtimeit_p2p(CROSS_MESH_RECV, default_pg, mapping_group, is_batched=False):
+ tensor = torch.empty((3, 3), device=p2p_tensor.device, dtype=torch.int64)
+ recv(tensor, mapping_group, default_pg)
+ p_size = sum(tensor[:, 0] >= 0)
+ tensor = tensor[:p_size]
+ sharding_type = [Placement.serialize_from_tensor(p) for p in tensor]
+ sharding = sharding_type
+ if len(sharding_type) > 0:
+ global_shape, global_stride = compute_global_tensor_info(p2p_tensor, comm.cur_mesh, sharding)
+ p2p_tensor = make_dtensor(
+ p2p_tensor,
+ comm.cur_mesh,
+ sharding,
+ shape=torch.Size(global_shape),
+ dtype=p2p_tensor.dtype,
+ requires_grad=p2p_tensor.requires_grad,
+ stride=tuple(global_stride),
+ )
+ return p2p_tensor
+ else:
+ return p2p_tensor
+ else:
+ raise NotImplementedError("currently not support change mesh size")
+
+
+def cross_mesh_send(comm, dt):
+ mapping_group = comm.cur_mesh.get_mapping_rank(comm.peer_mesh)
+ if isinstance(mapping_group, int): # equal size
+ default_pg = _get_default_group()
+ with ndtimeit_p2p(CROSS_MESH_SEND, default_pg, mapping_group, is_batched=False):
+ if isinstance(dt, DTensor):
+ send_sharding = torch.stack(
+ [p.serialize_to_tensor(dt.device) for p in dt._spec.placements]
+ + [
+ torch.full((3,), -1, device=dt.device, dtype=torch.int64)
+ for _ in range(3 - len(dt._spec.placements))
+ ]
+ )
+ send(send_sharding, mapping_group, default_pg)
+ else: # tensor
+ send(torch.full((3, 3), -1, device=dt.device, dtype=torch.int64), mapping_group, default_pg)
+ else:
+ raise NotImplementedError("currently not support change mesh size")
+
+
+def cross_mesh_double(comm, fwd_tensor, p2p_tensor):
+ if isinstance(fwd_tensor, DTensor):
+ placements = fwd_tensor._spec.placements
+ global_shape, global_stride = compute_global_tensor_info(p2p_tensor, comm.cur_mesh, placements)
+ p2p_tensor = make_dtensor(
+ p2p_tensor,
+ comm.cur_mesh,
+ placements,
+ shape=torch.Size(global_shape),
+ dtype=p2p_tensor.dtype,
+ requires_grad=p2p_tensor.requires_grad,
+ stride=tuple(global_stride),
+ )
+ return p2p_tensor
+
+
+@dataclass(eq=True, frozen=True)
+class ScheduledNode:
+ type: str
+ chunk: int
+ stage: int
+ minibatch: int
+ start_time: int
+ completion_time: int
+ rollback: bool = False
+
+ def get_send_comms(self, total_stages, deps):
+ if self.chunk == 0:
+ return (
+ [
+ CommPacket(
+ cur_mesh=deps.get_current_mesh(self.stage),
+ peer_mesh=deps.get_current_mesh(self.stage + 1),
+ input_id=0,
+ peer_stage=self.stage + 1,
+ )
+ ]
+ if self.stage != total_stages
+ else []
+ )
+ else:
+ return (
+ [
+ CommPacket(
+ cur_mesh=deps.get_current_mesh(self.stage),
+ peer_mesh=deps.get_current_mesh(self.stage - 1),
+ input_id=0,
+ peer_stage=self.stage - 1,
+ )
+ ]
+ if self.stage != 0
+ else []
+ )
+
+ def get_recv_comms(self, total_stages, deps):
+ if self.chunk == 0:
+ return (
+ [
+ CommPacket(
+ cur_mesh=deps.get_current_mesh(self.stage),
+ peer_mesh=deps.get_current_mesh(self.stage - 1),
+ input_id=0,
+ peer_stage=self.stage - 1,
+ )
+ ]
+ if self.stage != 0
+ else []
+ )
+ else:
+ return (
+ [
+ CommPacket(
+ cur_mesh=deps.get_current_mesh(self.stage),
+ peer_mesh=deps.get_current_mesh(self.stage + 1),
+ input_id=0,
+ peer_stage=self.stage + 1,
+ )
+ ]
+ if self.stage != total_stages
+ else []
+ )
+
+
+class CostGraph:
+ def __init__(self, n_stage, n_micro, f_cost, b_cost, w_cost, c_cost, f_mem, b_mem, w_mem, max_mem=None):
+ self.n_node = 6 * n_stage * n_micro
+ self.n_stage = n_stage
+ self.n_micro = n_micro
+ self.f_cost = f_cost
+ self.b_cost = b_cost
+ self.w_cost = w_cost
+ self.c_cost = c_cost
+ self.f_mem = f_mem
+ self.b_mem = b_mem
+ self.w_mem = w_mem
+ self.fbw_cost = [f_cost, b_cost, w_cost]
+ self.fbw_mem = [f_mem, b_mem, w_mem]
+ self.max_mem = max_mem or f_mem * self.n_stage * 2
+
+ def get_id(self, cat, chunk, stage, micro):
+ return (
+ cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro
+ )
+
+ def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
+ count = []
+ for i in range(self.n_stage):
+ count.append([0] * 6)
+
+ end_time = [-1] * self.n_node
+ cur_time = [0] * self.n_stage
+ mem = [0] * self.n_stage
+ stage_bubble = [0] * self.n_stage
+ pending_w = [deque() for _ in range(self.n_stage)]
+ schedule = [[] for _ in range(self.n_stage)]
+ stage_str = [" " * i for i in range(self.n_stage)]
+
+ if approved_bubble is None:
+ approved_bubble = [-1] * self.n_stage
+ max_approved_bubble = max(approved_bubble)
+
+ def get_max_stage_bubble(stage=-1):
+ max_stage_bubble = 0
+ for bb in stage_bubble:
+ max_stage_bubble = max(max_stage_bubble, bb)
+ if stage >= 0:
+ max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
+ return max_stage_bubble
+
+ def put_w(stage):
+ assert len(pending_w[stage]) > 0
+ _, chunk_, _ = pending_w[stage].popleft()
+ put(2, chunk_, stage)
+
+ def put(cat, chunk, stage, assert_cnt=True):
+ _tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
+ _cnt = count[stage][cat * 2 + chunk]
+ if _cnt >= self.n_micro:
+ if not assert_cnt:
+ stage_str[stage] += " "
+ cur_time[stage] = _tmp # TODO
+ return
+ raise AssertionError()
+ assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
+ stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
+ if cat > 0 or chunk > 0:
+ last_id = cat * 2 + chunk - 1
+ if cat < 2:
+ assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
+ else:
+ assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
+ if chunk == 1 and cat < 2:
+ if stage < self.n_stage - 1:
+ _fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
+ assert end_time[_fa_id] >= 0
+ _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
+ if chunk == 0 and cat < 2:
+ if stage > 0:
+ _fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
+ assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
+ _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
+ _id = self.get_id(cat, chunk, stage, _cnt)
+ if count[stage][0] > 0:
+ stage_bubble[stage] += _tmp - _no_bubble
+ end_time[_id] = _tmp
+ cur_time[stage] = _tmp
+ mem[stage] += self.fbw_mem[cat]
+ # noinspection PyTypeChecker
+ schedule[stage].append((cat, chunk, _cnt))
+ if cat == 1:
+ pending_w[stage].append((2, chunk, _cnt))
+ count[stage][cat * 2 + chunk] += 1
+
+ for i in range(self.n_stage):
+ put(0, 0, i)
+ for i in range(self.n_stage - 1, -1, -1):
+ if i == self.n_stage - 1:
+ put(0, 1, i)
+ continue
+ tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
+ while (
+ mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem
+ and cur_time[i] + self.fbw_cost[0] <= tmp
+ and count[i][0] < self.n_micro
+ ):
+ for j in range(i + 1):
+ put(0, 0, j)
+ put(0, 1, i)
+ iter_chunk_ = 0
+ end_tmp = 0
+ for i in range(self.n_stage):
+ if i == 0:
+ end_tmp = cur_time[0] + self.fbw_cost[1]
+ continue
+ tmp = end_tmp + self.c_cost
+ while (
+ count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1]
+ or count[i][1] <= count[i - 1][1] < self.n_micro
+ ):
+ for j in range(self.n_stage - 1, i - 1, -1):
+ if count[j][iter_chunk_] < self.n_micro:
+ put(0, iter_chunk_, j)
+ iter_chunk_ = 1 - iter_chunk_
+
+ for _ in range(2 * self.n_micro):
+ # check mem before putting b
+ for i in range(self.n_stage):
+ while mem[i] + self.fbw_mem[1] > self.max_mem:
+ assert len(pending_w[i]) > 0
+ put_w(i)
+ b0_ranks, b1_ranks = [], []
+ for i in range(self.n_stage):
+ if count[i][3] >= count[i][2]:
+ b0_ranks.append(i)
+ elif i == self.n_stage - 1:
+ b1_ranks.append(i)
+ else:
+ fa_id = self.get_id(1, 1, i + 1, count[i][3])
+ if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
+ b1_ranks.append(i)
+ else:
+ b0_ranks.append(i)
+ b_ranks = []
+ # put b1
+ for i in reversed(b1_ranks):
+ b_ranks.append((i, 1))
+ # put b0
+ for i in b0_ranks:
+ b_ranks.append((i, 0))
+ for i, _chunk_ in b_ranks:
+ fa_id = -1
+ if _chunk_ == 1 and i < self.n_stage - 1:
+ fa_id = self.get_id(1, 1, i + 1, count[i][3])
+ if _chunk_ == 0 and i > 0:
+ fa_id = self.get_id(1, 0, i - 1, count[i][2])
+ while (
+ len(pending_w[i]) > 0
+ and fa_id >= 0
+ and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
+ ):
+ # fill the bubble
+ put_w(i)
+ if (
+ len(pending_w[i]) > 0
+ and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
+ ):
+ if _chunk_ == 1:
+ put_w(i)
+ elif fill_b:
+ put_w(i)
+ put(1, _chunk_, i)
+
+ # put f
+ for i in range(self.n_stage):
+ if count[i][1] >= self.n_micro:
+ continue
+ put_item = None
+ if count[i][1] >= count[i][0]:
+ put_item = 0
+ elif i == self.n_stage - 1:
+ put_item = 1
+ else:
+ if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
+ put_item = 1
+ elif count[i][0] < self.n_micro:
+ if i == 0:
+ put_item = 0
+ elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
+ put_item = 0
+ if put_item is None:
+ continue
+ # check mem before putting f
+ while mem[i] + self.fbw_mem[0] > self.max_mem:
+ assert len(pending_w[i]) > 0
+ put_w(i)
+ fa_id = -1
+ if put_item == 0 and i > 0:
+ fa_id = self.get_id(0, 0, i - 1, count[i][0])
+ if put_item == 1 and i < self.n_stage - 1:
+ fa_id = self.get_id(0, 1, i + 1, count[i][1])
+ while (
+ len(pending_w[i]) > 0
+ and fa_id >= 0
+ and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
+ ):
+ # fill the bubble
+ put_w(i)
+ if (
+ len(pending_w[i]) > 0
+ and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
+ ):
+ if fill_f:
+ put_w(i)
+ put(0, put_item, i)
+
+ for i in range(self.n_stage):
+ while len(pending_w[i]) > 0:
+ put_w(i)
+
+ max_bubble = get_max_stage_bubble()
+ expected_time = sum(self.fbw_cost) * self.n_micro * 2
+ bubble_rate = max_bubble / expected_time
+ if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
+ _schedule, _end_time, _max_bubble = self.try_v_schedule(
+ fill_f=fill_f,
+ fill_b=fill_b,
+ approved_bubble=stage_bubble,
+ )
+ if _max_bubble < max_bubble:
+ return _schedule, _end_time, _max_bubble
+ return schedule, end_time, max_bubble
+
+ def print_details(self, end_time, print_scaling=1):
+ for stage in range(self.n_stage):
+ stage_str = ["."] * int(max(end_time) / print_scaling)
+ for _cat in range(3):
+ for _chunk in range(2):
+ for _micro in range(self.n_micro):
+ _id = self.get_id(_cat, _chunk, stage, _micro)
+ if end_time[_id] < 0:
+ continue
+ end = int(end_time[_id] / print_scaling)
+ start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
+ for j in range(start, end):
+ if j == start or j == end - 1:
+ stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
+ elif j == start + 1:
+ if _micro >= 10:
+ stage_str[j] = str(_micro // 10)
+ else:
+ stage_str[j] = str(_micro)
+ elif j == start + 2 and _micro >= 10:
+ stage_str[j] = str(_micro % 10)
+ else:
+ stage_str[j] = "-"
+ _str = ""
+ for _c in stage_str:
+ _str += _c
+ print(_str)
+
+ def get_v_schedule(self, only_run_time=False):
+ schedule, end_time, max_bubble = None, None, None
+ expected_time = sum(self.fbw_cost) * self.n_micro * 2
+ for fill_b in [True, False]:
+ for fill_f in [True, False]:
+ _schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f)
+ if max_bubble is None or _max_bubble < max_bubble:
+ max_bubble = _max_bubble
+ schedule = _schedule
+ end_time = _end_time
+ if only_run_time:
+ return max_bubble + expected_time
+ bubble_rate = max_bubble / (expected_time + max_bubble)
+ msg = "%2d %3d, [%5d %5d %5d %5d], %6d -> %6.4f" % (
+ self.n_stage,
+ self.n_micro,
+ *self.fbw_cost,
+ self.c_cost,
+ self.max_mem // self.f_mem,
+ bubble_rate,
+ )
+
+ logger.info(msg)
+ local_order = [[] for _ in range(self.n_stage)]
+ comm_id = {}
+ comm_id_counter = 0
+ post_validation_time = 0
+ for i in range(self.n_stage - 1, -1, -1):
+ pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
+ post_validation_time = max(
+ post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost
+ )
+ for it in ["RECV_", "SEND_", ""]:
+ if i == 0 and it == "SEND_":
+ continue
+ if i == self.n_stage - 1 and it == "RECV_":
+ continue
+ stage_ = i
+ local_order[stage_].append(
+ ScheduledNode(
+ type=it + "POST_VALIDATION",
+ chunk=0,
+ stage=stage_,
+ minibatch=0,
+ start_time=post_validation_time,
+ completion_time=post_validation_time,
+ )
+ )
+ comm_id[local_order[stage_][-1]] = comm_id_counter
+ comm_id_counter += 1
+ for i in range(self.n_stage):
+ for _cat_, _chunk_, _micro_ in schedule[i]:
+ complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
+ local_order[i].append(
+ ScheduledNode(
+ type="FBW"[_cat_],
+ chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
+ stage=i,
+ minibatch=_micro_,
+ start_time=complete_time - self.fbw_cost[_cat_],
+ completion_time=complete_time,
+ )
+ )
+ if _cat_ == 2: # no communication for W
+ continue
+ cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
+
+ def communicate(send_recv, stage_):
+ # noinspection PyTypeChecker
+ local_order[stage_].append(
+ ScheduledNode(
+ type=send_recv + cat_str,
+ chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
+ stage=stage_,
+ minibatch=_micro_,
+ start_time=complete_time,
+ completion_time=complete_time,
+ )
+ )
+ comm_id[local_order[stage_][-1]] = comm_id_counter
+
+ if _chunk_ == 1 and i > 0:
+ communicate("SEND_", i)
+ communicate("RECV_", i - 1)
+ if _chunk_ == 0 and i < self.n_stage - 1:
+ communicate("SEND_", i)
+ communicate("RECV_", i + 1)
+ comm_id_counter += 1
+ for rank in range(self.n_stage):
+ # For nodes with the same timestamp on the same stage, communication will be prioritized.
+ def even_breaker(x: ScheduledNode):
+ # Compute nodes are always delayed.
+ if x.type in ["F", "B", "W"]:
+ return comm_id_counter
+ # For comm nodes, order by their unique comm id
+ return comm_id[x]
+
+ local_order[rank] = sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x)))
+ # If a recv with intersects with previous computation, reorder them so that recv
+ # is executed before computation and hence can be overlapped.
+ for i in range(len(local_order[rank])):
+ if (
+ i > 0
+ and local_order[rank][i - 1].type in {"F", "B", "W"}
+ and local_order[rank][i].type.startswith("RECV")
+ and "POST_VALIDATION" not in local_order[rank][i].type
+ and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time
+ ):
+ local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
+
+ local_order_with_rollback = [[] for _ in range(self.n_stage)]
+ for rank in range(self.n_stage):
+ rollback_comm = set()
+ if rank > 0:
+ for node in local_order[rank - 1]:
+ if node.type == "POST_VALIDATION":
+ break
+ if node.type == "SEND_FORWARD":
+ assert node.chunk == 0
+ rollback_comm.add(node.minibatch)
+ for node in local_order[rank]:
+ if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
+ rollback = True
+ rollback_comm.remove(node.minibatch)
+ else:
+ rollback = False
+ local_order_with_rollback[rank].append(
+ ScheduledNode(
+ type=node.type,
+ chunk=node.chunk,
+ stage=node.stage,
+ minibatch=node.minibatch,
+ start_time=node.start_time,
+ completion_time=node.completion_time,
+ rollback=rollback,
+ )
+ )
+ assert len(rollback_comm) == 0
+ msg = ""
+ for node in local_order_with_rollback[rank]:
+ msg += f"{node.type}-{node.minibatch}-{int(node.rollback)},"
+ msg = msg[:-1] + "\n"
+ logger.info(msg)
+
+ return local_order_with_rollback
+
+
+class ZeroBubbleVInstrcutionGenerator(InstructionGenerator):
+ def __init__(
+ self,
+ deps: StageDeps,
+ meshes: List[DeviceMesh],
+ batches: int,
+ f_cost: int,
+ b_cost: int,
+ w_cost: int,
+ c_cost: int,
+ f_mem: int,
+ b_mem: int,
+ w_mem: int,
+ max_mem=None,
+ default_shape: Optional[Shape] = None,
+ default_dtype: Optional[torch.dtype] = None,
+ ):
+ self.num_chunk = 2 # for ZBV, manually set num chunks be 2 for each worker
+ self.deps = deps
+ n_stage = deps.num_stage
+ n_micro = batches
+ self.cost_graph = CostGraph(n_stage, n_micro, f_cost, b_cost, w_cost, c_cost, f_mem, b_mem, w_mem, max_mem=None)
+ self.num_stage = len(meshes)
+ self.schema = self.cost_graph.get_v_schedule()
+ self.default_shape = default_shape
+ self.default_dtype = default_dtype
+
+ def gen_instruction(self):
+ self.instruction_list = [[] for _ in range(self.num_stage)]
+ self.instruction_list_str = ["" for _ in range(self.num_stage)]
+
+ for stage in range(self.num_stage):
+ stage_str = ""
+ for node in self.schema[stage]:
+ self._set_inst(node, stage)
+ stage_str += node.type + ","
+ stage_str = stage_str[:-1]
+
+ self.gen_instruction_str_list()
+
+ def gen_instruction_str_list(self):
+ instruction_lists = self.instruction_list
+ stage_strs = defaultdict(str)
+ for stage_id, instruction_list in enumerate(instruction_lists):
+ cur_stage_str = stage_strs[stage_id]
+ for inst in instruction_list:
+ cur_stage_str += f"{VESCALE_INSTRUCTION_MAPPING_ZBV[inst.type]},"
+ cur_stage_str = cur_stage_str[:-1]
+ stage_strs[stage_id] = cur_stage_str
+ builder.build_from_dict(stage_strs)
+
+ @manage_dump_file
+ def execute(
+ self,
+ stage_id,
+ autocast_dtype=torch.float32,
+ enable_autocast=False,
+ grad_scaler=None,
+ deallocate_pipeline_outputs=False,
+ ):
+ # init constant data
+ builder.constant_data["autocast_dtype"] = autocast_dtype
+ builder.constant_data["enable_autocast"] = enable_autocast
+ builder.constant_data["grad_scaler"] = grad_scaler
+ builder.constant_data["deallocate_pipeline_outputs"] = deallocate_pipeline_outputs
+ builder.constant_data["total_stages"] = self.num_stage
+ builder.constant_data["stagedeps"] = self.deps
+ builder.constant_data["default_shape"] = self.default_shape
+ builder.constant_data["default_dtype"] = self.default_dtype
+
+ # Model chunk IDs with synchronized grads
+ builder.user_data["synchronized_model_chunks"] = set()
+ builder.user_data["input_tensors"] = [[] for _ in range(self.num_chunk)]
+ builder.user_data["output_tensors"] = [[] for _ in range(self.num_chunk)]
+ builder.user_data["output_tensor_grads"] = [[] for _ in range(self.num_chunk)]
+ builder.user_data["fwd_wait_handles"] = None
+ builder.user_data["bwd_wait_handles"] = None
+ builder.user_data["output_tensor"] = None
+ builder.user_data["input_tensor"] = (None, None)
+ builder.user_data["output_tensor_grad"] = None
+ builder.user_data["forward_data_store"] = []
+ model = self.deps.get_current_model(stage_id)
+
+ builder.model = model
+ instruction_list = self.get_instruction_list(stage_id)
+ builder.stage_id = stage_id
+ builder_instruction_list = builder.global_instructions_funcs[stage_id]
+
+ assert len(instruction_list) == len(builder_instruction_list)
+ # print(f"cur stage {stage_id} debug inst list: {instruction_list} len inst {len(instruction_list)}")
+
+ for inst, fn in zip(instruction_list, builder_instruction_list):
+ builder.user_data["inst"] = inst
+ fn()
+
+ return builder.user_data["forward_data_store"]
+
+
+# communication
+
+
+@register_instruction(name="vescale_zbv_send_forward")
+def vescale_zbv_send_forward():
+ inst = builder.user_data["inst"]
+ output_tensors = builder.user_data["output_tensor"]
+
+ if not isinstance(output_tensors, list):
+ output_tensors = [output_tensors]
+
+ def f(info):
+ output_tensor, comm, shape = info
+ send_forward(
+ output_tensor=maybe_tensor(output_tensor),
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ tensor_shape=shape,
+ )
+ cross_mesh_send(comm, output_tensor)
+
+ comm_packages = inst.get_send_comms(builder.constant_data["total_stages"], builder.constant_data["stagedeps"])
+
+ shapes = [builder.constant_data["default_shape"] for _ in comm_packages]
+ infos = zip(output_tensors, comm_packages, shapes)
+ return list(map(f, infos))
+
+
+@register_instruction(name="vescale_zbv_recv_forward")
+def vescale_zbv_recv_forward():
+ inst = builder.user_data["inst"]
+ chunk_id = inst.chunk
+ mbx = inst.minibatch
+
+ def f(info):
+ comm, shape, dtype = info
+ p2p_tensor = recv_forward(
+ tensor_shape=shape,
+ recv_dtype=dtype,
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ )
+ p2p_tensor = cross_mesh_recv(comm, p2p_tensor)
+ return p2p_tensor
+
+ comm_packages = inst.get_recv_comms(builder.constant_data["total_stages"], builder.constant_data["stagedeps"])
+ shapes = [builder.constant_data["default_shape"] for _ in comm_packages]
+ dtypes = [builder.constant_data["default_dtype"] for _ in comm_packages]
+ infos = zip(comm_packages, shapes, dtypes)
+ out = list(map(f, infos))
+ input_tensor = out if len(out) > 0 else None
+ builder.user_data["input_tensor"] = (input_tensor, mbx)
+ builder.user_data["input_tensors"][chunk_id].append((input_tensor, mbx))
+ return input_tensor
+
+
+@register_instruction(name="vescale_zbv_send_backward")
+def vescale_zbv_send_backward():
+ inst = builder.user_data["inst"]
+ input_tensor_grad = builder.user_data["input_tensor_grad"]
+ if not isinstance(input_tensor_grad, list):
+ input_tensor_grad = [input_tensor_grad]
+
+ def f(info):
+ grad, comm, shape = info
+ send_backward(
+ input_tensor_grad=maybe_tensor(grad),
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ tensor_shape=shape,
+ )
+ cross_mesh_send(comm, grad)
+
+ recv_comms = inst.get_recv_comms(builder.constant_data["total_stages"], builder.constant_data["stagedeps"])
+ shapes = [builder.constant_data["default_shape"] for _ in recv_comms]
+ infos = zip(input_tensor_grad, recv_comms, shapes)
+ return list(map(f, infos))
+
+
+@register_instruction(name="vescale_zbv_recv_backward")
+def vescale_zbv_recv_backward():
+ inst = builder.user_data["inst"]
+ chunk_id = inst.chunk
+
+ def f(info):
+ comm, shape, dtype = info
+ p2p_tensor = recv_backward(
+ tensor_shape=shape,
+ recv_dtype=dtype,
+ current_device_mesh=comm.cur_mesh,
+ peer_device_mesh=comm.peer_mesh,
+ )
+ p2p_tensor = cross_mesh_recv(comm, p2p_tensor)
+ return p2p_tensor
+
+ comm_packages = inst.get_send_comms(builder.constant_data["total_stages"], builder.constant_data["stagedeps"])
+ shapes = [builder.constant_data["default_shape"] for _ in comm_packages]
+ dtypes = [builder.constant_data["default_dtype"] for _ in comm_packages]
+ infos = zip(comm_packages, shapes, dtypes)
+ out = list(map(f, infos))
+ output_tensor_grad = out if len(out) > 0 else None
+
+ builder.user_data["output_tensor_grad"] = output_tensor_grad
+ builder.user_data["output_tensor_grads"][chunk_id].append(output_tensor_grad)
+ return output_tensor_grad
+
+
+# forward
+
+
+@register_instruction(name="vescale_zbv_forward")
+def vescale_zbv_forward():
+ inst = builder.user_data["inst"]
+ chunk_id = inst.chunk
+ stage_id = inst.stage
+ mbx = inst.minibatch
+ cur_model = builder.model[chunk_id]
+
+ user_data = builder.user_data
+ forward_data_store = user_data["forward_data_store"]
+ input_tensors = user_data["input_tensors"]
+ output_tensors = user_data["output_tensors"]
+
+ constant_data = builder.constant_data
+ autocast_dtype = constant_data["autocast_dtype"]
+ enable_autocast = constant_data["enable_autocast"]
+
+ is_pp_first_stage = stage_id == 0 and chunk_id == 0
+ is_pp_last_stage = stage_id == 0 and chunk_id == 1
+
+ # forward step
+ if is_pp_first_stage:
+ if len(input_tensors[chunk_id]) == len(output_tensors[chunk_id]):
+ input_tensors[chunk_id].append(None)
+
+ # find corresponding input tensor
+ input_tensor = None
+ for cur_item in input_tensors[chunk_id]:
+ if cur_item is not None and cur_item[1] == mbx:
+ input_tensor = cur_item[0]
+
+ if not is_pp_first_stage:
+ assert input_tensor is not None
+
+ if enable_autocast:
+ context_manager = torch.autocast("cuda", dtype=autocast_dtype)
+ else:
+ context_manager = contextlib.nullcontext()
+
+ with context_manager:
+
+ def prepare_data():
+ model_chunk_id = builder.user_data["model_chunk_id"]
+ ground_truth = []
+ if builder.user_data["is_pp_first_stage"]:
+ true_input_tensor = next(builder.dataloader[model_chunk_id])
+ # keep the input tensor in builder
+ if len(input_tensors[chunk_id]) == len(output_tensors[chunk_id]) + 1:
+ true_input_tensor.requires_grad_()
+ builder.user_data["input_tensors"][chunk_id].pop()
+ builder.user_data["input_tensors"][chunk_id].append((true_input_tensor, mbx))
+ else:
+ local_tensors = next(builder.dataloader[model_chunk_id])
+ if isinstance(local_tensors, Sequence) and len(local_tensors) > 1:
+ ground_truth.append(local_tensors[-1])
+ elif isinstance(local_tensors, Dict) and "labels" in local_tensors:
+ ground_truth.append(local_tensors["labels"])
+ true_input_tensor = builder.user_data["p2p_tensors"]
+ if isinstance(true_input_tensor, Sequence):
+ true_input_tensor = true_input_tensor[0]
+
+ return (true_input_tensor,), {}, ground_truth
+
+ builder.user_data["model_chunk_id"] = chunk_id
+ builder.user_data["p2p_tensors"] = input_tensor
+ builder.user_data["is_pp_first_stage"] = is_pp_first_stage
+ builder.user_data["is_pp_last_stage"] = is_pp_last_stage
+ builder.user_data["prepare_data_fn"] = prepare_data
+ args, kwargs, ground_truth = builder.user_data["prepare_data_fn"]()
+ builder.user_data["ground_truth"] = ground_truth
+ output_tensor = cur_model(*args, **kwargs)
+
+ if is_pp_last_stage:
+ output_tensor, loss_tensor = registed_functions["vescale_zbv_loss_fn"](output_tensor)
+ forward_data_store.append((output_tensor, loss_tensor))
+ output_tensor = output_tensor if builder.loss_fn is None else loss_tensor
+
+ if stage_id + 1 == builder.constant_data["total_stages"] and chunk_id == 0:
+ # turn around the forward direction
+ builder.user_data["input_tensor"] = (output_tensor, mbx)
+ builder.user_data["input_tensors"][chunk_id + 1].append((output_tensor, mbx))
+
+ builder.user_data["output_tensors"][chunk_id].append(output_tensor)
+ user_data["output_tensor"] = output_tensor
+
+
+# backward
+
+
+@register_instruction(name="vescale_zbv_backward_b")
+def vescale_zbv_backward_b():
+ inst = builder.user_data["inst"]
+ chunk_id = inst.chunk
+ stage_id = inst.stage
+ grad_scaler = builder.constant_data["grad_scaler"]
+ deallocate_pipeline_outputs = builder.constant_data["deallocate_pipeline_outputs"]
+
+ input_tensors = builder.user_data["input_tensors"]
+ output_tensors = builder.user_data["output_tensors"]
+ output_tensor_grads = builder.user_data["output_tensor_grads"]
+
+ is_pp_last_stage = stage_id == 0 and chunk_id == 1
+
+ if is_pp_last_stage:
+ if len(output_tensor_grads[chunk_id]) == 0:
+ output_tensor_grads[chunk_id].append(None)
+ input_tensor = input_tensors[chunk_id].pop(0)[0]
+ output_tensor = output_tensors[chunk_id][0]
+ output_tensor_grad = output_tensor_grads[chunk_id][0]
+
+ # Retain the grad on the input_tensor.
+ unwrap_input_tensor_grad = False
+ if not isinstance(input_tensor, list):
+ input_tensor = [input_tensor]
+ unwrap_input_tensor_grad = True
+ for x in input_tensor:
+ if x is not None:
+ x.retain_grad()
+
+ if not isinstance(output_tensor, list):
+ output_tensor = [output_tensor]
+ if not isinstance(output_tensor_grad, list):
+ output_tensor_grad = [output_tensor_grad]
+
+ # extract loss value from output tensors
+ if isinstance(output_tensor[0], Sequence):
+ for j in range(len(output_tensor[0])):
+ if output_tensor[0][j].ndim == 0 and output_tensor[0][j].numel() == 1:
+ loss_value = output_tensor[0][j]
+ break
+ else:
+ loss_value = output_tensor[0][-1]
+ else:
+ loss_value = output_tensor[0]
+
+ # Backward pass.
+ if output_tensor_grad[0] is None and grad_scaler is not None:
+ loss_value = grad_scaler(loss_value)
+ # FIXME: For virtual pipeline, there may exist frozen layer without grad;
+ # Need to verify if this solution is correct
+ if not loss_value.requires_grad:
+ return None
+
+ if deallocate_pipeline_outputs:
+ assert 0
+ # custom_backward(output_tensor[0], output_tensor_grad[0])
+ else:
+ input_tensor_grad = switch_dtensor(torch.autograd.grad)(
+ loss_value,
+ input_tensor,
+ grad_outputs=output_tensor_grad[0],
+ retain_graph=True,
+ allow_unused=True,
+ materialize_grads=True,
+ )[0]
+
+ if unwrap_input_tensor_grad:
+ input_tensor_grad = input_tensor_grad[0]
+
+ def f(input_tensor):
+ if input_tensor is not None:
+ assert isinstance(input_tensor, (torch.Tensor, DTensor)), input_tensor
+ input_tensor.grad = None
+
+ nonlocal output_tensor
+
+ if not isinstance(output_tensor, Sequence):
+ output_tensor = [output_tensor]
+
+ if (output_tensor is None) or (not deallocate_pipeline_outputs):
+ return
+ assert isinstance(
+ output_tensor, [torch.Tensor, DTensor]
+ ), f"expected Tensor, found {type(output_tensor).__name__}."
+ assert output_tensor._base is None, "counter-productive to free a view of another tensor."
+ if isinstance(output_tensor, [torch.Tensor, DTensor]):
+ output_tensor._local_tensor.data = torch.empty(
+ (1,),
+ device=output_tensor.device,
+ dtype=output_tensor.dtype,
+ )
+ else:
+ output_tensor.data = torch.empty(
+ (1,),
+ device=output_tensor.device,
+ dtype=output_tensor.dtype,
+ )
+ return
+
+ if not isinstance(input_tensor, Sequence):
+ map(f, [input_tensor])
+ else:
+ map(f, input_tensor)
+
+ if stage_id + 1 == builder.constant_data["total_stages"] and chunk_id == 1:
+ # turn around the forward direction
+ builder.user_data["output_tensor_grad"] = input_tensor_grad
+ builder.user_data["output_tensor_grads"][chunk_id - 1].append(output_tensor_grad)
+
+ builder.user_data["input_tensor_grad"] = input_tensor_grad
+
+
+@register_instruction(name="vescale_zbv_backward_w")
+def vescale_zbv_backward_w():
+ inst = builder.user_data["inst"]
+ chunk_id = inst.chunk
+ stage_id = inst.stage
+ cur_model = builder.model[chunk_id]
+ grad_scaler = builder.constant_data["grad_scaler"]
+ deallocate_pipeline_outputs = builder.constant_data["deallocate_pipeline_outputs"]
+
+ output_tensors = builder.user_data["output_tensors"]
+ output_tensor_grads = builder.user_data["output_tensor_grads"]
+
+ is_pp_last_stage = stage_id == 0 and chunk_id == 1
+
+ if is_pp_last_stage:
+ if len(output_tensor_grads[chunk_id]) == 0:
+ output_tensor_grads[chunk_id].append(None)
+ output_tensor = output_tensors[chunk_id].pop(0)
+ output_tensor_grad = output_tensor_grads[chunk_id].pop(0)
+
+ if not isinstance(output_tensor, list):
+ output_tensor = [output_tensor]
+ if not isinstance(output_tensor_grad, list):
+ output_tensor_grad = [output_tensor_grad]
+
+ # Backward pass.
+ if output_tensor_grad[0] is None and grad_scaler is not None:
+ output_tensor = grad_scaler(output_tensor[0])
+ # FIXME: For virtual pipeline, there may exist frozen layer without grad;
+ # Need to verify if this solution is correct
+ if not output_tensor[0].requires_grad:
+ return None
+
+ # Gather params
+ nps = {}
+ for key, value in cur_model.named_parameters():
+ nps[key] = value
+
+ if deallocate_pipeline_outputs:
+ assert 0
+ else:
+ params_grad = switch_dtensor(torch.autograd.grad)(
+ output_tensor[0],
+ nps.values(),
+ grad_outputs=output_tensor_grad[0],
+ retain_graph=True,
+ allow_unused=True,
+ materialize_grads=True,
+ )
+
+ # Manually set each params grad
+ for param, grad in zip(nps.values(), params_grad):
+ param.grad = grad
+
+
+# validation
+
+
+@register_instruction(name="vescale_zbv_post_validation")
+def vescale_zbv_post_validation():
+ pass
+
+
+@register_instruction(name="vescale_zbv_recv_post_validation")
+def vescale_zbv_recv_post_validation():
+ pass
+
+
+@register_instruction(name="vescale_zbv_send_post_validation")
+def vescale_zbv_send_post_validation():
+ pass
+
+
+# loss
+
+
+@register_instruction(name="vescale_zbv_loss_fn")
+def vescale_zbv_loss_fn(output_tensor):
+ loss_func = builder.loss_fn
+ if loss_func is None:
+ return output_tensor, None
+ temp_tensor = output_tensor
+ args_spec = signature(loss_func)
+ args_len = len(args_spec.parameters.keys())
+ if args_len == 1:
+ output_tensor = loss_func(output_tensor)
+ else:
+ ground_truth = builder.user_data["ground_truth"]
+ loss_fn_inputs = [output_tensor] + ground_truth
+ output_tensor = loss_func(*loss_fn_inputs)
+ assert args_len == len(loss_fn_inputs), "Mismatch of loss function #args and #actual inputs!"
+ builder.user_data["output_tensor"] = output_tensor
+ return temp_tensor, output_tensor
+
+
+VESCALE_INSTRUCTION_MAPPING_ZBV = {
+ "RECV_FORWARD": "vescale_zbv_recv_forward",
+ "SEND_FORWARD": "vescale_zbv_send_forward",
+ "F": "vescale_zbv_forward",
+ "B": "vescale_zbv_backward_b",
+ "W": "vescale_zbv_backward_w",
+ "RECV_BACKWARD": "vescale_zbv_recv_backward",
+ "SEND_BACKWARD": "vescale_zbv_send_backward",
+ "RECV_POST_VALIDATION": "vescale_zbv_recv_post_validation",
+ "SEND_POST_VALIDATION": "vescale_zbv_send_post_validation",
+ "POST_VALIDATION": "vescale_zbv_post_validation",
+}
+
+if __name__ == "__main__":
+ settings = [
+ # p, n, f, b, w, c, h, a, l
+ # (8, 24, 18522, 18086, 9337, 601, 2304, 24, 24),
+ # (8, 32, 18513, 18086, 9331, 626, 2304, 24, 24),
+ # (8, 64, 18546, 18097, 9321, 762, 2304, 24, 24),
+ # (8, 24, 29718, 29444, 19927, 527, 4096, 32, 32),
+ # (8, 32, 29802, 29428, 19530, 577, 4096, 32, 32),
+ # (8, 64, 29935, 29621, 19388, 535, 4096, 32, 32),
+ # (16, 48, 11347, 11248, 8132, 377, 5120, 40, 48),
+ # (16, 64, 11307, 11254, 8101, 379, 5120, 40, 48),
+ # (16, 128, 11325, 11308, 8109, 378, 5120, 40, 48),
+ # (32, 96, 10419, 10207, 7715, 408, 6144, 48, 64),
+ # (32, 128, 10408, 10204, 7703, 408, 6144, 48, 64),
+ # (32, 256, 10402, 10248, 7698, 460, 6144, 48, 64),
+ (4, 8, 6, 4, 4, 1, 4096, 32, 32),
+ # (8, 24, 29444, 29718, 19927, 527, 4096, 32, 32),
+ # ( 8, 32, 16099, 16504, 7589, 540, 2304, 24, 16),
+ # (16, 48, 14407, 14380, 9676, 1610, 4096, 32, 32),
+ # (16, 64, 14412, 14393, 9688, 1621, 4096, 32, 32),
+ # (16, 128, 14316, 14306, 9639, 1619, 4096, 32, 32),
+ # (24, 72, 6763, 6969, 5251, 755, 5120, 40, 48),
+ # (24, 96, 6783, 6984, 5259, 758, 5120, 40, 48),
+ # (24, 192, 6785, 6990, 5260, 770, 5120, 40, 48),
+ # (32, 96, 9458, 9748, 7288, 879, 6144, 48, 64),
+ # (32, 128, 9469, 9744, 7306, 892, 6144, 48, 64),
+ # (32, 256, 9447, 9644, 7193, 887, 6144, 48, 64),
+ ]
+ s = 1024
+
+ # h, a, s = 4096, 32, 1024
+ # cost_f, cost_b, cost_w, cost_c = 29718, 29444, 19927, 527
+ for p, n, f, b, w, c, h, a, _ in settings:
+ mem_f = 34 * h + 5 * a * s
+ mem_w = -32 * h
+ mem_b = -mem_w - mem_f
+ for m_offset in range(p + 1):
+ graph = CostGraph(
+ n_stage=p,
+ n_micro=n,
+ f_cost=f,
+ b_cost=b,
+ w_cost=w,
+ c_cost=c,
+ f_mem=mem_f,
+ b_mem=mem_b,
+ w_mem=mem_w,
+ max_mem=mem_f * (p * 2 + m_offset),
+ )
+ graph.get_v_schedule()
+ break
diff --git a/vescale/pipe/p2p_communication.py b/vescale/pipe/p2p_communication.py
new file mode 100644
index 0000000..61605b6
--- /dev/null
+++ b/vescale/pipe/p2p_communication.py
@@ -0,0 +1,1005 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+# Some code are adapted p2p_communication.py in Megatron-LM.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+################################################################################
+
+from enum import Enum
+import os
+import torch
+import torch.distributed as dist
+from vescale.dtensor.dtensor import DTensor
+from vescale.dtensor.device_mesh import DeviceMesh
+from vescale.ndtimeline import ndtimeit_p2p
+from vescale.ndtimeline.predefined import (
+ RECV_FORWARD,
+ RECV_BACKWARD,
+ SEND_FORWARD,
+ SEND_BACKWARD,
+ SEND_FORWARD_RECV_BACKWARD,
+ SEND_BACKWARD_RECV_FORWARD,
+)
+from typing import Optional, List, Union, Tuple
+from torch.distributed.distributed_c10d import ProcessGroup
+
+try:
+ from torch.distributed.distributed_c10d import _coalescing_manager
+except ImportError:
+ print("Warning: cannot import coalescing_manager. It may impact PP performance")
+
+# Types
+Shape = Union[List[int], torch.Size]
+# For P2P overlap, currently we do not differ fwd/bwd reqs;
+# Hence, drain func will sync both fwd and bwd p2p ops.
+GLOBAL_COUNTER = 0
+INTERMEDIATE_SHAPES = []
+MINIBATCH_STEPS = 0
+
+
+def reset_global_counter():
+ global GLOBAL_COUNTER
+ global MINIBATCH_STEPS
+ GLOBAL_COUNTER = 0
+ MINIBATCH_STEPS += 1
+
+
+class OpType(Enum):
+ SEND, RECV_FWD, RECV_BWD = 0, 1, 2
+
+
+p2p_overlap = False
+send_reqs = []
+recv_fwd_reqs = []
+recv_bwd_reqs = []
+
+
+# Sync P2P-send OP
+def drain_send_reqs():
+ global send_reqs
+ if len(send_reqs) == 0:
+ return
+ for req in send_reqs:
+ req.wait()
+ send_reqs.clear()
+
+
+# Sync P2P-recv OP: we differ forward recv reqs from backward recv reqs
+# to enable 1F1B P2P communication overlap
+def drain_recv_reqs(drain_type="all"):
+ global recv_fwd_reqs, recv_bwd_reqs
+ if drain_type == "all" or drain_type == "forward":
+ if len(recv_fwd_reqs) > 0:
+ for req in recv_fwd_reqs:
+ req.wait()
+ recv_fwd_reqs.clear()
+ if drain_type == "all" or drain_type == "backward":
+ if len(recv_bwd_reqs) > 0:
+ for req in recv_bwd_reqs:
+ req.wait()
+ recv_bwd_reqs.clear()
+
+
+def _mapping_local_rank_to_target_rank_by_device_mesh(
+ *, current_device_mesh: DeviceMesh, target_device_mesh: DeviceMesh, local_rank: int
+):
+ """Mapping local rank in current device mesh to find target rank in target device mesh
+
+ Takes the following arguments:
+ current_device_mesh: current device mesh for locate rank position
+ target_device_mesh: target device mesh for mapping to target rank
+ Returns:
+ target_rank
+ """
+ if target_device_mesh is None:
+ return None
+ current_device_mesh_list = current_device_mesh.mesh.view(-1).tolist()
+ assert local_rank in current_device_mesh_list
+ current_rank_pos = current_device_mesh_list.index(local_rank)
+ target_rank = target_device_mesh.mesh.view(-1).tolist()[current_rank_pos]
+ return target_rank
+
+
+def _get_p2p_send_recv_process_group(
+ *, current_device_mesh: DeviceMesh, target_device_mesh: DeviceMesh, local_rank: int
+):
+ target_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ current_device_mesh=current_device_mesh, target_device_mesh=target_device_mesh
+ )
+ return list(local_rank, target_rank)
+
+
+def _communicate_shapes(
+ *,
+ tensor_send_next: torch.tensor,
+ tensor_send_prev: torch.tensor,
+ prev_rank: int,
+ next_rank: int,
+ recv_prev: bool,
+ recv_next: bool,
+ local_rank: int,
+ shape_dim: int = 3,
+):
+ """Communicate tensor shapes between stages. Used to communicate
+ tensor shapes before the actual tensor communication happens.
+ This is required when the sequence lengths across micro batches
+ are not uniform.
+
+ Takes the following arguments:
+ tensor_send_next: DTensor or torch.tensor to send to next rank (no tensor sent if
+ set to None).
+ tensor_send_prev: DTensor or torch.tensor to send to prev rank (no tensor sent if
+ set to None).
+ prev_rank: prev rank for send/recv rank
+ next_rank: next rank for send/recv rank
+ recv_prev: boolean for whether tensor should be received from
+ previous rank.
+ recv_next: boolean for whether tensor should be received from
+ next rank.
+ shape_dim: default to 3, which is set in megatron, in this refactor func, you can set shape dim
+ Returns:
+ (recv_prev_shape, recv_next_shape)
+ """
+
+ recv_prev_shape_tensor = None
+ recv_next_shape_tensor = None
+ send_prev_shape_tensor = None
+ send_next_shape_tensor = None
+
+ if recv_prev:
+ recv_prev_shape_tensor = torch.empty((shape_dim), device=torch.cuda.current_device(), dtype=torch.int64)
+ if recv_next:
+ recv_next_shape_tensor = torch.empty((shape_dim), device=torch.cuda.current_device(), dtype=torch.int64)
+ if tensor_send_prev is not None:
+ if isinstance(tensor_send_prev, DTensor):
+ send_prev_shape_tensor = torch.tensor(
+ tensor_send_prev._local_tensor.size(), device=torch.cuda.current_device(), dtype=torch.int64
+ )
+ else:
+ send_prev_shape_tensor = torch.tensor(
+ tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64
+ )
+ if tensor_send_next is not None:
+ if isinstance(tensor_send_next, DTensor):
+ send_next_shape_tensor = torch.tensor(
+ tensor_send_next._local_tensor.size(), device=torch.cuda.current_device(), dtype=torch.int64
+ )
+ else:
+ send_next_shape_tensor = torch.tensor(
+ tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64
+ )
+ ops = []
+ if send_prev_shape_tensor is not None:
+ send_prev_op = torch.distributed.P2POp(torch.distributed.isend, send_prev_shape_tensor, prev_rank)
+ ops.append(send_prev_op)
+ if recv_next_shape_tensor is not None:
+ recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, recv_next_shape_tensor, next_rank)
+ ops.append(recv_next_op)
+ if send_next_shape_tensor is not None:
+ send_next_op = torch.distributed.P2POp(torch.distributed.isend, send_next_shape_tensor, next_rank)
+ ops.append(send_next_op)
+ if recv_prev_shape_tensor is not None:
+ recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, recv_prev_shape_tensor, prev_rank)
+ ops.append(recv_prev_op)
+
+ if len(ops) > 0:
+ reqs = torch.distributed.batch_isend_irecv(ops)
+ for req in reqs:
+ req.wait()
+
+ # To protect against race condition when using batch_isend_irecv().
+ # should take this out once the bug with batch_isend_irecv is resolved.
+ if not _coalescing_manager:
+ torch.cuda.synchronize()
+
+ recv_prev_shape = [0, 0, 0]
+ if recv_prev_shape_tensor is not None:
+ recv_prev_shape = recv_prev_shape_tensor.tolist()
+
+ recv_next_shape = [0, 0, 0]
+ if recv_next_shape_tensor is not None:
+ recv_next_shape = recv_next_shape_tensor.tolist()
+
+ return recv_prev_shape, recv_next_shape
+
+
+def _batched_p2p_ops(
+ *,
+ tensor_send_prev: Optional[torch.Tensor],
+ tensor_recv_prev: Optional[torch.Tensor],
+ tensor_send_next: Optional[torch.Tensor],
+ tensor_recv_next: Optional[torch.Tensor],
+ prev_rank: int,
+ next_rank: int,
+ group: torch.distributed.ProcessGroup,
+ local_rank: int,
+ send_tensor_shape_unpad: Shape = None,
+ p2p_overlap=False,
+):
+ ops = []
+ if tensor_send_prev is not None:
+ send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev, prev_rank)
+ ops.append(send_prev_op)
+ if tensor_recv_prev is not None:
+ recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, prev_rank)
+ ops.append(recv_prev_op)
+ if tensor_send_next is not None:
+ send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, next_rank)
+ ops.append(send_next_op)
+ if tensor_recv_next is not None:
+ recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next, next_rank)
+ ops.append(recv_next_op)
+ if len(ops) > 0:
+ reqs = torch.distributed.batch_isend_irecv(ops)
+ else:
+ reqs = []
+ return reqs
+
+
+def check_nan(tensor_list, check=False):
+ if check:
+ for t in tensor_list:
+ assert not torch.isnan(t).any(), (
+ "tensor shape: "
+ + str(t.shape)
+ + ", dtype: "
+ + str(t.dtype)
+ + ", device: "
+ + str(t.device)
+ + ", # of NaN elements: "
+ + str(torch.sum(torch.isnan(t)).item())
+ + ", NaN element indexes: "
+ + str(torch.isnan(t).nonzero())
+ )
+
+
+def _p2p_ops(
+ *,
+ tensor_send_prev: Optional[torch.Tensor],
+ tensor_recv_prev: Optional[torch.Tensor],
+ tensor_send_next: Optional[torch.Tensor],
+ tensor_recv_next: Optional[torch.Tensor],
+ prev_rank: int,
+ next_rank: int,
+ group: torch.distributed.ProcessGroup,
+ local_rank: int,
+ p2p_overlap=False,
+ send_tensor_shape_unpad: Shape = None,
+ # file=None,
+):
+ reqs = []
+
+ """
+ by now the megatron pingpong
+ send recv is not supported because the global
+ devicemeshmanager is not impled. we will use
+ the ucx and mpi two-end no-blocking api to do
+ the send recv
+ """
+ stage_id = int(os.environ.get("STAGE_ID", "0"))
+ op_type = []
+ if stage_id % 2:
+ if tensor_send_next is not None:
+ if send_tensor_shape_unpad is not None:
+ assert (
+ send_tensor_shape_unpad[0] <= tensor_send_next.shape[0]
+ ), f"{send_tensor_shape_unpad} vs {tensor_send_next.shape}"
+ check_nan([tensor_send_next[: send_tensor_shape_unpad[0]]])
+ else:
+ check_nan([tensor_send_next])
+ send_next_req = torch.distributed.isend(
+ tensor=tensor_send_next,
+ dst=next_rank,
+ group=group,
+ )
+ reqs.append(send_next_req)
+ op_type.append(OpType.SEND)
+
+ if tensor_recv_prev is not None:
+ recv_prev_req = torch.distributed.irecv(
+ tensor=tensor_recv_prev,
+ src=prev_rank,
+ group=group,
+ )
+ reqs.append(recv_prev_req)
+ op_type.append(OpType.RECV_FWD)
+
+ if tensor_send_prev is not None:
+ if send_tensor_shape_unpad is not None:
+ assert (
+ send_tensor_shape_unpad[0] <= tensor_send_prev.shape[0]
+ ), f"{send_tensor_shape_unpad} vs {tensor_send_prev.shape}"
+ check_nan([tensor_send_prev[: send_tensor_shape_unpad[0]]])
+ else:
+ check_nan([tensor_send_prev])
+
+ send_prev_req = torch.distributed.isend(
+ tensor=tensor_send_prev,
+ dst=prev_rank,
+ group=group,
+ )
+ reqs.append(send_prev_req)
+ op_type.append(OpType.SEND)
+
+ if tensor_recv_next is not None:
+ recv_next_req = torch.distributed.irecv(
+ tensor=tensor_recv_next,
+ src=next_rank,
+ group=group,
+ )
+ reqs.append(recv_next_req)
+ op_type.append(OpType.RECV_BWD)
+
+ else:
+ if tensor_recv_prev is not None:
+ recv_prev_req = torch.distributed.irecv(
+ tensor=tensor_recv_prev,
+ src=prev_rank,
+ group=group,
+ )
+ reqs.append(recv_prev_req)
+ op_type.append(OpType.RECV_FWD)
+ if tensor_send_next is not None:
+ if send_tensor_shape_unpad is not None:
+ assert (
+ send_tensor_shape_unpad[0] <= tensor_send_next.shape[0]
+ ), f"{send_tensor_shape_unpad} vs {tensor_send_next.shape}"
+ check_nan([tensor_send_next[: send_tensor_shape_unpad[0]]])
+ else:
+ check_nan([tensor_send_next])
+ send_next_req = torch.distributed.isend(
+ tensor=tensor_send_next,
+ dst=next_rank,
+ group=group,
+ )
+ reqs.append(send_next_req)
+ op_type.append(OpType.SEND)
+
+ if tensor_recv_next is not None:
+ recv_next_req = torch.distributed.irecv(
+ tensor=tensor_recv_next,
+ src=next_rank,
+ group=group,
+ )
+ reqs.append(recv_next_req)
+ op_type.append(OpType.RECV_BWD)
+
+ if tensor_send_prev is not None:
+ if send_tensor_shape_unpad is not None:
+ assert (
+ send_tensor_shape_unpad[0] <= tensor_send_prev.shape[0]
+ ), f"{send_tensor_shape_unpad} vs {tensor_send_prev.shape}"
+ check_nan([tensor_send_prev[: send_tensor_shape_unpad[0]]])
+ else:
+ check_nan([tensor_send_prev])
+
+ send_prev_req = torch.distributed.isend(
+ tensor=tensor_send_prev,
+ dst=prev_rank,
+ group=group,
+ )
+ reqs.append(send_prev_req)
+ op_type.append(OpType.SEND)
+
+ if p2p_overlap:
+ # For P2P-comm overlap
+ global send_reqs, recv_fwd_reqs, recv_bwd_reqs
+ for i in range(len(op_type)):
+ if op_type[i] == OpType.SEND:
+ send_reqs.append(reqs[i])
+ elif op_type[i] == OpType.RECV_FWD:
+ recv_fwd_reqs.append(reqs[i])
+ elif op_type[i] == OpType.RECV_BWD:
+ recv_bwd_reqs.append(reqs[i])
+
+ return reqs
+
+
+def _communicate(
+ *,
+ tensor_send_next: Optional[torch.Tensor],
+ tensor_send_prev: Optional[torch.Tensor],
+ recv_prev: bool,
+ recv_next: bool,
+ current_device_mesh: DeviceMesh,
+ prev_device_mesh: DeviceMesh = None,
+ next_device_mesh: DeviceMesh = None,
+ tensor_shape: Shape = None,
+ send_tensor_shape_unpad: Shape = None,
+ batch_p2p_comm: bool = True,
+ wait_on_reqs: bool = True,
+ dtype: Optional[torch.dtype],
+ group: ProcessGroup = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Communicate tensors between stages. Used as helper method in other
+ communication methods that are used in vescale/schedules.py.
+
+ Arguments:
+ tensor_send_next (torch.Tensor, optional):
+ Tensor to send to next rank (no tensor sent if None)
+
+ tensor_send_prev (torch.Tensor, optional):
+ Tensor to send to prev rank (no tensor sent if None)
+
+ current_device_mesh (DeviceMesh, required):
+ Current device mesh for locate rank position
+
+ prev_device_mesh (DeviceMesh, required):
+ Target device mesh for mapping to pre rank
+
+ next_device_mesh (DeviceMesh, required):
+ Target device mesh for mapping to next rank
+
+ recv_prev (boolean, required):
+ whether tensor should be received from previous rank.
+
+ recv_next (boolean, required):
+ whether tensor should be received from next rank.
+
+ tensor_shape (List[int] or torch.Size, required):
+ shape of tensor to receive (this method assumes that all
+ tensors sent and received in a single function call are
+ the same shape). If none, using dynamic shape
+
+ batch_p2p_comm (boolean, required):
+ If true use batch_isend_irecv, otherwise use individual
+ isend and irecv calls.
+
+ wait_on_reqs (boolean, optional, default=False):
+ For non-batched p2p communication, wait on each request
+ before returning.
+
+ dtype (torch.dtype, required if either recv_{prev,next} is True):
+ this must be the type of the tensors that will be
+ received, will typically be params_dtype, but in the case
+ of fp32 residual connections might be torch.float.
+
+ variable_seq_lengths (bool, optional, default=False):
+ Support for variable sequence lengths across
+ microbatches. Setting this communicates the size of
+ tensors during pipeline parallelism communication, because
+ of this extra overhead it should only be set if the
+ sequence length is not constant during training.
+
+ Returns:
+ tuple containing
+
+ - tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
+ - tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
+
+ """
+ # Init p2p_overlap: Use a global var to enable p2p comm overlap,
+ # so as not to change the original APIs
+
+ global p2p_overlap
+ if not wait_on_reqs and not p2p_overlap:
+ p2p_overlap = True
+
+ # Create placeholder tensors for receive in forward and backward directions
+ # if needed.
+ tensor_recv_prev = None
+ tensor_recv_next = None
+
+ # This will come from config in the next version, for now hard
+ # code it here to match existing functionality.
+ batch_p2p_sync = True
+ local_rank = current_device_mesh.get_rank()
+ # parse current device mesh and target device mesh
+ prev_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=local_rank, current_device_mesh=current_device_mesh, target_device_mesh=prev_device_mesh
+ )
+ next_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=local_rank, current_device_mesh=current_device_mesh, target_device_mesh=next_device_mesh
+ )
+ # flag to reuse intermediate tensor shapes of recorded tensors in first minibatch
+ reuse_intermediate_shapes = os.environ.get("REUSE_COMM_SHAPE", "0") == "1"
+
+ if tensor_shape is not None:
+ recv_prev_shape = tensor_shape
+ recv_next_shape = tensor_shape
+ else:
+ global GLOBAL_COUNTER
+ global INTERMEDIATE_SHAPES
+ global MINIBATCH_STEPS
+ if reuse_intermediate_shapes and MINIBATCH_STEPS > 1:
+ recv_prev_shape, recv_next_shape = INTERMEDIATE_SHAPES[GLOBAL_COUNTER]
+ else:
+ recv_prev_shape, recv_next_shape = _communicate_shapes(
+ tensor_send_next=tensor_send_next,
+ tensor_send_prev=tensor_send_prev,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ prev_rank=prev_rank,
+ next_rank=next_rank,
+ local_rank=local_rank,
+ )
+ if reuse_intermediate_shapes:
+ INTERMEDIATE_SHAPES.append((recv_prev_shape, recv_next_shape))
+ GLOBAL_COUNTER += 1
+
+ if recv_prev:
+ if dtype is None:
+ raise RuntimeError("dtype must be provided if recv_prev is True")
+ if recv_prev_shape is None:
+ raise RuntimeError(
+ "tensor_shape must be specified if recv_prev is True. "
+ "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
+ )
+ tensor_recv_prev = torch.empty(
+ recv_prev_shape, requires_grad=True, device=torch.cuda.current_device(), dtype=dtype
+ )
+ if recv_next:
+ if dtype is None:
+ raise RuntimeError("dtype must be provided if recv_next is True")
+ if recv_next_shape is None:
+ raise RuntimeError(
+ "tensor_shape must be specified if recv_next is True. "
+ "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
+ )
+ tensor_recv_next = torch.empty(
+ recv_next_shape, requires_grad=True, device=torch.cuda.current_device(), dtype=dtype
+ )
+
+ # Send tensors in both the forward and backward directions as appropriate.
+ if batch_p2p_comm:
+ assert wait_on_reqs
+ p2p_func = _batched_p2p_ops
+ else:
+ p2p_func = _p2p_ops
+
+ # if file:
+ # file.write(
+ # f"\np2p tensor_send_prev:{tensor_send_prev}, tensor_recv_prev:{tensor_recv_prev} {id(tensor_recv_prev)}, tensor_send_next:{tensor_send_next} {id(tensor_send_next)}, tensor_recv_next:{tensor_recv_next}, prev_rank: {prev_rank}, next_rank: {next_rank}, local_rank: {local_rank}\n"
+ # )
+ # file.flush()
+ reqs = p2p_func(
+ tensor_send_prev=tensor_send_prev,
+ tensor_recv_prev=tensor_recv_prev,
+ tensor_send_next=tensor_send_next,
+ tensor_recv_next=tensor_recv_next,
+ prev_rank=prev_rank,
+ next_rank=next_rank,
+ group=group,
+ local_rank=local_rank,
+ send_tensor_shape_unpad=send_tensor_shape_unpad,
+ p2p_overlap=p2p_overlap,
+ )
+
+ if wait_on_reqs and len(reqs) > 0:
+ for req in reqs:
+ req.wait()
+ reqs = None
+
+ if batch_p2p_comm and batch_p2p_sync:
+ # To protect against race condition when using batch_isend_irecv().
+ if not _coalescing_manager:
+ torch.cuda.synchronize()
+
+ return tensor_recv_prev, tensor_recv_next, reqs
+
+
+def recv_forward(
+ tensor_shape: Shape,
+ recv_dtype: torch.dtype,
+ current_device_mesh: DeviceMesh,
+ peer_device_mesh: Optional[DeviceMesh] = None,
+ batch_p2p_comm: bool = True,
+) -> torch.Tensor:
+ """Receive tensor from previous rank in pipeline (forward receive).
+
+ See _communicate for argument details.
+
+ Args:
+ tensor_shape (Shape): shape of imminenently arrived tensors
+ recv_dtype (torch.dtype): data types of received tensors
+ current_device_mesh (DeviceMesh): sub-DeviceMesh of current stage
+ peer_device_mesh (Optional[DeviceMesh]): sub-DeviceMesh of sender/recipient stage
+ batch_p2p_comm (bool): switch to execute batched p2p transfer when turned on
+
+ Returns:
+ Received forward tensor
+
+ """
+ if peer_device_mesh is None:
+ intput_tensor = None
+ return intput_tensor
+ prev_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=current_device_mesh.get_rank(),
+ current_device_mesh=current_device_mesh,
+ target_device_mesh=peer_device_mesh,
+ )
+ with ndtimeit_p2p(RECV_FORWARD, dist.group.WORLD, prev_rank, batch_p2p_comm):
+ input_tensor, _, _ = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ current_device_mesh=current_device_mesh,
+ prev_device_mesh=peer_device_mesh,
+ recv_prev=True,
+ recv_next=False,
+ tensor_shape=tensor_shape,
+ batch_p2p_comm=batch_p2p_comm,
+ dtype=recv_dtype,
+ )
+ return input_tensor
+
+
+def recv_backward(
+ tensor_shape: Shape,
+ recv_dtype: torch.dtype,
+ current_device_mesh: DeviceMesh,
+ peer_device_mesh: Optional[DeviceMesh] = None,
+ batch_p2p_comm: bool = True,
+) -> torch.Tensor:
+ """Receive tensor from next rank in pipeline (backward receive).
+
+ See _communicate for argument details.
+
+ Args:
+ tensor_shape (Shape): shape of imminenently arrived tensors
+ recv_dtype (torch.dtype): data types of received tensors
+ current_device_mesh (DeviceMesh): sub-DeviceMesh of current stage
+ peer_device_mesh (Optional[DeviceMesh]): sub-DeviceMesh of sender/recipient stage
+ batch_p2p_comm (bool): switch to execute batched p2p transfer when turned on
+
+ Returns:
+ Received output tensor gradient.
+
+ """
+ if peer_device_mesh is None:
+ output_tensor_grad = None
+ return output_tensor_grad
+ next_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=current_device_mesh.get_rank(),
+ current_device_mesh=current_device_mesh,
+ target_device_mesh=peer_device_mesh,
+ )
+ with ndtimeit_p2p(RECV_BACKWARD, dist.group.WORLD, next_rank, batch_p2p_comm):
+ _, output_tensor_grad, _ = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ current_device_mesh=current_device_mesh,
+ next_device_mesh=peer_device_mesh,
+ recv_prev=False,
+ recv_next=True,
+ tensor_shape=tensor_shape,
+ dtype=recv_dtype,
+ batch_p2p_comm=batch_p2p_comm,
+ )
+ return output_tensor_grad
+
+
+def send_forward(
+ output_tensor: torch.Tensor,
+ current_device_mesh: DeviceMesh,
+ peer_device_mesh: Optional[DeviceMesh] = None,
+ tensor_shape: Optional[Shape] = None,
+ batch_p2p_comm: bool = True,
+) -> None:
+ """Send tensor to next rank in pipeline (forward send).
+
+ See _communicate for argument details.
+
+ Args:
+ output_tensor (torch.Tensor): backward input received from previous stage
+ current_device_mesh (DeviceMesh): sub-DeviceMesh of current stage
+ peer_device_mesh (Optional[DeviceMesh]): sub-DeviceMesh of sender/recipient stage
+ tensor_shape (Shape): shape of imminenently arrived tensors
+ batch_p2p_comm (bool): switch to execute batched p2p transfer when turned on
+
+ """
+
+ if peer_device_mesh is None:
+ return
+ next_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=current_device_mesh.get_rank(),
+ current_device_mesh=current_device_mesh,
+ target_device_mesh=peer_device_mesh,
+ )
+ with ndtimeit_p2p(SEND_FORWARD, dist.group.WORLD, next_rank, batch_p2p_comm):
+ _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=None,
+ current_device_mesh=current_device_mesh,
+ next_device_mesh=peer_device_mesh,
+ recv_prev=False,
+ recv_next=False,
+ tensor_shape=tensor_shape,
+ batch_p2p_comm=batch_p2p_comm,
+ dtype=None,
+ )
+
+
+def send_backward(
+ input_tensor_grad: torch.Tensor,
+ current_device_mesh: DeviceMesh,
+ peer_device_mesh: Optional[DeviceMesh] = None,
+ tensor_shape: Optional[Shape] = None,
+ batch_p2p_comm: bool = True,
+) -> None:
+ """Send tensor to previous rank in pipeline (backward send).
+
+ See _communicate for argument details.
+
+ Args:
+ input_tensor_grad (torch.Tensor): input tensor gradients
+ current_device_mesh (DeviceMesh): sub-DeviceMesh of current stage
+ peer_device_mesh (Optional[DeviceMesh]): sub-DeviceMesh of sender/recipient stage
+ tensor_shape (Shape): shape of imminenently arrived tensors
+ batch_p2p_comm (bool): switch to execute batched p2p transfer when turned on
+
+ """
+
+ if peer_device_mesh is None:
+ return
+ prev_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=current_device_mesh.get_rank(),
+ current_device_mesh=current_device_mesh,
+ target_device_mesh=peer_device_mesh,
+ )
+ with ndtimeit_p2p(SEND_BACKWARD, dist.group.WORLD, prev_rank, batch_p2p_comm):
+ _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=input_tensor_grad,
+ current_device_mesh=current_device_mesh,
+ prev_device_mesh=peer_device_mesh,
+ recv_prev=False,
+ recv_next=False,
+ tensor_shape=tensor_shape,
+ batch_p2p_comm=batch_p2p_comm,
+ dtype=None,
+ )
+
+
+def send_forward_recv_backward(
+ output_tensor: torch.Tensor,
+ tensor_shape: Shape,
+ recv_dtype: torch.dtype,
+ current_device_mesh: DeviceMesh,
+ peer_device_mesh: Optional[DeviceMesh] = None,
+ batch_p2p_comm: bool = True,
+) -> torch.Tensor:
+ """Batched send and recv with next rank in pipeline.
+
+ See _communicate for argument details.
+
+ Args:
+ output_tensor (torch.Tensor): backward input received from previous stage
+ tensor_shape (Shape): shape of imminenently arrived tensors
+ recv_dtype (torch.dtype): data types of received tensors
+ current_device_mesh (DeviceMesh): sub-DeviceMesh of current stage
+ peer_device_mesh (Optional[DeviceMesh]): sub-DeviceMesh of sender/recipient stage
+ batch_p2p_comm (bool): switch to execute batched p2p transfer when turned on
+
+ Returns:
+ Received output tensor gradients.
+
+ """
+
+ if peer_device_mesh is None:
+ output_tensor_grad = None
+ return output_tensor_grad
+ next_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=current_device_mesh.get_rank(),
+ current_device_mesh=current_device_mesh,
+ target_device_mesh=peer_device_mesh,
+ )
+ with ndtimeit_p2p(SEND_FORWARD_RECV_BACKWARD, dist.group.WORLD, next_rank, batch_p2p_comm):
+ _, output_tensor_grad, _ = _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=None,
+ current_device_mesh=current_device_mesh,
+ next_device_mesh=peer_device_mesh,
+ recv_prev=False,
+ recv_next=True,
+ tensor_shape=tensor_shape,
+ dtype=recv_dtype,
+ batch_p2p_comm=batch_p2p_comm,
+ )
+ return output_tensor_grad
+
+
+def send_backward_recv_forward(
+ input_tensor_grad: torch.Tensor,
+ tensor_shape: Shape,
+ recv_dtype: torch.dtype,
+ current_device_mesh: DeviceMesh,
+ peer_device_mesh: Optional[DeviceMesh] = None,
+ batch_p2p_comm: bool = True,
+) -> torch.Tensor:
+ """
+ Batched send and recv with previous rank in pipeline.
+
+ See _communicate for argument details.
+
+ Args:
+ input_tensor_grad (torch.Tensor): input tensor gradients
+ tensor_shape (Shape): shape of imminenently arrived tensors
+ recv_dtype (torch.dtype): data types of received tensors
+ current_device_mesh (DeviceMesh): sub-DeviceMesh of current stage
+ peer_device_mesh (Optional[DeviceMesh]): sub-DeviceMesh of sender/recipient stage
+ batch_p2p_comm (bool): switch to execute batched p2p transfer when turned on
+
+ Returns:
+ Received tensor.
+
+ """
+ if peer_device_mesh is None:
+ input_tensor = None
+ return input_tensor
+ prev_rank = _mapping_local_rank_to_target_rank_by_device_mesh(
+ local_rank=current_device_mesh.get_rank(),
+ current_device_mesh=current_device_mesh,
+ target_device_mesh=peer_device_mesh,
+ )
+ with ndtimeit_p2p(SEND_BACKWARD_RECV_FORWARD, dist.group.WORLD, prev_rank, batch_p2p_comm):
+ input_tensor, _, _ = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=input_tensor_grad,
+ current_device_mesh=current_device_mesh,
+ prev_device_mesh=peer_device_mesh,
+ recv_prev=True,
+ recv_next=False,
+ tensor_shape=tensor_shape,
+ dtype=recv_dtype,
+ batch_p2p_comm=batch_p2p_comm,
+ )
+ return input_tensor
+
+
+def send_forward_recv_forward(
+ output_tensor: torch.Tensor,
+ recv_prev: bool,
+ tensor_shape: Shape,
+ current_device_mesh: DeviceMesh,
+ prev_device_mesh: DeviceMesh,
+ next_device_mesh: DeviceMesh,
+ send_tensor_shape_unpad: Shape = None,
+ overlap_p2p_comm: bool = False,
+ recv_dtype: Optional[torch.dtype] = None,
+ batch_p2p_comm: bool = True,
+ group: ProcessGroup = None,
+) -> torch.Tensor:
+ """Batched recv from previous rank and send to next rank in pipeline.
+
+ See _communicate for argument details.
+ """
+ # auto state change
+ if prev_device_mesh is None:
+ recv_prev = False
+ if next_device_mesh is None:
+ input_tensor, _, wait_handles = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ current_device_mesh=current_device_mesh,
+ prev_device_mesh=prev_device_mesh,
+ next_device_mesh=next_device_mesh,
+ recv_prev=recv_prev,
+ recv_next=False,
+ tensor_shape=tensor_shape,
+ send_tensor_shape_unpad=send_tensor_shape_unpad,
+ batch_p2p_comm=batch_p2p_comm,
+ wait_on_reqs=(not overlap_p2p_comm),
+ dtype=recv_dtype,
+ group=group,
+ )
+ else:
+ input_tensor, _, wait_handles = _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=None,
+ current_device_mesh=current_device_mesh,
+ prev_device_mesh=prev_device_mesh,
+ next_device_mesh=next_device_mesh,
+ recv_prev=recv_prev,
+ recv_next=False,
+ tensor_shape=tensor_shape,
+ send_tensor_shape_unpad=send_tensor_shape_unpad,
+ batch_p2p_comm=batch_p2p_comm,
+ wait_on_reqs=(not overlap_p2p_comm),
+ dtype=recv_dtype,
+ group=group,
+ )
+ if overlap_p2p_comm:
+ return input_tensor, wait_handles
+ return input_tensor
+
+
+def send_backward_recv_backward(
+ input_tensor_grad: torch.Tensor,
+ recv_next: bool,
+ tensor_shape: Shape,
+ current_device_mesh: DeviceMesh,
+ prev_device_mesh: DeviceMesh,
+ next_device_mesh: DeviceMesh,
+ send_tensor_shape_unpad: Shape = None,
+ overlap_p2p_comm: bool = False,
+ recv_dtype: Optional[torch.dtype] = None,
+ batch_p2p_comm: bool = True,
+ group: ProcessGroup = None,
+) -> torch.Tensor:
+ """Batched recv from next rank and send to previous rank in pipeline.
+
+ See _communicate for argument details.
+ """
+ # auto state change
+ if next_device_mesh is None:
+ recv_next = False
+ if prev_device_mesh is None:
+ _, output_tensor_grad, wait_handles = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ current_device_mesh=current_device_mesh,
+ prev_device_mesh=prev_device_mesh,
+ next_device_mesh=next_device_mesh,
+ recv_prev=False,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ send_tensor_shape_unpad=send_tensor_shape_unpad,
+ batch_p2p_comm=batch_p2p_comm,
+ wait_on_reqs=(not overlap_p2p_comm),
+ dtype=recv_dtype,
+ group=group,
+ # file=file,
+ )
+ else:
+ _, output_tensor_grad, wait_handles = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=input_tensor_grad,
+ current_device_mesh=current_device_mesh,
+ prev_device_mesh=prev_device_mesh,
+ next_device_mesh=next_device_mesh,
+ recv_prev=False,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ send_tensor_shape_unpad=send_tensor_shape_unpad,
+ batch_p2p_comm=batch_p2p_comm,
+ wait_on_reqs=(not overlap_p2p_comm),
+ dtype=recv_dtype,
+ group=group,
+ )
+ if overlap_p2p_comm:
+ return output_tensor_grad, wait_handles
+ return output_tensor_grad
+
+
+def send_forward_backward_recv_forward_backward(
+ output_tensor: torch.Tensor,
+ input_tensor_grad: torch.Tensor,
+ recv_prev: bool,
+ recv_next: bool,
+ tensor_shape: Shape,
+ current_device_mesh: DeviceMesh,
+ prev_device_mesh: DeviceMesh,
+ next_device_mesh: DeviceMesh,
+ recv_dtype: Optional[torch.dtype] = None,
+ batch_p2p_comm: bool = True,
+) -> torch.Tensor:
+ """Batched send and recv with previous and next ranks in pipeline.
+
+ See _communicate for argument details.
+ """
+ input_tensor, output_tensor_grad, _ = _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=input_tensor_grad,
+ current_device_mesh=current_device_mesh,
+ prev_device_mesh=prev_device_mesh,
+ next_device_mesh=next_device_mesh,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ dtype=recv_dtype,
+ batch_p2p_comm=batch_p2p_comm,
+ )
+ return input_tensor, output_tensor_grad
diff --git a/vescale/pipe/pipe_emmiter.py b/vescale/pipe/pipe_emmiter.py
new file mode 100644
index 0000000..d60fc3d
--- /dev/null
+++ b/vescale/pipe/pipe_emmiter.py
@@ -0,0 +1,356 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+# mypy: ignore-errors
+from vescale.dtensor.device_mesh import DeviceMesh
+from vescale.plan.pipeline_parallel import PipelineParallelPlan
+from vescale.plan.spec import PipelineScheduleType
+from vescale.pipe._schedules import (
+ OneFOneBInstrcutionGenerator,
+ InterleavedOneFOneBInstructionGenerator,
+ ZeroBubbleVInstrcutionGenerator,
+ StageDeps,
+ Shape,
+)
+from vescale.pipe._schedules.instruction_base import VESCALE_INTRUCTION_BUILDER as builder
+from vescale.pipe.p2p_communication import reset_global_counter
+from vescale.devicemesh_api.api import VeDeviceMesh
+from collections import OrderedDict
+from typing import Callable, Iterator, List, Sequence, Union
+import torch
+import torch.distributed as dist
+import logging
+import os
+
+
+logger = logging.Logger(__file__)
+
+
+class PipelineEmitter:
+ """Pipeline Emitter."""
+
+ def __init__(
+ self,
+ deps: StageDeps,
+ meshes: Sequence[DeviceMesh],
+ schedule: str,
+ batches: int,
+ tensor_shape: Shape,
+ dtype: torch.dtype,
+ num_chunks: int = 1,
+ input_shapes: List[Shape] = None,
+ input_shapes_unpad: List[Shape] = None,
+ forward_only=False,
+ overlap_p2p_comm=False,
+ batch_p2p_comm: bool = True,
+ param_sync_overlap=False,
+ grad_sync_overlap=False,
+ **kwargs,
+ ):
+ self.deps = deps
+ self.num_stage = deps.num_stage
+ self.meshes = meshes
+ self.batches = batches
+ self.num_chunks = num_chunks
+ self.overlap_p2p_comm = overlap_p2p_comm
+ self.batch_p2p_comm = batch_p2p_comm
+ self.param_sync_overlap = param_sync_overlap
+ self.forward_only = forward_only
+ self.grad_sync_overlap = grad_sync_overlap
+ if schedule == PipelineScheduleType.SIMPLE_1F1B:
+ self.num_meshes = meshes
+ self.instruction_generator = OneFOneBInstrcutionGenerator(
+ deps=deps,
+ meshes=self.meshes,
+ batches=batches,
+ default_shape=tensor_shape,
+ default_dtype=dtype,
+ forward_only=self.forward_only,
+ )
+
+ elif schedule == PipelineScheduleType.INTERLEAVED_1F1B:
+ self.instruction_generator = InterleavedOneFOneBInstructionGenerator(
+ deps=deps,
+ meshes=self.meshes,
+ batches=batches,
+ default_shape=tensor_shape,
+ default_dtype=dtype,
+ input_shapes=input_shapes,
+ input_shapes_unpad=input_shapes_unpad,
+ num_chunks=self.num_chunks,
+ batch_p2p_comm=batch_p2p_comm,
+ overlap_p2p_comm=overlap_p2p_comm,
+ param_sync_overlap=param_sync_overlap,
+ grad_sync_overlap=grad_sync_overlap,
+ forward_only=forward_only,
+ )
+
+ elif schedule == PipelineScheduleType.ZERO_BUBBLE:
+ self.instruction_generator = ZeroBubbleVInstrcutionGenerator(
+ deps=deps,
+ meshes=self.meshes,
+ batches=batches,
+ default_shape=tensor_shape,
+ default_dtype=dtype,
+ **kwargs,
+ )
+ else:
+ raise NotImplementedError("unsupport schedule type")
+ self.instruction_list: List[List] = self.gen_instruction()
+
+ def gen_instruction(self):
+ """
+ Generates instruction steps of a pipeline schedule.
+ """
+ return self.instruction_generator.gen_instruction()
+
+ def get_instruction_list(self, stage: int):
+ """
+ Generates instruction steps of a pipeline schedule for a particular pipeline stage.
+
+ Args:
+ stage (int): pipeline stage id
+
+ """
+ return self.instruction_generator.get_instruction_list(stage)
+
+
+class ScheduleEngine:
+ def __init__(
+ self,
+ deps: StageDeps,
+ meshes: int,
+ schedule: PipelineScheduleType,
+ batches: int,
+ data_iterator: Union[Iterator, List[Iterator]],
+ stage_id: int,
+ shape: Union[Shape, Sequence[Shape]],
+ dtype: Union[torch.dtype, Sequence[torch.dtype]] = torch.float32,
+ num_chunks=1,
+ input_shapes: List[Shape] = None,
+ input_shapes_unpad: List[Shape] = None,
+ forward_only=False,
+ overlap_p2p_comm=False,
+ batch_p2p_comm: bool = True,
+ param_sync_overlap=False,
+ grad_sync_overlap=False,
+ send_dtypes_map: OrderedDict = None,
+ loss_fn: Callable = lambda x: torch.sum(x),
+ global_mesh: VeDeviceMesh = None,
+ **kwargs,
+ ):
+ os.environ["STAGE_ID"] = str(stage_id)
+ self.p_emmiter = PipelineEmitter(
+ deps,
+ meshes,
+ schedule,
+ batches,
+ shape,
+ dtype,
+ num_chunks=num_chunks,
+ input_shapes=input_shapes,
+ input_shapes_unpad=input_shapes_unpad,
+ forward_only=forward_only,
+ overlap_p2p_comm=overlap_p2p_comm,
+ batch_p2p_comm=batch_p2p_comm,
+ param_sync_overlap=param_sync_overlap,
+ grad_sync_overlap=grad_sync_overlap,
+ **kwargs,
+ )
+ self.schedule = schedule
+ self.deps = deps
+ self.instruction_list = self.get_instruction_list(stage_id)
+ self.stage_id = stage_id
+ self.shape = shape
+ self.dtype = dtype
+ self.chunk = num_chunks
+ self.send_dtypes_map = send_dtypes_map
+ builder.topo = deps
+ builder.dataloader = data_iterator
+ builder.loss_fn = loss_fn
+ self.src_loss_rank = -1
+ self.global_mesh = global_mesh
+ if self.global_mesh:
+ all_ranks = list(range(dist.get_world_size()))
+ dp_rank = self.global_mesh.get_data_parallel_rank()
+ tp_rank = self.global_mesh.get_tensor_parallel_rank()
+ same_pipeline_group = [
+ rank for rank in all_ranks if self.global_mesh.get_strategy_coordinate(rank)[1:] == [dp_rank, tp_rank]
+ ]
+ for rank in same_pipeline_group:
+ if self.global_mesh.get_strategy_coordinate(rank)[0] == self.global_mesh.size(0) - 1:
+ self.src_loss_rank = rank
+ break
+ # the group for all ranks in the same pipeline to share final loss outputs
+ self.sync_loss_group = dist.new_group(ranks=same_pipeline_group, backend="nccl")
+
+ def set_data_iterator(self, data_iterator: List, data_shape=None):
+ """
+ Assigns minibatch data to instruction builder.
+
+ Args:
+ data_iterator (List): a minibatch list of microbatch data
+
+ """
+ assert builder.dataloader
+ builder.dataloader = data_iterator
+ if data_shape:
+ self.shape = data_shape
+ builder.constant_data["shape"] = data_shape
+
+ def get_instruction_list(self, stage_id):
+ return self.p_emmiter.get_instruction_list(stage_id)
+
+ def sync_output_loss_per_pipeline(self, loss: torch.Tensor):
+ """
+ A debug mode function that synchronizes minibatch loss
+ with all stages of a pipeline.
+
+ Args:
+ data_iterator (List): a minibatch list of microbatch data
+
+ """
+ assert self.global_mesh, "Must initialize per-pipeline dist group before synchronizing loss!"
+ if loss is None:
+ loss = torch.tensor(0.0, dtype=torch.float).cuda(dist.get_rank())
+ dist.broadcast(loss, src=self.src_loss_rank, group=self.sync_loss_group)
+
+ # monkey patch torch.tensor loss backward as empty tensor to make it a dummy function
+ def _empty_backward():
+ return None
+
+ loss.backward = _empty_backward
+ return loss
+
+ def _collect_microbatch_losses(self, outputs):
+ # monkey patch torch.tensor loss backward as empty tensor to make it a dummy function
+ def _empty_backward():
+ return None
+
+ output_losses = []
+ for microbatch_output, microbatch_loss in outputs:
+ if microbatch_loss is None:
+ if isinstance(microbatch_output, Sequence):
+ for j in range(len(microbatch_output)):
+ if microbatch_output[j].ndim == 0 and microbatch_output[j].numel() == 1:
+ loss_value = microbatch_output[j]
+ break
+ else:
+ raise ValueError("Loss values not found.")
+ else:
+ loss_value = microbatch_output
+ else:
+ # monkey patch microbatch loss backward as empty tensor to make it a dummy function
+ loss_value = microbatch_loss
+ output_losses.append(loss_value)
+ if not output_losses:
+ return None
+ tensor_device = output_losses[0].device
+ minibatch_loss = torch.tensor(sum(output_losses), device=tensor_device)
+ minibatch_loss.backward = _empty_backward
+ return minibatch_loss
+
+ @staticmethod
+ def execute(
+ instance,
+ *,
+ deallocate_pipeline_outputs: bool = False,
+ autocast_dtype: torch.dtype = torch.float,
+ enable_autocast: bool = False,
+ grad_scaler=None,
+ param_sync_func=None,
+ grad_sync_func=None,
+ debug_mode=False,
+ ):
+ """
+ Main entry point of executing forward and backward
+ computation of a minibatch.
+
+ Args:
+ instance (ScheduleEngine): a minibatch list of microbatch data
+ deallocate_pipeline_outputs (bool): deallocate tensors
+ autocast_dtype (torch.dtype): autocast data types
+ enable_autocast (bool): turn on to enable tensor autocast
+ grad_scaler (Callable): gradient scaler
+ param_sync_func (Callable): gradient synchronization function
+ debug_mode (bool): turn on to generate debugging outputs
+
+ Returns:
+ A tuple of two elements:
+ 1). loss of this minibatch of data,
+ 2). a list of tuple of outputs per microbatch, where for each tuple:
+ - 2.1). the first element is output of the original model
+ - 2.2). the second element is the loss of this microbatch.
+ If loss_fn is not provided at initialization, it means loss
+ is computed in 2.1) and here will return None
+
+ """
+ reset_global_counter()
+ if instance.schedule == PipelineScheduleType.SIMPLE_1F1B:
+ minibatch_outputs = instance.p_emmiter.instruction_generator.execute(
+ stage_id=instance.stage_id,
+ enable_autocast=enable_autocast,
+ autocast_dtype=autocast_dtype,
+ grad_scaler=grad_scaler,
+ deallocate_pipeline_outputs=deallocate_pipeline_outputs,
+ )
+ minibatch_loss = instance._collect_microbatch_losses(minibatch_outputs)
+ if debug_mode:
+ minibatch_loss = instance.sync_output_loss_per_pipeline(minibatch_loss)
+ return minibatch_loss, minibatch_outputs
+ elif instance.schedule == PipelineScheduleType.INTERLEAVED_1F1B:
+ minibatch_outputs = instance.p_emmiter.instruction_generator.execute(
+ stage_id=instance.stage_id,
+ enable_autocast=enable_autocast,
+ autocast_dtype=autocast_dtype,
+ grad_scaler=grad_scaler,
+ deallocate_pipeline_outputs=deallocate_pipeline_outputs,
+ param_sync_func=param_sync_func,
+ grad_sync_func=grad_sync_func,
+ )
+ minibatch_loss = instance._collect_microbatch_losses(minibatch_outputs)
+ if debug_mode:
+ minibatch_loss = instance.sync_output_loss_per_pipeline(minibatch_loss)
+ return minibatch_loss, minibatch_outputs
+ elif instance.schedule == PipelineScheduleType.ZERO_BUBBLE:
+ minibatch_outputs = instance.p_emmiter.instruction_generator.execute(
+ stage_id=instance.stage_id,
+ enable_autocast=enable_autocast,
+ autocast_dtype=autocast_dtype,
+ grad_scaler=grad_scaler,
+ deallocate_pipeline_outputs=deallocate_pipeline_outputs,
+ )
+ minibatch_loss = instance._collect_microbatch_losses(minibatch_outputs)
+ if debug_mode:
+ minibatch_loss = instance.sync_output_loss_per_pipeline(minibatch_loss)
+ return minibatch_loss, minibatch_outputs
+ else:
+ raise NotImplementedError("Unsupported Schedule!")
+
+
+def validate_pipeline_schedule(plan: PipelineParallelPlan):
+ """
+ Validates pipeline schedule settings in Pipeline ParallelPlan.
+
+ Args:
+ plan (PipelineParallelPlan): configuration of pipeline parallel API attributes
+
+ """
+ if plan.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B:
+ assert plan.virtual_chunks > 1
+ elif plan.schedule_type == PipelineScheduleType.SIMPLE_1F1B:
+ assert plan.virtual_chunks == 1
diff --git a/vescale/pipe/pipe_parser.py b/vescale/pipe/pipe_parser.py
new file mode 100644
index 0000000..18cd180
--- /dev/null
+++ b/vescale/pipe/pipe_parser.py
@@ -0,0 +1,652 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+
+from typing import Sequence, Dict, List, Union, Any, Optional
+import torch
+import re
+import torch.nn as nn
+import logging
+from inspect import signature
+from vescale.pipe.tracer import ModelTracer, HFModelTracer, hf_symbolic_trace
+from torch.fx.passes.split_utils import split_by_tags
+from vescale.plan.pipeline_parallel import PipelineParallelPlan
+from vescale.plan.spec import PipelineSplitMethodType, TracerType
+
+NUM_DEFAULT_ARGS = 3
+
+try:
+ # New import path
+ from torch.export._trace import _export_to_torch_ir # noqa: PGH004
+except ImportError:
+ try:
+ # Old import path
+ from torch._export import _export_to_torch_ir # noqa: F401
+ except ImportError:
+ print("Could not import _export_to_torch_ir. Please make sure your PyTorch " "version is newer than 2.2.0.")
+
+
+logger = logging.Logger(__file__)
+
+
+class PipeParser:
+ def __init__(self):
+ self.orig_to_split_fqn_mapping = {}
+
+ def parse(
+ self, module: nn.Module, plan: Optional[PipelineParallelPlan] = None, **kwargs: Any
+ ) -> torch.fx.GraphModule:
+ """
+ Applies cascade trace capture using upstream torch.fx symbolic tracer, huggingface
+ tracer and dynamo export tracer respectively. To trigger cascade parser, select
+ TracerType.AUTO in PipelineParallelPlan's tracer_type field
+
+ Args:
+ module (nn.Module): the model from which we trace its forward execution graph.
+
+ Returns:
+ Model trace graph.
+
+ """
+ parser_args = {}
+ if plan and plan.smallest_unsplittable_units:
+ parser_args["partition_units"] = plan.smallest_unsplittable_units
+ if kwargs:
+ parser_args.update(kwargs)
+ try:
+ msg = "Applying Default torch.fx symbolic tracing..."
+ logger.info(msg)
+ traced = self.parse_torch_fx(module, **parser_args)
+ except Exception as e:
+ try:
+ msg = f"Default torch.fx symbolic tracing failed: {e}\nApplying HuggingFace Tracer..."
+ logger.warning(msg)
+ traced = self.parse_huggingface_fx(module, **parser_args)
+ except Exception as e2:
+ try:
+ msg = f"HuggingFace tracing failed: {e2}\nApplying Dynamo Export Tracer..."
+ logger.warning(msg)
+ traced = self.parse_dynamo_export(module, **parser_args)
+ except Exception as e3:
+ msg = f"Dynamo export tracing failed: {e3}"
+ logger.warning(msg)
+ raise e3
+ print(f"Below is visualization of the traced model graph:\n{traced}")
+ return traced
+
+ def partition_stage(
+ self, module: nn.Module, model_graph: torch.fx.GraphModule, plan: PipelineParallelPlan
+ ) -> List[str]:
+ """
+ Partitions models by split criterion. The function first annotates graph nodes and ops by stage
+ boundary, and then split stages into model partition modules (torch.fx.GraphModule).
+
+ Args:
+ module (nn.Module): the model.
+ model_graph (torch.fx.GraphModule): the trace graph of the model.
+ plan (PipelineParallelPlan): configuration of pipeline paralellism API.
+
+ Returns:
+ The executable trace graph partitioned by stage boundary,
+ and mappings of submodules before and after partition.
+
+ """
+ split_points = self.split(model_graph, plan)
+ plan.split_points = split_points
+ splited_graph = self.split_stage(model_graph, module, plan)
+ return splited_graph
+
+ def split(self, graph: torch.fx.GraphModule, plan: PipelineParallelPlan):
+ """
+ Generates or verifies pipeline split points, and writes updates to PipelineParallelPlan.
+
+ Args:
+ graph (torch.fx.GraphModule): symbolic trace graph of the entire model
+ plan (PipelineParallelPlan): configuration of attributes for pipeline parallleism API
+
+ Returns:
+ A list of fully qualified names of stage split points.
+
+ """
+ criterion = plan.split_method
+ boundaries = plan.split_points
+ nodes = list(graph.graph.nodes)
+ trimmed_nodes = nodes[1:-1] # remove input and output nodes in graph
+ node_names = [nd.name for nd in nodes]
+ trimmed_node_names = []
+ for nd in nodes[1:-1]:
+ if nd.op == "call_module":
+ trimmed_node_names.append(nd.target)
+ else:
+ trimmed_node_names.append(nd.name)
+ num_stages = plan.num_stages
+ num_chunk_per_stage = plan.virtual_chunks
+ num_model_partitions = num_stages * num_chunk_per_stage
+ nodes_size = len(trimmed_nodes)
+ trimmed_module_indices = [idx for idx in range(nodes_size) if trimmed_nodes[idx].op == "call_module"]
+ modules_only_size = len(trimmed_module_indices)
+ assert criterion in [
+ PipelineSplitMethodType.UNIFORM,
+ PipelineSplitMethodType.MANUAL,
+ PipelineSplitMethodType.AUTO,
+ PipelineSplitMethodType.PARAMETERS,
+ PipelineSplitMethodType.SIMULATOR,
+ PipelineSplitMethodType.FLOPS,
+ ]
+ if criterion == PipelineSplitMethodType.UNIFORM:
+ if plan.uniform_split_ops:
+ module_indices = self._partition_uniform(modules_only_size, num_model_partitions)
+ indices = [trimmed_module_indices[module_indices[idx]] for idx in range(len(module_indices))]
+ else:
+ indices = self._partition_uniform(nodes_size, num_model_partitions)
+ final_boundaries = []
+ for idx in indices:
+ if nodes[idx].op == "call_module" and trimmed_nodes[idx].name != trimmed_nodes[idx].target:
+ final_boundaries.append(trimmed_nodes[idx].name.replace("_", "."))
+ else:
+ final_boundaries.append(trimmed_nodes[idx].name)
+ plan.split_points = final_boundaries
+ elif criterion == PipelineSplitMethodType.MANUAL:
+ assert boundaries, "Must provide stage boundaries for MANUAL mode during stage partition!"
+ if boundaries and all(isinstance(x, str) for x in boundaries):
+ for fqn in boundaries:
+ assert (
+ fqn in node_names
+ or fqn.replace(".", "_") in node_names
+ or any(name.startswith(fqn) for name in node_names)
+ )
+ elif boundaries and all(isinstance(x, int) for x in boundaries):
+ # Under indexing-based partition, model graph's execution order is visualized as followed
+ boundaries.sort()
+ assert 0 <= boundaries[0] <= boundaries[-1] < len(nodes)
+ # convert submodule indices into fully qualified names
+ new_boundaries = []
+ for idx in boundaries:
+ if nodes[idx].op == "call_module":
+ new_boundaries.append(nodes[idx].name.replace("_", "."))
+ else:
+ new_boundaries.append(nodes[idx].name)
+ boundaries = new_boundaries
+ else:
+ raise ValueError("Input must be either a list of path strings or partition indices!")
+ if boundaries[-1] != node_names[-2]:
+ boundaries.append(node_names[-2])
+
+ final_boundaries = self._handle_virtual_stage_boundaries(
+ boundaries,
+ trimmed_node_names,
+ num_chunk_per_stage,
+ plan.enable_vpp_split_points,
+ )
+ # assert no stage boundary is a prefix of other boundaries
+ _boundaries = set(final_boundaries)
+ for this_bd in _boundaries:
+ for bd in _boundaries:
+ if this_bd != bd:
+ assert not this_bd.startswith(bd)
+ assert len(final_boundaries) == num_model_partitions
+ else:
+ raise NotImplementedError
+ return final_boundaries
+
+ def _partition_uniform(self, num_items, num_parts):
+ assert num_items % num_parts == 0, "#graph nodes must be partitioned by #stages!"
+ assert num_items >= num_parts, "#model partitions must not be less than #graph nodes!"
+ parts = [0] * (num_parts + 1)
+ # First check for the trivial edge case
+ if num_items <= num_parts:
+ for p in range(num_parts + 1):
+ parts[p] = min(p, num_items)
+ else:
+ chunksize = num_items // num_parts
+ residual = num_items - (chunksize * num_parts)
+ parts = torch.arange(0, (num_parts + 1) * chunksize, chunksize)
+ for i in range(residual):
+ parts[i + 1 :] += 1
+ parts = parts.tolist()
+ if parts[0] == 0:
+ parts = parts[1:]
+ parts = [x - 1 for x in parts]
+ return parts
+
+ def _handle_virtual_stage_boundaries(
+ self,
+ boundaries: List[Union[str, int]],
+ node_names: List[str],
+ num_chunk_per_stage: int,
+ use_manual_vpp_boundary: bool,
+ ):
+ if isinstance(boundaries[0], int):
+ boundaries = [node_names[idx] for idx in boundaries]
+ if num_chunk_per_stage > 1 and not use_manual_vpp_boundary:
+ new_indices = []
+ indices = list(range(len(node_names)))
+ raw_stage_indices = []
+ for fqn in boundaries:
+ if fqn not in node_names:
+ fqn = fqn.replace(".", "_")
+ raw_stage_indices.append(node_names.index(fqn))
+ if raw_stage_indices[-1] < len(node_names) - 1:
+ raw_stage_indices[-1].append(len(node_names) - 1)
+ for i in range(len(raw_stage_indices)):
+ if i == 0:
+ sublist = torch.tensor(indices[: raw_stage_indices[i] + 1])
+ else:
+ sublist = torch.tensor(indices[raw_stage_indices[i - 1] + 1 : raw_stage_indices[i] + 1])
+ assert (
+ len(sublist) >= num_chunk_per_stage
+ ), "#operators and modules in a stage must be no smaller than #virtual pipeline chunks!"
+ sublist_list = sublist.tensor_split(num_chunk_per_stage)
+ new_indices += [int(sub[-1]) for sub in sublist_list]
+ boundaries = [node_names[idx] for idx in new_indices]
+ return boundaries
+
+ def annotate_pipeline_stage(
+ self, graph: torch.fx.GraphModule, root_module: nn.Module, boundaries: List, partition_units: List
+ ):
+ """
+ Annotates stage split boundaries of each stage on the model graph.
+
+ Args:
+ graph (torch.fx.GraphModule): model trace graph
+ root_module (nn.Module): raw model
+ boundaries (List): a list of pipeline stage split points in the form of fully qualified names
+ partition_units (List): smallest unsplittable unit in a model trace graph
+
+ Returns:
+ Model graph with stage split points annotated.
+
+ """
+
+ def identify_base_units(submodule, partition_units, submodule_name):
+ return (
+ len(list(submodule.children())) == 0
+ or submodule_name in partition_units
+ or type(submodule) in partition_units
+ )
+
+ splited_module_names = boundaries
+ assert len(splited_module_names) > 0, "need to have bigger than 1 nodes"
+ max_dfn_for_modules = [0 for _ in range(len(splited_module_names))]
+ node_lists = list(graph.graph.nodes)
+ node_lists_names = [node.name for node in node_lists]
+ node_lists_target_names = [node.target for node in node_lists]
+ submodule_paths = {name for name, _ in root_module.named_modules()}
+ for stage_id, submodule_name in enumerate(splited_module_names):
+ stage_tag = stage_id
+ sub_module_unions = []
+ if submodule_name in node_lists_names:
+ boundary_node = node_lists[node_lists_names.index(submodule_name)]
+ else:
+ boundary_node = node_lists[node_lists_target_names.index(submodule_name)]
+ if submodule_name in submodule_paths:
+ submodule = root_module.get_submodule(submodule_name)
+ if identify_base_units(submodule, partition_units, submodule_name): # for leaf module
+ sub_module_unions.append(submodule_name)
+ else:
+ for name, _ in submodule.named_children():
+ sub_module_unions.append(submodule_name + "." + name)
+ sub_module_unions = [re.sub(r"\.", "_", name) for name in sub_module_unions]
+ else:
+ if boundary_node.op == "call_method" or boundary_node.op == "call_function":
+ sub_module_unions.append(boundary_node.name)
+ else:
+ raise ValueError(
+ "Stage boundary can only be of ``call_module``, ``call_method`` and ``call_function``!"
+ )
+ stage_max_dfn = 0
+ # set tag with the node Sequence, to O(N)
+ for dfn in range(len(node_lists)):
+ node = node_lists[dfn]
+ if node.name in sub_module_unions:
+ # TODO: tag should be partition_chunk{id} instead of stage, as it may lead to confusion in interleaved 1F1B schedules
+ node.tag = f"stage{str(stage_tag)}"
+ stage_max_dfn = max(dfn, stage_max_dfn)
+ max_dfn_for_modules[stage_id] = stage_max_dfn
+
+ # annotate the first stage
+ for dfn in range(len(node_lists)):
+ if dfn <= max_dfn_for_modules[0]:
+ node_lists[dfn].tag = "stage0"
+ else:
+ break
+
+ slow = 0
+ cur_dfn_num = 0
+ fast = max_dfn_for_modules[cur_dfn_num]
+ # using fast slow ptr to annotate graph
+
+ while fast < len(node_lists) and slow < len(node_lists):
+ while slow <= fast:
+ node_lists[slow].tag = node_lists[fast].tag
+ slow += 1
+ cur_dfn_num += 1
+ if cur_dfn_num < len(max_dfn_for_modules):
+ fast = max_dfn_for_modules[cur_dfn_num]
+ else:
+ while slow < len(node_lists):
+ node_lists[slow].tag = node_lists[fast].tag
+ slow += 1
+ return graph
+
+ def split_stage(
+ self, graph: torch.fx.GraphModule, root_module: nn.Module, plan: PipelineParallelPlan
+ ) -> torch.fx.GraphModule:
+ """
+ Split a model graph into multiple pipeline stage subgraphs.
+
+ Args:
+ graph (torch.fx.GraphModule): model graph
+ root_module (nn.Module): raw model
+ plan (PipelineParallelPlan): configuration of attributes for pipeline parallleism API
+
+ Returns:
+ Edited model graph that contains subgraph of each virtual module chunk of a pipeline stage.
+ For example,
+ ```
+ Before:
+ original_graph:
+ module1: xxx
+ module2: xxx
+ module3: xxx
+ module4: xxx
+
+ After:
+ split_graph:
+ stage0:
+ module1: xxx
+ module2: xxx
+ stage1:
+ module3: xxx
+ module4: xxx
+ ```
+
+ """
+ if graph is None:
+ return None
+
+ boundaries = plan.split_points
+ partition_units = plan.smallest_unsplittable_units
+ graph = self.annotate_pipeline_stage(graph, root_module, boundaries, partition_units)
+ tags = [f"stage{str(num)}" for num in range(len(boundaries))]
+ # split by PyTorch upstream's split_by_tags
+ split_graph, orig_to_split_fqn_mapping = split_by_tags(graph, tags, return_fqn_mapping=True)
+ for i in range(1, len(tags)):
+ # input placeholder node of each stage-specific graph
+ placeholder_node = list(getattr(split_graph, tags[i]).graph.nodes)[0]
+ if placeholder_node.op == "placeholder" and placeholder_node.name != "x":
+ placeholder_node.name = "x"
+
+ return split_graph
+
+ def parse_torch_fx(
+ self, model: nn.Module, partition_units: List[str] = None, shard_plan: Dict = None
+ ) -> torch.fx.GraphModule:
+ """
+ Applies torch.fx symbolic trace to capture model graph.
+
+ Args:
+ model (nn.Module): raw model
+ partition_units (List[str]): a list of smallest unsplittable modules such that the parser will
+ not flatten their underlying components during parsing
+ shard_plan (Dict): dictionary of sharding plan, if users would like to wrap up tensor parallelized
+ modules as unsplittable units
+
+ Returns:
+ Captured torch.fx.GraphModule
+
+ """
+ if partition_units is None:
+ partition_units = []
+ input_names = list(signature(model.forward).parameters.keys())
+ if "input_ids" in input_names and "inputs_embeds" in input_names:
+ input_names.remove("inputs_embeds")
+ if shard_plan:
+ hierarchy_substructure_qualified_names = self._hierarchy_structure_names(model, shard_plan)
+ partition_units += hierarchy_substructure_qualified_names
+ traced: torch.fx.GraphModule = hf_symbolic_trace(
+ model,
+ input_names=input_names,
+ disable_check=True,
+ tracer_cls=ModelTracer,
+ partition_modules=partition_units,
+ )
+ return traced
+
+ def parse_dynamo_export(self, model: nn.Module, *args: Sequence, **kwargs: Dict):
+ """
+ Applies capture model graph with torch dynamo.export.
+
+ Args:
+ model (nn.Module): raw model
+
+ Returns:
+ Captured torch.fx.GraphModule
+
+ """
+ traced: torch.fx.GraphModule = _export_to_torch_ir(model, args=args, kwargs=kwargs)
+ return traced
+
+ def parse_huggingface_fx(
+ self, model, partition_units: List[str] = None, shard_plan: Dict = None, default_settings: bool = True
+ ):
+ """
+ Applies symbolic trace with huggingface-like fx.
+
+ Args:
+ model (nn.Module): raw model
+ partition_units (List[str]): a list of smallest unsplittable modules such that the parser will
+ not flatten their underlying components during parsing
+ shard_plan (Dict): dictionary of sharding plan, if users would like to wrap up tensor parallelized
+ modules as unsplittable units
+
+ Returns:
+ Captured torch.fx.GraphModule
+
+ """
+ if partition_units is None:
+ partition_units = []
+ input_arguments = signature(model.forward).parameters.keys()
+ # parser flattens module hierachy during parse. Maintain hierachy so that it can still be accessed by a sharding plan
+ if shard_plan:
+ hierarchy_substructure_qualified_names = self._hierarchy_structure_names(model, shard_plan)
+ partition_units += hierarchy_substructure_qualified_names
+ input_names = list(input_arguments)
+ if default_settings:
+ default_input_names, default_unit_modules = self._default_parse_info(model, input_names)
+ if default_input_names:
+ input_names = default_input_names
+ if default_unit_modules:
+ partition_units = default_unit_modules
+ if "input_ids" in input_names and "inputs_embeds" in input_names:
+ # two arguments cannot occur simultanenously
+ input_names.remove("inputs_embeds")
+ input_names = input_names[:NUM_DEFAULT_ARGS]
+ traced: torch.fx.GraphModule = hf_symbolic_trace(
+ model,
+ input_names=input_names,
+ disable_check=True,
+ tracer_cls=HFModelTracer,
+ partition_modules=partition_units,
+ )
+ return traced
+
+ def _hierarchy_structure_names(self, model, shard_plan):
+ modules_to_maintain_hierarchy = set()
+ self._collect_hierachical_modules_paths(model, shard_plan["forward"], modules_to_maintain_hierarchy)
+ self._collect_hierachical_modules_paths(model, shard_plan["parameter"], modules_to_maintain_hierarchy)
+ return modules_to_maintain_hierarchy
+
+ def _collect_hierachical_modules_paths(self, model, plan_dict, module_paths):
+ for path_to_submodule, _ in model.named_modules():
+ for plan_fqn in plan_dict:
+ pattern = plan_fqn.rsplit(".", 1)[0]
+ if (
+ re.match(pattern, path_to_submodule)
+ and len(list(model.get_submodule(path_to_submodule).children())) != 0
+ ):
+ module_paths.add(path_to_submodule)
+
+ def _locate_module_classes(self, model, paths_to_submodules):
+ if paths_to_submodules is None:
+ return paths_to_submodules
+ visited = set(paths_to_submodules)
+ submodule_classes = set()
+ for name, submodule in model.named_modules():
+ if name in visited:
+ submodule_classes.add(type(submodule))
+ return list(submodule_classes)
+
+ def _default_parse_info(self, model, input_names, num_default_args=3):
+ from transformers.models.whisper.modeling_whisper import WhisperModel
+ from transformers.models.mixtral.modeling_mixtral import (
+ MixtralModel,
+ MixtralRMSNorm,
+ MixtralSparseMoeBlock,
+ MixtralAttention,
+ )
+ from transformers.models.biogpt.modeling_biogpt import BioGptModel, BioGptAttention
+ from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Model, DisentangledSelfAttention
+ from transformers.models.marian.modeling_marian import MarianModel, MarianAttention, MarianEncoderLayer
+ from transformers.models.blenderbot.modeling_blenderbot import (
+ BlenderbotModel,
+ BlenderbotAttention,
+ BlenderbotEncoderLayer,
+ )
+ from transformers.models.layoutlmv3.modeling_layoutlmv3 import LayoutLMv3Model, LayoutLMv3SelfAttention
+ from transformers.models.phi.modeling_phi import PhiModel, PhiAttention
+ from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXModel, GPTNeoXAttention
+ from transformers.models.falcon.modeling_falcon import FalconModel, FalconAttention
+ from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeModel, GPTBigCodeAttention
+ from transformers.models.vit.modeling_vit import ViTModel, ViTEmbeddings, ViTSelfAttention
+ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model, Wav2Vec2Attention
+ from transformers.models.speecht5.modeling_speecht5 import SpeechT5Model, SpeechT5Attention
+ from transformers.models.bloom.modeling_bloom import BloomModel, BloomAttention
+
+ model_type = type(model)
+ input_names = partition_unit_classes = None
+ if model_type == MixtralModel:
+ partition_unit_classes = [MixtralRMSNorm, MixtralSparseMoeBlock, MixtralAttention]
+ elif model_type == BioGptModel:
+ partition_unit_classes = [BioGptAttention]
+ elif model_type == DebertaV2Model:
+ partition_unit_classes = [DisentangledSelfAttention]
+ elif model_type == MarianModel:
+ partition_unit_classes = [MarianAttention, MarianEncoderLayer]
+ elif model_type == BlenderbotModel:
+ partition_unit_classes = [BlenderbotAttention, BlenderbotEncoderLayer]
+ elif model_type == LayoutLMv3Model:
+ partition_unit_classes = [LayoutLMv3SelfAttention]
+ elif model_type == PhiModel:
+ partition_unit_classes = [PhiAttention]
+ elif model_type == GPTNeoXModel:
+ partition_unit_classes = [GPTNeoXAttention]
+ elif model_type == FalconModel:
+ partition_unit_classes = [FalconAttention]
+ elif model_type == GPTBigCodeModel:
+ partition_unit_classes = [GPTBigCodeAttention]
+ elif model_type == ViTModel:
+ partition_unit_classes = [ViTEmbeddings, ViTSelfAttention]
+ elif model_type == Wav2Vec2Model:
+ partition_unit_classes = [Wav2Vec2Attention]
+ elif model_type == SpeechT5Model:
+ partition_unit_classes = [SpeechT5Attention]
+ elif model_type == BloomModel:
+ input_names = ["attention_mask", "head_mask", "inputs_embeds"]
+ partition_unit_classes = [BloomAttention]
+ elif model_type == WhisperModel:
+ input_names = ["input_features", "decoder_input_ids"]
+
+ if input_names:
+ input_names = input_names[:num_default_args]
+ return input_names, partition_unit_classes
+
+
+def parse_model_graph(parser: PipeParser, model: nn.Module, plan: PipelineParallelPlan) -> torch.fx.GraphModule:
+ """
+ Pipeline Parallelism API that performs parsing given tracer types.
+
+ Args:
+ parser (PipeParser): model parser
+ model (nn.Module): raw model
+ plan (PipelineParallelPlan): configuration of pipeline paralellism API.
+
+ Returns:
+ Captured torch.fx.GraphModule
+
+ """
+ tracer_type = plan.tracer_type
+ tracer_kwargs = plan.tracer_kwargs
+ if tracer_kwargs is None:
+ tracer_kwargs = {}
+ if tracer_type == TracerType.AUTO:
+ model_graph = parser.parse(model, plan)
+ else:
+ if "partition_units" not in tracer_kwargs and tracer_type in [TracerType.TORCH_FX, TracerType.HF_FX]:
+ tracer_kwargs["partition_units"] = plan.smallest_unsplittable_units
+ if tracer_type == TracerType.TORCH_FX:
+ model_graph = parser.parse_torch_fx(model, **tracer_kwargs)
+ elif tracer_type == TracerType.HF_FX:
+ model_graph = parser.parse_huggingface_fx(model, **tracer_kwargs)
+ elif tracer_type == TracerType.TORCH_DYNAMO:
+ model_graph = parser.parse_dynamo_export(model, **tracer_kwargs)
+ else:
+ raise NotImplementedError(f"Logic of tracer {tracer_type} has not been implemented yet.")
+ return model_graph
+
+
+def split_pipeline_point(model: nn.Module, plan: PipelineParallelPlan):
+ """
+ Pipeline Parallelism API that updates pipeline stage split points.
+
+ Args:
+ model (nn.Module): raw model
+ plan (PipelineParallelPlan): configuration of pipeline paralellism API.
+
+ Returns:
+ Captured torch.fx.GraphModule.
+
+ """
+ # obtain the traced graph of entire model if pipeline parallelism is on
+ parser = PipeParser()
+ model_graph = parse_model_graph(parser, model, plan)
+ split_points = parser.split(model_graph, plan)
+ plan.split_points = split_points
+ return split_points, model_graph, parser
+
+
+def construct_pipeline_split_graph(model: nn.Module, plan: PipelineParallelPlan, update_split_points: bool = False):
+ """
+ Pipeline Parallelism API that performs pipeline stage split.
+
+ Args:
+ model (nn.Module): raw model
+ plan (PipelineParallelPlan): configuration of pipeline paralellism API.
+ update_split_points (bool): set this switch on to update pipeline split points in-place.
+
+ Returns:
+ Captured torch.fx.GraphModule.
+
+ """
+ parser = PipeParser()
+ model_graph = parse_model_graph(parser, model, plan)
+ if update_split_points:
+ split_points = parser.split(model_graph, plan)
+ plan.split_points = split_points
+ # partition model graph into virtual pipeline chunks per stage
+ split_graph = parser.split_stage(model_graph, model, plan)
+ return split_graph
diff --git a/vescale/pipe/pipe_stage.py b/vescale/pipe/pipe_stage.py
new file mode 100644
index 0000000..9e91f56
--- /dev/null
+++ b/vescale/pipe/pipe_stage.py
@@ -0,0 +1,563 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+"""
+This `PipeModule` Class is the abstraction of a pipeline stage.
+
+PipeModule takes as microbatch input 1). List of data per microbatch, 2). Dictionary of data per microbatch, 3). torch.Tensor
+
+PipeModule takes both 1). p2p transmitted data from incoming stages, and 2). local data inputs
+
+Each Pipeline stage can run single batch data forward, just like nn.Modules, we can
+use forward functions and new p2p ops to replement pipeline forward and backward
+
+For Example 1.
+ ```python
+ stage: PipeModule = ...
+ single_data = ... # a single microbatch of data
+ fwd = stage(single_data)
+ p2p_send_recv( ... )
+ ```
+
+For Example 2.
+ ```python
+ stage: PipeModule = ...
+ p2p_data = ... # a torch.Tensor from last stage
+ local_data = Dict(...) # a single microbatch of data
+ fwd = stage(p2p_data, local_inputs=local_data)
+ p2p_send_recv( ... )
+ ```
+
+"""
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import numpy as np
+import inspect
+import re
+from typing import Dict, List, Tuple, Union, Optional, Sequence, Callable, Any
+from vescale.optim.base_optimizer import BasicOptimizer
+from vescale.optim.distributed_optimizer import DistributedOptimizer
+from vescale.devicemesh_api.api import VeDeviceMesh
+from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
+from vescale.dtensor.dtensor import DTensor
+from vescale.plan import PipelineParallelPlan, PipelineP2PSpec
+from vescale.pipe.pipe_parser import construct_pipeline_split_graph
+from collections import defaultdict
+
+
+class PipeModule(nn.Module):
+ def __init__(
+ self,
+ module: Union[nn.Module, List],
+ doptimizer: Union[BasicOptimizer, DistributedOptimizer],
+ lr_scheduler: Callable,
+ stage_deps: np.ndarray,
+ p2p_index_mapping: Dict,
+ config: PipelineParallelPlan,
+ ):
+ super().__init__()
+ self.stage_modules = {}
+ if isinstance(module, List):
+ for i in range(len(module)):
+ self.stage_modules[i] = module[i]
+ else:
+ self.stage_modules[0] = module
+ self.doptimizer = doptimizer
+ self.lr_scheduler = lr_scheduler
+ self.shared_module_process_groups = defaultdict()
+ self.sync_chunk_ids = set()
+ self.shared_path_this_stage = {}
+ self.shared_module_mapping = {}
+ self.config = config
+ self.num_stages = self.config.num_stages
+ self.virtual_chunks = self.config.virtual_chunks
+ self.stage_deps = stage_deps
+ self.p2p_index_mapping = p2p_index_mapping
+
+ def forward(
+ self,
+ inputs: Union[torch.Tensor, List, Dict],
+ local_inputs: Union[torch.Tensor, List, Dict] = None,
+ chunk_id: int = 0,
+ ):
+ """
+ Forward propagation function of a pipeline stage. This function processes inputs to model chunks from p2p data transfers
+ and local dataloaders.
+
+ Note:
+ - inputs (Union[torch.Tensor, List, Dict]): transmitted data received from another pipeline stage
+ - local_inputs (Union[torch.Tensor, List, Dict]): optional input of local data
+ - chunk_id (int): identifier of dictating what virtual model chunk to execute in interleaved 1f1b schedule.
+ If it is the simple 1f1b schedule, chunk_id=0.
+
+ Args:
+ inputs (torch.Tensor, list, dict): inputs fed into model partition module.
+ local_inputs (torch.Tensor, list, dict): local inputs from dataloaders, used when executing pipeline schedule.
+
+ Returns:
+ Output activations.
+
+ """
+ chunk_module = self.stage_modules[chunk_id]
+ if local_inputs is None:
+ if isinstance(inputs, list):
+ return chunk_module(*inputs)
+ elif isinstance(inputs, dict):
+ return chunk_module(**inputs)
+ elif inputs is None:
+ return chunk_module()
+ else:
+ return chunk_module(inputs)
+ else:
+ combined_data = self._prepare_inputs(chunk_module, inputs, local_inputs)
+ return chunk_module(**combined_data)
+
+ def _prepare_inputs(self, module, inputs, local_inputs=None):
+ fwd = module.module.forward if isinstance(module, DDP) else module.forward
+ sig = inspect.signature(fwd)
+ arguments = list(sig.parameters.keys())
+ dict_inputs = self._prepare_data_formats(arguments, inputs)
+ dict_local_inputs = self._prepare_data_formats(arguments, local_inputs)
+ final_inputs = {}
+ for key in arguments:
+ input_val, local_val = dict_inputs.get(key), dict_local_inputs.get(key)
+ if input_val is not None:
+ final_inputs[key] = input_val
+ elif local_val is not None:
+ final_inputs[key] = local_val
+ elif sig.parameters[key].default is not inspect.Parameter.empty:
+ final_inputs[key] = sig.parameters[key].default.default
+ return final_inputs
+
+ def _prepare_data_formats(self, keys, data):
+ if data is None or isinstance(data, Sequence) and len(data) == 1 and data[0] is None:
+ if keys:
+ return {keys[0]: None}
+ return None
+ if isinstance(data, torch.Tensor):
+ data = [data]
+ if isinstance(data, Sequence):
+ args_length = min(len(data), len(keys))
+ data = {keys[i]: data[i] for i in range(args_length)}
+ return data
+
+ def __getitem__(self, module_chunk_id: int):
+ assert module_chunk_id in self.stage_modules, "Virtual chunk id not existed!"
+ return self.stage_modules[module_chunk_id]
+
+ @property
+ def get_optimizer(self):
+ return self.doptimizer
+
+ @property
+ def get_lr_scheduler(self):
+ return self.lr_scheduler
+
+ def parameters(self):
+ parameters = []
+ for chunk_id in range(self.virtual_chunks):
+ parameters += list(self.stage_modules[chunk_id].parameters())
+ return parameters
+
+ def has_shared_params(self, global_mesh: VeDeviceMesh, group_id: int, tp_rank: int) -> bool:
+ """
+ Checks whether this stage has submodules to synchronize parameters or gradients.
+ An additional use case of this function is to dictate if a submodule's shared parameter
+ (invoked by self.get_shared_module()) participates in grad norm clipping.
+
+ Args:
+ global_mesh (VeDeviceMesh): global DeviceMesh with which one looks up communication information.
+ group_id (int): specify groups of modules across stages to synchronize. Default by 0.
+ tp_rank (int): tensor model parallel rank of current stage.
+
+ Returns:
+ whether a stage contains sharable parameters
+
+ """
+ local_rank = global_mesh.get_local_rank()
+ return not (
+ not self.shared_module_process_groups
+ or tp_rank not in self.shared_module_process_groups[group_id]
+ or local_rank not in dist.get_process_group_ranks(self.shared_module_process_groups[group_id][tp_rank])
+ )
+
+ def sync_shared_params(
+ self, global_mesh: VeDeviceMesh, group_id: int = 0, share_params: bool = True, chunk_id: int = 0
+ ):
+ """
+ Synchronize parameters of reused modules e.g.
+ Embedding. This function is invoked in each run of PP schedule.
+
+ Args:
+ global_mesh (VeDeviceMesh): global DeviceMesh with which one looks up communication information.
+ group_id (int): specify groups of modules across stages to synchronize. Default by 0.
+ share_params (bool): if True, sync weight parameters; otherwise, share gradients.
+ chunk_id (int): identify if current virtual model chunk in this stage has any module to synchronize.
+
+ """
+ tp_rank = global_mesh.get_tensor_parallel_rank()
+ if (
+ not self.has_shared_params(global_mesh, group_id=group_id, tp_rank=tp_rank)
+ or chunk_id not in self.sync_chunk_ids
+ ):
+ return
+ # assume that each model chunk has at most 1 sharable sub-module per shared group
+ shared_submodule_path = self.shared_path_this_stage[(group_id, chunk_id)]
+ model_chunk = self.stage_modules[chunk_id]
+ if isinstance(model_chunk, DDP):
+ model_chunk = model_chunk.module
+ target_module = model_chunk.get_submodule(shared_submodule_path)
+ if getattr(target_module, "get_word_embeddings_weight", None):
+ target_module = target_module.get_word_embeddings_weight()
+
+ # assume tp coordinate is always the last dimension
+ sync_group = self.shared_module_process_groups[group_id][tp_rank]
+ group_size = dist.get_world_size(group=sync_group)
+
+ if share_params:
+ if isinstance(target_module.data, DTensor):
+ dist.all_reduce(target_module.data._local_tensor, group=sync_group)
+ else:
+ dist.all_reduce(target_module.data, group=sync_group)
+ target_module.data /= group_size
+ else:
+ # if type is DTensr, then do local_tensor.grad
+ if target_module.grad is not None:
+ target_module.grad.data /= group_size
+ dist.all_reduce(target_module.grad.data, group=sync_group)
+ else: # DDP Module
+ target_module.main_grad /= group_size
+ dist.all_reduce(target_module.main_grad, group=sync_group)
+
+
+def construct_stage_modules(
+ model: nn.Module,
+ plan: PipelineParallelPlan,
+ global_mesh: VeDeviceMesh,
+ update_split_points: bool = False,
+):
+ """
+ Pipeline Parallelism API that constructs ingredients for building PipelineModule.
+
+ Args:
+ model (nn.Module): raw model
+ plan (PipelineParallelPlan): configuration of pipeline paralellism API.
+ update_split_points (bool): set this switch on to update pipeline split points in-place.
+
+ Returns:
+ Triplet of 1). list of modules in a pipeline stage, 2). abstraction of send-receive dependency relationship
+ among stages, 3). P2P input index mapping.
+
+ """
+ num_stages = plan.num_stages
+ virtual_chunks = plan.virtual_chunks
+ split_graph = construct_pipeline_split_graph(model, plan, update_split_points=update_split_points)
+
+ # assign modules to stage, establish stage dependency and input mapping
+ stage_modules, stage_dependency, p2p_index_mapping = build_stage_module_and_dependency(
+ split_graph,
+ num_stages,
+ virtual_chunks,
+ stage_id=global_mesh.get_pipeline_parallel_rank(),
+ )
+ submodules_this_stage = []
+ for chunk_id in range(len(stage_modules)):
+ submodules_this_stage.append(stage_modules[chunk_id])
+ return submodules_this_stage, stage_dependency, p2p_index_mapping
+
+
+def construct_pipeline_stage(
+ model: nn.Module,
+ plan: PipelineParallelPlan,
+ global_mesh: VeDeviceMesh,
+ lr_scheduler: Optional[Union[Callable, Tuple[Callable, Any]]] = None,
+ update_split_points: bool = False,
+):
+ """
+ Pipeline Parallelism API that constructs PipeModule from the raw model.
+
+ Args:
+ model (nn.Module): raw model.
+ plan (PipelineParallelPlan): configuration of pipeline paralellism API.
+ lr_scheduler (Optional[Union[Callable, Tuple[Callable, Any]]]): learning rate scheduler.
+ update_split_points (bool): set this switch on to update pipeline split points in-place.
+
+ Returns:
+ Pipeline stage.
+
+ """
+ stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules(
+ model, plan, global_mesh, update_split_points
+ )
+ return PipeModule(stage_modules, None, lr_scheduler, stage_dependency, p2p_index_mapping, plan)
+
+
+def build_shared_module_group(
+ pipe_module: PipeModule,
+ split_graph: torch.fx.GraphModule,
+ num_stages: int,
+ virtual_chunks: int,
+ shared_module_path_groups: List[List],
+ global_mesh: VeDeviceMesh,
+):
+ """
+ Pipeline Parallelism API that establishes groups of modules which
+ synchronize parameters or gradients amongst one another.
+
+ Args:
+ pipe_module (PipeModule): pipeline stage to assign synchronzied mapping.
+ split_graph (torch.fx.GraphModule): the global model graph split into stages.
+ num_stages (int): number of pipeline stages.
+ virtual_chunks (int): number of virtual pipeline stage chunks in a stage.
+ shared_module_path_groups (List[List]): list of groups of module fully qualified names,
+ where modules in the same group synchronizes parameters or gradients.
+ global_mesh (VeDeviceMesh): global DeviceMesh with which one looks up communication information.
+
+ Returns:
+ Tuple of 1). a dictionary of shared group items, 2). a dictionary of shared group this stage is involved
+ 3). synchronized model chunk ids, and 4). path to the shared submodule, if applicable.
+
+ """
+ shared_module_process_groups = defaultdict()
+ shared_module_mapping = {}
+ sync_chunk_ids = set()
+ shared_path_this_stage = {}
+ module_partition_names_by_stage = [[] for _ in range(num_stages)]
+ num_model_partitions = num_stages * virtual_chunks
+ for j in range(num_model_partitions):
+ module_partition_names_by_stage[j % num_stages].append(f"stage{j}")
+ stage_id = global_mesh.get_pipeline_parallel_rank()
+ # establish process groups of synchronizing shared embeddings
+ if shared_module_path_groups:
+ shared_module_process_groups, shared_module_mapping, shared_info = _establish_shared_module_groups(
+ num_stages,
+ virtual_chunks,
+ module_partition_names_by_stage,
+ split_graph,
+ shared_module_path_groups,
+ global_mesh,
+ )
+ for group_id, group in enumerate(shared_info):
+ for _stage_id, chunk_id, path in group:
+ if _stage_id == stage_id:
+ sync_chunk_ids.add(chunk_id)
+ shared_path_this_stage[(group_id, chunk_id)] = path
+ pipe_module.shared_module_process_groups = shared_module_process_groups
+ pipe_module.shared_module_mapping = shared_module_mapping
+ pipe_module.sync_chunk_ids = sync_chunk_ids
+ pipe_module.shared_path_this_stage = shared_path_this_stage
+ return shared_module_process_groups, shared_module_mapping, sync_chunk_ids, shared_path_this_stage
+
+
+def build_stage_module_and_dependency(
+ split_graph: torch.fx.GraphModule,
+ num_stages: int,
+ virtual_chunks: int,
+ stage_id: int,
+):
+ """
+ Establishes sub-modules of the same stage as well as the send-receive relationship among stages.
+
+ Args:
+ split_graph (torch.fx.GraphModule): the global model graph split into stages.
+ num_stages (int): number of pipeline stages.
+ virtual_chunks (int): number of virtual pipeline stage chunks in a stage.
+ stage_id (int): pipeline stage id.
+
+ Returns:
+ Submodules of a pipeline stage, inter-stage dependency, and P2P input mapping.
+
+ """
+ # generate inter-stage communication dependency and communication mapping
+ stage_dependency, p2p_index_mapping = _generate_stage_dependencies(split_graph, num_stages, virtual_chunks)
+ # build sub-modules belonging to the current pipeline stage
+ stage_modules = _build_module(split_graph, num_stages, virtual_chunks, stage_id)
+ return stage_modules, stage_dependency, p2p_index_mapping
+
+
+def _generate_stage_dependencies(graph: torch.fx.GraphModule, num_stage: int, virtual_chunks: int):
+ """
+ Generates inter-stage dependency and P2P index mapping across stages.
+
+ Args:
+ graph (torch.fx.GraphModule): the whole trace graph of the model.
+
+ Returns:
+ Mapping of inter-stage dependency and p2p index mapping.
+
+ """
+ stage_to_chunk_mapping = _get_stage_to_chunk_mapping(virtual_chunks, num_stage)
+ _stage_to_chunk_mapping = {}
+ for stage_id, partition_ids in stage_to_chunk_mapping.items():
+ for part_id in partition_ids:
+ _stage_to_chunk_mapping[part_id] = stage_id
+ stage_to_chunk_mapping = _stage_to_chunk_mapping
+
+ stage_rule = r"stage\d+"
+ stage2node = {}
+ for node in graph.graph.nodes:
+ if re.match(stage_rule, node.name):
+ stage2node.update({node.name: node})
+
+ stage_deps = np.zeros((num_stage, num_stage))
+ for node_name, node in stage2node.items():
+ partition_id = int(node_name[5:])
+ stage_id = stage_to_chunk_mapping[partition_id]
+ node_user = node.users.keys()
+ for u_node in node_user:
+ if u_node.name in stage2node:
+ u_id = int(u_node.name[5:])
+ target_stage_id = stage_to_chunk_mapping[u_id]
+ if stage_deps[target_stage_id][stage_id] or stage_id == num_stage - 1:
+ # no recurring edge!
+ continue
+ stage_deps[stage_id][target_stage_id] = 1
+
+ # construct p2p index mapping
+ p2p_index_mapping = {}
+ for node_name, node in stage2node.items():
+ partition_id = int(node_name[5:])
+ stage_id = stage_to_chunk_mapping[partition_id]
+ args_mapping = []
+ for input_id, arg_node in enumerate(node.args):
+ if arg_node.name in stage2node:
+ arg_partition_id = int(arg_node.name[5:])
+ arg_stage_id = stage_to_chunk_mapping[arg_partition_id]
+ args_mapping.append(PipelineP2PSpec(arg_stage_id, input_id))
+ else: # should from local
+ args_mapping.append(PipelineP2PSpec(stage_id, input_id))
+ p2p_index_mapping.update({stage_id: args_mapping})
+
+ return stage_deps, p2p_index_mapping
+
+
+def _establish_shared_module_groups(
+ num_stage,
+ virtual_chunks,
+ module_partition_names_by_stage,
+ split_graph,
+ shared_module_path_groups,
+ global_mesh: VeDeviceMesh,
+):
+ """
+ Identify groups of modules to share gradients/weights, e.g. embedding layers
+ upon initialization and at the end of a pipeline schedule run.
+ """
+ all_named_modules = [[] for _ in range(num_stage)]
+ for stage_id in range(num_stage):
+ for chunk_id in range(virtual_chunks):
+ key_name = module_partition_names_by_stage[stage_id][chunk_id]
+ module_graph = split_graph.get_submodule(key_name)
+ all_named_modules[stage_id].append({name for name, _ in module_graph.named_modules()})
+
+ shared_module_paths = [[] for _ in range(len(shared_module_path_groups))]
+ for idx, shared_module_group in enumerate(shared_module_path_groups):
+ for module_path in shared_module_group:
+ stage_id, chunk_id = _locate_shared_module(module_path, all_named_modules, num_stage, virtual_chunks)
+ shared_module_paths[idx].append((stage_id, chunk_id, module_path))
+ shared_stages_groups = [
+ [stage for stage, _, _ in shared_module_paths[idx]] for idx in range(len(shared_module_path_groups))
+ ]
+
+ all_tp_submeshes = global_mesh.get_global_tensor_parallel_meshes()
+ # TODO: in future, keep track of multiple groups of shared modules
+ all_tp_groups = []
+ map_id = 0
+ for dm in all_tp_submeshes:
+ mesh_list = dm.mesh.tolist()
+ converted_pp_ranks = [global_mesh.get_strategy_coordinate(_idx)[0] for _idx in mesh_list]
+ assert all(i == converted_pp_ranks[0] for i in converted_pp_ranks)
+ for pp_rank in shared_stages_groups[map_id]:
+ if pp_rank == converted_pp_ranks[0]:
+ all_tp_groups.append(mesh_list)
+ break
+
+ shared_tp_comm_groups = list(zip(*all_tp_groups))
+ shared_module_process_groups = defaultdict(dict)
+ shared_module_mapping = {}
+ shared_module_mapping[map_id] = shared_stages_groups[map_id]
+ for tp_idx, shared_group in enumerate(shared_tp_comm_groups):
+ sync_embed_pg = dist.new_group(ranks=shared_group, backend="nccl")
+ shared_module_process_groups[map_id][tp_idx] = sync_embed_pg
+ return shared_module_process_groups, shared_module_mapping, shared_module_paths
+
+
+def _locate_shared_module(module_path, all_named_modules, num_stage, virtual_chunks):
+ for stage_id in range(num_stage):
+ for chunk_id in range(virtual_chunks):
+ if module_path in all_named_modules[stage_id][chunk_id]:
+ return stage_id, chunk_id
+ raise ValueError(f"Module to be synchronized not found: {module_path}")
+
+
+def _build_model_chunks(stage_id, model_graph, mapping):
+ assert stage_id in mapping
+ pipeline_chunks = {}
+ unique_id = 0
+ for chunk_id, partition_id in enumerate(mapping[stage_id]):
+ key = f"stage{partition_id}"
+ virtual_pipeline_module = getattr(model_graph, key)
+ # assign unique id for each low-level submodule
+ for _, submodule in virtual_pipeline_module.named_modules():
+ if len(list(submodule.children())) == 0:
+ registered_module_id = f"module_{stage_id}_{chunk_id}_{unique_id}"
+ virtual_pipeline_module.module_id = registered_module_id
+ unique_id += 1
+ pipeline_chunks[chunk_id] = virtual_pipeline_module
+ return pipeline_chunks
+
+
+def _build_module(model_graph: torch.fx.GraphModule, num_stages: int, num_model_chunks: int, stage_id: int):
+ """
+ Builds model chunks by stage, and assigns unique submodule id to every basic modules.
+
+ Args:
+ model_graph (torch.fx.GraphModule): the model trace graph with stage partitions.
+ num_stages (int): number of pipeline stages.
+ num_model_chunks (int): number of virtual pipeline chunks per stage.
+ dist_api (VeDeviceMesh): an object of DeviceMesh API.
+
+ Returns:
+ Mapping of chunk id to model partitions of the current stage.
+
+ """
+ stage_to_chunk = _get_stage_to_chunk_mapping(num_model_chunks, num_stages)
+ return _build_model_chunks(stage_id, model_graph, stage_to_chunk)
+
+
+def _get_stage_to_chunk_mapping(num_model_chunks, num_stages):
+ """
+ Gets a mapping from stage id to model partition ids.
+
+ Args:
+ num_model_chunks (int): number of virtual pipeline chunks per stage.
+ num_stages (int): number of pipeline stages.
+
+ Returns:
+ Mapping from stages to their model chunks.
+
+ """
+ if num_model_chunks == 1:
+ stage_to_chunk = {i: [i] for i in range(num_stages)}
+ else:
+ length = num_stages * num_model_chunks
+ stage_to_chunk = {i: [] for i in range(num_stages)}
+ for i in range(length):
+ stage_to_chunk[i % num_stages].append(i)
+ return stage_to_chunk
diff --git a/vescale/pipe/tracer.py b/vescale/pipe/tracer.py
new file mode 100644
index 0000000..85ded75
--- /dev/null
+++ b/vescale/pipe/tracer.py
@@ -0,0 +1,709 @@
+################################################################################
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
+################################################################################
+
+import torch
+import torch.nn as nn
+import torch.fx as fx
+import collections
+import warnings
+import math
+import inspect
+from torch.fx import Tracer, Graph, Proxy, GraphModule
+from torch.fx.proxy import ParameterProxy
+from transformers.utils.fx import (
+ _proxies_to_metas,
+ _generate_random_int,
+ check_if_model_is_supported,
+ _FX_SUPPORTED_MODELS_WITH_KV_CACHE,
+ _IS_IN_DEBUG_MODE,
+ _MANUAL_META_OVERRIDES,
+ HFProxy,
+ HFAttribute,
+ HFTracer,
+)
+
+try:
+ from transformers.utils.fx import _gen_constructor_wrapper
+except Exception as e:
+ warnings.warn("Util path changed. Now load from a new path")
+ from transformers.utils.fx import gen_constructor_wrapper as _gen_constructor_wrapper
+
+from transformers.utils.import_utils import (
+ TORCH_FX_REQUIRED_VERSION,
+ get_torch_version,
+ is_torch_fx_available,
+ is_peft_available,
+)
+from torch.fx._compatibility import compatibility
+from transformers.modeling_utils import PreTrainedModel
+from typing import Any, Callable, Dict, List, Optional, Union, Sequence, Type
+from transformers.models.auto import get_values
+from transformers.models.auto.modeling_auto import (
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_BACKBONE_MAPPING_NAMES,
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_CTC_MAPPING_NAMES,
+ MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
+ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
+ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
+ MODEL_FOR_PRETRAINING_MAPPING_NAMES,
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
+)
+
+if is_peft_available():
+ from peft import PeftModel
+
+
+_IS_PARTITION_MODULE = "PARTITION"
+
+
+class ModelTracer(fx.Tracer):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def is_leaf_module(self, m, module_qualified_name):
+ return (
+ m.__module__.startswith("torch.nn")
+ or m.__module__.startswith("torch.ao.nn")
+ or hasattr(m, _IS_PARTITION_MODULE)
+ ) and not isinstance(m, torch.nn.Sequential)
+
+
+class HFModelTracer(Tracer):
+ """
+ Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
+ regular PyTorch torch.fx.Proxy.
+ """
+
+ # Feature flag for proxying accesses to buffer values
+ proxy_buffer_attributes: bool = True
+ allow_insert_stateless_mods: bool = True
+ _TORCH_METHODS_TO_PATCH = [
+ "arange",
+ "zeros",
+ "ones",
+ "full",
+ "full_like",
+ "eye",
+ "empty",
+ "tensor",
+ "clamp",
+ "finfo",
+ ]
+ supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
+
+ def __init__(self, autowrap_modules=(math,), autowrap_functions=(), partition_modules=None):
+ super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
+
+ if not is_torch_fx_available():
+ raise ImportError(
+ f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
+ f"{TORCH_FX_REQUIRED_VERSION} is supported."
+ )
+
+ self.visited_partition_module_paths = set()
+ self.partition_module_classes_and_fqns = set() if partition_modules is None else set(partition_modules)
+
+ def _generate_dummy_input(
+ self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
+ ) -> Dict[str, torch.Tensor]:
+ """Generates dummy input for model inference recording."""
+ # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
+ # from pickle, or from the "__class__" attribute in the general case.
+ model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
+ device = model.device
+ inputs_dict = {}
+
+ # when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
+ # rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
+ # After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
+ kv_cache_length = 5
+
+ if input_name in ["labels", "start_positions", "end_positions"]:
+ batch_size = shape[0]
+ if model_class_name in [
+ *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
+ *get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
+ *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
+ *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
+ *get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
+ ]:
+ inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
+ elif model_class_name in [
+ *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
+ *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
+ "XLNetForQuestionAnswering",
+ ]:
+ inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
+ inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
+ elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
+ if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
+ raise ValueError(
+ "Could not retrieve the problem type for the sequence classification task, please set "
+ 'model.config.problem_type to one of the following values: "regression", '
+ '"single_label_classification", or "multi_label_classification".'
+ )
+
+ if model.config.problem_type == "regression":
+ labels_shape = (batch_size, model.config.num_labels)
+ labels_dtype = torch.float32
+ elif model.config.problem_type == "single_label_classification":
+ labels_shape = (batch_size,)
+ labels_dtype = torch.long
+ elif model.config.problem_type == "multi_label_classification":
+ labels_shape = (batch_size, model.config.num_labels)
+ labels_dtype = torch.float32
+ else:
+ raise ValueError(
+ 'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
+ f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
+ )
+ inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
+
+ elif model_class_name in [
+ *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
+ *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
+ *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
+ *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
+ *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
+ *get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
+ "GPT2DoubleHeadsModel",
+ "PeftModelForCausalLM",
+ "PeftModelForSeq2SeqLM",
+ ]:
+ inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
+ elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
+ inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
+ else:
+ raise NotImplementedError(
+ f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
+ )
+ elif "pixel_values" in input_name:
+ batch_size = shape[0]
+ image_size = getattr(model.config, "image_size", None)
+ if image_size is None:
+ if hasattr(model.config, "vision_config"):
+ image_size = model.config.vision_config.image_size
+ elif hasattr(model.config, "encoder"):
+ image_size = model.config.encoder.image_size
+ else:
+ image_size = (_generate_random_int(), _generate_random_int())
+
+ # If no num_channels is in the config, use some arbitrary value.
+ num_channels = getattr(model.config, "num_channels", 3)
+ if not isinstance(image_size, collections.abc.Iterable):
+ image_size = (image_size, image_size)
+ height, width = image_size
+ inputs_dict[input_name] = torch.zeros(
+ batch_size, num_channels, height, width, dtype=torch.float32, device=device
+ )
+ elif "bbox" in input_name:
+ inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
+ elif "input_features" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
+ )
+ elif "visual_feats" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ shape
+ + [
+ model.config.visual_feat_dim,
+ ],
+ dtype=torch.float,
+ device=device,
+ )
+ elif "visual_pos" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ shape
+ + [
+ model.config.visual_pos_dim,
+ ],
+ dtype=torch.float,
+ device=device,
+ )
+ elif "inputs" in input_name:
+ inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
+ elif "input_values" in input_name:
+ batch_size, _ = shape
+ # Generating big sequence length for audio inputs.
+ seq_length = _generate_random_int(low=10000, high=20000)
+ inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
+ elif "mask" in input_name:
+ if "past_key_values" in input_names:
+ mask_shape = [shape[0], shape[1] + kv_cache_length]
+ else:
+ mask_shape = shape
+
+ inputs_dict[input_name] = torch.zeros(mask_shape, dtype=torch.long, device=device)
+ elif "ids" in input_name:
+ inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
+ elif "past_key_values" in input_name:
+ if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
+ raise NotImplementedError(
+ f"Symbolic trace with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in Transformers repository if you would like to see the support added."
+ )
+ num_heads = model.config.num_attention_heads
+ head_dim = model.config.hidden_size // model.config.num_attention_heads
+
+ cache_shape = (shape[0], num_heads, kv_cache_length, head_dim)
+ pkv = tuple(
+ (
+ torch.rand(cache_shape, dtype=torch.float, device=device),
+ torch.rand(cache_shape, dtype=torch.float, device=device),
+ )
+ for i in range(model.config.num_hidden_layers)
+ )
+ inputs_dict[input_name] = pkv
+ else:
+ shape_with_hidden_size = shape + [model.config.hidden_size]
+ inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
+
+ return inputs_dict
+
+ def is_leaf_module(self, m, module_qualified_name):
+ return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and (
+ hasattr(m, _IS_PARTITION_MODULE)
+ or (
+ m.__module__.startswith("torch.nn")
+ or m.__module__.startswith("torch.ao.nn")
+ and not isinstance(m, torch.nn.Sequential)
+ )
+ )
+
+ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
+ rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
+
+ if kind == "placeholder" and target in self.meta_args:
+ rv.install_metadata(self.meta_args[target])
+ return rv
+
+ if target in self.orig_fns:
+ # NOTE: tensor constructors in PyTorch define the `device` argument as
+ # *kwargs-only*. That is why this works. If you add methods to
+ # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
+ # this will break and you will likely see issues where we cannot infer
+ # the size of the output.
+ if "device" in kwargs:
+ kwargs["device"] = "meta"
+
+ try:
+ args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
+ kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
+
+ if kind == "call_function":
+ meta_target = _MANUAL_META_OVERRIDES.get(target, target)
+ meta_out = meta_target(*args_metas, **kwargs_metas)
+ if isinstance(meta_out, torch.Tensor):
+ meta_out = meta_out.to(device="meta")
+ elif kind == "call_method":
+ method = getattr(args_metas[0].__class__, target)
+ meta_target = _MANUAL_META_OVERRIDES.get(method, method)
+ meta_out = meta_target(*args_metas, **kwargs_metas)
+ elif kind == "call_module":
+ if not hasattr(self, "orig_forward"):
+ raise AttributeError(f"{self} does not have an attribute called orig_forward")
+ self._disable_module_getattr = True
+ try:
+ mod = self.root.get_submodule(target)
+
+ mod_type = type(mod)
+ assert not any(path for path in self.visited_partition_module_paths if target.startswith(path))
+ self.visited_partition_module_paths.add(target)
+ # assert mod_type not in self.partition_module_classes
+ if mod_type in _MANUAL_META_OVERRIDES:
+ meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
+ else:
+ if self.partition_module_classes_and_fqns and (
+ target in self.partition_module_classes_and_fqns
+ or mod_type in self.partition_module_classes_and_fqns
+ ):
+ raise ValueError # not to recurse into partition module's forward()
+ meta_out = self.orig_forward(*args_metas, **kwargs_metas)
+ except: # noqa: E722
+ mod = self.root.get_submodule(target)
+ mod_type = type(mod)
+ meta_out = None
+ finally:
+ self._disable_module_getattr = False
+ elif kind == "get_attr":
+ self._disable_module_getattr = True
+ try:
+ attr_itr = self.root
+ atoms = target.split(".")
+ for atom in atoms:
+ attr_itr = getattr(attr_itr, atom)
+ if isinstance(attr_itr, torch.Tensor):
+ meta_out = attr_itr.to(device="meta")
+ else:
+ meta_out = attr_itr
+ finally:
+ self._disable_module_getattr = False
+ else:
+ return rv
+
+ if not isinstance(rv, Proxy):
+ raise ValueError("Don't support composite output yet")
+ rv.install_metadata(meta_out)
+ except Exception as e:
+ if _IS_IN_DEBUG_MODE:
+ warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
+
+ return rv
+
+ # Replaced by .getattr from PyTorch 1.13
+ def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
+ if getattr(self, "_disable_module_getattr", False):
+ return attr_val
+ else:
+
+ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
+ for n, p in collection_to_search:
+ if attr_val is p:
+ if n not in parameter_proxy_cache:
+ kwargs = {}
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ParameterProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
+ parameter_proxy_cache[n] = val_proxy
+ return parameter_proxy_cache[n]
+ return None
+
+ if isinstance(attr_val, torch.nn.Parameter):
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
+ if maybe_parameter_proxy is not None:
+ return maybe_parameter_proxy
+
+ if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
+ maybe_buffer_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_buffers(), parameter_proxy_cache
+ )
+ if maybe_buffer_proxy is not None:
+ return maybe_buffer_proxy
+
+ return attr_val
+
+ # Needed for PyTorch 1.13+
+ def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
+ return self._module_getattr(attr, attr_val, parameter_proxy_cache)
+
+ def call_module(self, m, forward, args, kwargs):
+ self.orig_forward = forward
+ return super().call_module(m, forward, args, kwargs)
+
+ def proxy(self, node):
+ return HFProxy(node, self)
+
+ def trace(
+ self,
+ root: Union[torch.nn.Module, Callable[..., Any]],
+ concrete_args: Optional[Dict[str, Any]] = None,
+ dummy_inputs: Optional[Dict[str, Any]] = None,
+ complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
+ ) -> Graph:
+ """
+ Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
+ `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
+ the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
+ `torch.nn.Module` instance to use as the root and add embedded constants to.
+
+ Args:
+ root (`torch.nn.Module` or `Callable`):
+ Either a `torch.nn.Module`` or a function to be traced through. If root is not a
+ [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
+ concrete_args (`Dict[str, Any], *optional*):
+ Concrete arguments that should not be treated as Proxies
+ dummy_inputs (`Dict[str, Any]`, *optional*):
+ The dummy inputs needed to handle data-dependent control-flow if `root` is not a
+ [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
+ [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
+ complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
+ If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
+ `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.
+
+ Returns:
+ `torch.fx.Graph`:
+ A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.
+
+ """
+ sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)
+
+ if concrete_args is None:
+ concrete_args = {}
+
+ if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
+ for param in sig.parameters.values():
+ if param.name in dummy_inputs:
+ continue
+ if param.default is inspect.Parameter.empty:
+ raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
+ concrete_args.update(
+ {
+ p.name: p.default
+ for p in sig.parameters.values()
+ if (p.name not in dummy_inputs and p.name not in concrete_args)
+ }
+ )
+
+ input_names = sig.parameters.keys() - concrete_args.keys()
+
+ # Creating a random input shape to generate dummy inputs.
+ batch_size = _generate_random_int()
+ sequence_length = _generate_random_int()
+ shape = [batch_size, sequence_length]
+
+ if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
+ num_choices = _generate_random_int(low=2, high=5)
+ shape.insert(1, num_choices)
+
+ inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
+ for input_name in input_names:
+ if input_name in inputs:
+ continue
+ # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
+ # be able to use HFTracer._generate_dummy_input.
+ if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
+ ("_deserialize_graph_module", "_CodeOnlyModule")
+ ):
+ inputs.update(self._generate_dummy_input(root, input_name, shape, input_names=input_names))
+ else:
+ raise RuntimeError(
+ f"Could not generate input named {input_name} for because root is not a"
+ " transformers.PreTrainedModel."
+ )
+
+ concrete_metas = {
+ input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
+ for input_name, input_ in inputs.items()
+ }
+ for param in sig.parameters.values():
+ if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
+ concrete_metas[f"**{param.name}"] = {}
+ self.meta_args = concrete_metas
+ self.patched_torch_methods = {
+ target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
+ }
+ self.orig_fns = set()
+
+ for name, (wrapper, orig) in self.patched_torch_methods.items():
+ setattr(torch, name, wrapper)
+ self.orig_fns.add(orig)
+
+ try:
+ self.graph = super().trace(root, concrete_args=concrete_args)
+ finally:
+ for name, (_, orig) in self.patched_torch_methods.items():
+ setattr(torch, name, orig)
+
+ # This is necessary because concrete args are added as input to the traced module since
+ # https://github.com/pytorch/pytorch/pull/55888.
+ for node in self.graph.nodes:
+ if node.op == "placeholder":
+ # Removing default values for inputs as the forward pass will fail with them.
+ if node.target in input_names:
+ node.args = ()
+ # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
+ # It cannot infer on the attributes and methods the input should have, and fails.
+ node.type = torch.Tensor
+ # It is a concrete arg so it is not used and should be removed.
+ else:
+ to_visit = [node]
+ to_delete = collections.OrderedDict()
+ while to_visit:
+ n = to_visit.pop(0)
+ to_delete[n] = None
+ to_visit += list(n.users.keys())
+
+ for user in reversed(to_delete.keys()):
+ self.graph.erase_node(user)
+
+ # Without this, return type annotation "Tuple" is causing code execution failure.
+ if node.op == "output":
+ node.type = None
+
+ return self.graph
+
+ def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
+ """
+ Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
+ because its attributes are input-dependent.
+ """
+ return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
+
+ def _insert_module_as_submodule(self, mod: nn.Module) -> str:
+ """
+ Helper method which tries to insert a module that was not declared as submodule.
+ """
+ # If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
+ # It is not possible to insert such modules, those should be traced through.
+ if self._stateless_mod_instanciation_depends_on_proxies(mod):
+ return ""
+ idx = 0
+ mod_name = mod.__class__.__name__.lower()
+ path = f"{mod_name}_{idx}"
+ already_inserted = False
+ while hasattr(self.root, path):
+ if getattr(self.root, path) is mod:
+ already_inserted = True
+ break
+ path = f"{mod_name}_{idx}"
+ idx += 1
+
+ # No need to add multiple instances of the same module.
+ if not already_inserted:
+ self.root.add_module(path, mod)
+ return path
+
+ def path_of_module(self, mod: nn.Module) -> str:
+ """
+ Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
+ a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
+ string "foo.bar".
+
+ Args:
+ mod (str): The `Module` to retrieve the qualified name for.
+ """
+ try:
+ return super().path_of_module(mod)
+ except NameError as e:
+ if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
+ path = self._insert_module_as_submodule(mod)
+ return path
+ raise e
+
+ @compatibility(is_backward_compatible=True)
+ def keys(self, obj: "Proxy") -> Any:
+ """Called when a proxy object is has the keys() method called.
+ This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
+ your custom tracer.
+ """
+ attribute = HFAttribute(obj, "keys")()
+ if obj.node.target == "**kwargs":
+ return attribute._metadata
+ return attribute
+
+
+def get_concrete_args(model: nn.Module, input_names: List[str]):
+ sig = inspect.signature(model.forward)
+
+ if not (set(input_names) <= set(sig.parameters.keys())):
+ formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
+ formatted_allowed_input_names = ", ".join(sig.parameters.keys())
+ raise ValueError(
+ f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
+ f" {formatted_allowed_input_names}"
+ )
+
+ return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
+
+
+def hf_symbolic_trace(
+ model: PreTrainedModel,
+ input_names: Optional[List[str]] = None,
+ disable_check: bool = False,
+ tracer_cls: Type[HFTracer] = HFTracer,
+ partition_modules: List = None,
+) -> GraphModule:
+ """
+ Performs symbolic tracing on the model.
+
+ Args:
+ model ([`PretrainedModel`]):
+ The model to trace.
+ input_names (`List[str]`, *optional*):
+ The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
+ disable_check (`bool`, *optional*, defaults to `False`):
+ If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
+ tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
+ The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
+ partition_modules (`List`):
+ A list of string paths to un-partitionable submodules of custom modulen classes.
+
+ Returns:
+ `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
+
+ Example:
+
+ ```python
+ from transformers.utils.fx import symbolic_trace
+
+ traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
+ ```
+ """
+ if input_names is None:
+ input_names = model.dummy_inputs.keys()
+
+ input_names = list(input_names)
+ concrete_args = get_concrete_args(model, input_names)
+
+ if not disable_check:
+ check_if_model_is_supported(model)
+
+ # Tracing.
+ if partition_modules:
+ # annotate partition modules as minimally unpartitionable units in stage split.
+ assert isinstance(partition_modules, Sequence)
+
+ def _check_legitimate_fqn(unique_paths, path):
+ return not any(path == p or path.startswith(p + ".") for p in unique_paths)
+
+ partition_modules_paths = set()
+ for fqn, sub_module in model.named_modules():
+ if (fqn in partition_modules) or (
+ type(sub_module) in partition_modules and _check_legitimate_fqn(partition_modules_paths, fqn)
+ ):
+ # elif type(sub_module) in partition_modules and _check_legitimate_fqn(partition_modules_paths, fqn):
+ partition_modules_paths.add(fqn)
+ partition_modules_paths = list(partition_modules_paths)
+ register_partition_module(model, fully_qualified_names=partition_modules_paths)
+ tracer = tracer_cls(partition_modules=partition_modules)
+ traced_graph = tracer.trace(model, concrete_args=concrete_args)
+ traced = torch.fx.GraphModule(model, traced_graph)
+
+ if hasattr(model, "config"):
+ traced.config = model.config
+ # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
+ # _generate_dummy_input, where the model class is needed.
+ traced.class_for_deserialization = model.__class__
+ if hasattr(model, "device"):
+ traced.device = model.device
+
+ return traced
+
+
+def register_partition_module(module: nn.Module, fully_qualified_names: Union[str, Sequence] = None):
+ if fully_qualified_names is None:
+ setattr(module, _IS_PARTITION_MODULE, True)
+ else:
+ if isinstance(fully_qualified_names, str):
+ fully_qualified_names = [fully_qualified_names]
+ for fqn, sub_module in module.named_modules():
+ for mod_name in fully_qualified_names:
+ if fqn == mod_name:
+ setattr(sub_module, _IS_PARTITION_MODULE, True)
diff --git a/vescale/plan/__init__.py b/vescale/plan/__init__.py
new file mode 100644
index 0000000..3d6e4cf
--- /dev/null
+++ b/vescale/plan/__init__.py
@@ -0,0 +1,20 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+
+from .pipeline_parallel import PipelineParallelPlan
+from .spec import *
diff --git a/vescale/plan/pipeline_parallel.py b/vescale/plan/pipeline_parallel.py
new file mode 100644
index 0000000..b39fdb5
--- /dev/null
+++ b/vescale/plan/pipeline_parallel.py
@@ -0,0 +1,142 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+
+from dataclasses import dataclass, field
+from typing import List, Dict
+from .spec import * # noqa: F403
+import torch
+
+__all__ = ["PipelineParallelPlan"]
+
+
+@dataclass
+class PipelineParallelPlan:
+ # PP mode:
+ mode: ModeType = ModeType.GRAPH_EAGER
+
+ ########## model graph and partition ##########
+
+ # type of tracer to obtain the model execution graph
+ # fit modes: [GRAPH_EAGER]
+ # format: Enum
+ # consumer: PipeParser
+ tracer_type: TracerType = TracerType.AUTO
+
+ # kwargs to be fed to different parser, e.g. torch.fx, dynamo, export, etc
+ # fit modes: [GRAPH_EAGER]
+ # format: Enum
+ # consumer: PipeParser
+ tracer_kwargs: Dict = None
+
+ # method of stage partitioning for all modes
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: Enum
+ # consumer: PipeParser and ManualPipeParser
+ split_method: PipelineSplitMethodType = PipelineSplitMethodType.MANUAL
+
+ # number of pipeline stages
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: int
+ # consumer: PipeParser
+ num_stages: int = 2
+
+ # number of virtual module chunks per pipeline stage
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: int
+ # consumer: ScheduleEngine, PipeModule
+ virtual_chunks: int = 1
+
+ # list of minimum un-partitionable units in model forward graph. Internal hierarchy
+ # of a partition unit is maintained during stage splitting
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: list of fqns to particular modules/callable or module classes
+ # consumer: ScheduleEngine, PipeModule
+ smallest_unsplittable_units: List = field(default_factory=list)
+
+ # stage boundaries
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: a list of fqns or index integers of particular modules / callables
+ # consumer: PipeParser and ManualParser
+ split_points: List = field(default_factory=list)
+
+ # enables to manually define boundaries of virtual stage chunks in interleaved 1F1B schedule
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: boolean
+ # consumer: PipeParser and ManualParser
+ enable_vpp_split_points: bool = False
+
+ # enables to uniformly split stages by modules and operators when split_method==PipelineSplitMethodType.UNIFORM
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: boolean
+ # consumer: PipeParser and ManualParser
+ uniform_split_ops: bool = False
+
+ ########## end of model graph generation, partition ##########
+
+ ########## pipeline runtime ##########
+
+ # executes batched p2p communication for simple 1f1b and interleaved 1f1b,
+ # mutually exclusive to overlap_p2p_comm
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: bool
+ # consumer: ScheduleEngine
+ batch_p2p_comm: bool = False
+
+ # executes overlapped p2p communication for simple 1f1b and interleaved 1f1b,
+ # mutually exclusive to batch_p2p_comm
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: bool
+ # consumer: ScheduleEngine
+ overlap_p2p_comm: bool = True
+
+ # sets to True in inference, so that pipeline schedule only executes forward propagation
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: bool
+ # consumer: ScheduleEngine
+ forward_only: bool = False
+
+ # pipeline schedule type
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: Enum
+ # consumer: ScheduleEngine
+ schedule_type: PipelineScheduleType = PipelineScheduleType.SIMPLE_1F1B
+
+ # reuses data tensor shapes in some use cases instead of communicating
+ # shapes before tensors. Use with caution!
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: bool
+ # consumer: ScheduleEngine
+ reuse_p2p_tensor_shape: bool = False
+
+ # precision types of communicated tensors during pipeline execution
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: torch.dtype
+ # consumer: ScheduleEngine
+ p2p_tensor_dtype: torch.dtype = torch.float32
+
+ ########## end of pipeline schedule ##########
+
+ ########## other information ##########
+
+ # list of groups of fqns whose parameters or gradients will be synchronized per step, e.g. embedding modules
+ # fit modes: [EAGER, GRAPH_EAGER]
+ # format: [ [word_embeddingA, word_embeddingB], [vision_embeddingA, vision_embeddingB] ]
+ # consumer: build utilities in vescale/api.py
+ shared_modules: List[List[str]] = field(default_factory=list)
+
+ ########## end of other information ##########
diff --git a/vescale/plan/spec.py b/vescale/plan/spec.py
new file mode 100644
index 0000000..e91a39e
--- /dev/null
+++ b/vescale/plan/spec.py
@@ -0,0 +1,78 @@
+################################################################################
+#
+# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+################################################################################
+
+from enum import Enum, auto
+from dataclasses import dataclass
+from typing import TypeVar
+
+
+__all__ = [
+ "ModeType",
+ "PipelineP2PSpec",
+ "PipelineSplitMethodType",
+ "PipelineScheduleType",
+ "TracerType",
+]
+
+ArrayLike = TypeVar("ArrayLike")
+
+
+class ModeType(Enum):
+ """Type of parallel modes"""
+
+ EAGER = auto()
+ MANUAL_EAGER = auto()
+ GRAPH_EAGER = auto()
+
+
+class PipelineSplitMethodType(Enum):
+ """Type of pipeline stage partitioning methods"""
+
+ MANUAL = auto()
+ UNIFORM = auto()
+ PARAMETERS = auto()
+ AUTO = auto()
+ SIMULATOR = auto()
+ FLOPS = auto()
+
+
+class PipelineScheduleType(Enum):
+ """Type of pipeline parallel schedules"""
+
+ SIMPLE_1F1B = auto()
+ INTERLEAVED_1F1B = auto()
+ GPIPE = auto()
+ ZERO_BUBBLE = auto()
+ GRAPH_PIPE = auto()
+
+
+class TracerType(Enum):
+ VESCALE_FX = auto()
+ VESCALE_EXPORT = auto()
+ HF_FX = auto()
+ TORCH_FX = auto()
+ TORCH_DYNAMO = auto()
+ TORCH_EXPORT = auto()
+ AUTO = auto()
+
+
+@dataclass
+class PipelineP2PSpec:
+ """The p2p spec for communication p2p spec in manual pipeline plan."""
+
+ peer_stage_idx: int
+ peer_output_idx: int = 0