Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[BUG] PipeshardParallel crashes when apply_grad part is empty #560

Open
merrymercy opened this issue Jun 30, 2022 · 1 comment
Open

[BUG] PipeshardParallel crashes when apply_grad part is empty #560

merrymercy opened this issue Jun 30, 2022 · 1 comment
Labels
known bug Something isn't working

Comments

@merrymercy
Copy link
Member

merrymercy commented Jun 30, 2022

Change this line

grads = jax.grad(loss_func)(state.params, batch["x"], batch["y"])

from jax.grad to alpa.grad. I got this error

WARNING:alpa.pipeline_parallel.apply_grad:the apply gradient part is empty. Hint: apply() after alpa.grad
(CompileWorker pid=None) 2022-09-09 01:01:54.866096: F external/org_tensorflow/tensorflow/compiler/xla/service/spmd/slice_auto_sharded_stages.cc:118] Check failed: ins->opcode() != HloOpcode::kParameter (parameter vs. parameter)All the inputs to a pipeline stage should be from the start marker. %Arg_0.1 = f32[64,1024]{1,0} parameter(0), sharding={replicated}
(CompileWorker pid=None) *** SIGABRT received at time=1662685314 on cpu 63 ***
(CompileWorker pid=None) PC: @     0x7f8a54ddae87  (unknown)  raise
(CompileWorker pid=None)     @     0x7f8a55d93980  1112381872  (unknown)
(CompileWorker pid=None)     @     0x7f65e03f53a5       6880  xla::spmd::CreateStageModule()
(CompileWorker pid=None)     @     0x7f65e03f63ed        928  xla::spmd::SliceAutoShardedStagesInternal()
(CompileWorker pid=None)     @     0x7f65e03f7690         96  xla::spmd::SliceAutoShardedStages::Run()
(CompileWorker pid=None)     @     0x7f65e3fa01c4       1024  xla::HloPassPipeline::RunPassesInternal<>()
(CompileWorker pid=None)     @     0x7f65e3fa0ed8        480  xla::HloPassPipeline::Run()
(CompileWorker pid=None)     @     0x7f65e03e1519         80  xla::HloPassInterface::Run()
(CompileWorker pid=None)     @     0x7f65e03ec118       4064  xla::spmd::RunAutoShardingPass()
(CompileWorker pid=None)     @     0x7f65e03cda26        192  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
(CompileWorker pid=None)     @     0x7f65e0150c2c        736  pybind11::cpp_function::dispatcher()
(CompileWorker pid=None)     @     0x56254b66b4b0  (unknown)  _PyMethodDef_RawFastCallKeywords
(CompileWorker pid=None)     @     0x7f65e0150590  (unknown)  (unknown)
(CompileWorker pid=None)     @ 0x6e6964726168735f  (unknown)  (unknown)
(CompileWorker pid=None) [2022-09-09 01:01:54,982 E 33902 33902] logging.cc:321: *** SIGABRT received at time=1662685314 on cpu 63 ***
(CompileWorker pid=None) [2022-09-09 01:01:54,982 E 33902 33902] logging.cc:321: PC: @     0x7f8a54ddae87  (unknown)  raise
(CompileWorker pid=None) 2022-09-09 01:01:54.863147: F external/org_tensorflow/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc:1276] Check failed: user->opcode() == HloOpcode::kGetTupleElement (custom-call vs. get-tuple-element)
(CompileWorker pid=None) *** SIGABRT received at time=1662685314 on cpu 3 ***
(CompileWorker pid=None) PC: @     0x7fe806b0ee87  (unknown)  raise
(CompileWorker pid=None)     @     0x7fe807ac7980  (unknown)  (unknown)
(CompileWorker pid=None)     @     0x7fc392f72c87        528  xla::spmd::PassThroughCustomCallMarkerUser()
(CompileWorker pid=None)     @     0x7fc392f792e3        352  xla::spmd::FindReplicateSet()
(CompileWorker pid=None)     @     0x7fc392f79c5c        432  xla::spmd::FindReplicateSet()
(CompileWorker pid=None)     @     0x7fc392f7e41b       2144  xla::spmd::GenerateReduceScatter()
(CompileWorker pid=None)     @     0x7fc392f6d6df       2384  xla::spmd::AutoSharding::Run()
(CompileWorker pid=None)     @     0x7fc395c8c1c4       1024  xla::HloPassPipeline::RunPassesInternal<>()
(CompileWorker pid=None)     @     0x7fc395c8ced8        480  xla::HloPassPipeline::Run()
(CompileWorker pid=None)     @     0x7fc3920cd519         80  xla::HloPassInterface::Run()
(CompileWorker pid=None)     @     0x7fc3920d8118       4064  xla::spmd::RunAutoShardingPass()
(CompileWorker pid=None)     @     0x7fc3920b9a26        192  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()

A script to reproduce:

import unittest
import os

import jax
import jax.numpy as jnp
import optax
import ray

import alpa
from alpa import init, parallelize, PipeshardParallel
from alpa.model.model_util import TrainState
from alpa.parallel_method import LocalPipelineParallel
from alpa.pipeline_parallel.layer_construction import manual_layer_construction
from alpa.testing import MLPModel, assert_allclose


class PipelineMLPTest(unittest.TestCase):

    def setUp(self):
        os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
        init(cluster="ray")

    def train_2_layer_mlp(self, method):

        def train_step(state, batch):

            def loss_func(params, x, y):
                out = state.apply_fn(params, x)
                loss = jnp.mean((out - y)**2)
                return loss

            # Note, we can only use jax.grad in this testcase.
            # TODO: Fix https://github.com/alpa-projects/alpa/issues/560
            grads = alpa.grad(loss_func)(state.params, batch["x"], batch["y"])
            return grads

        batch_size = 64
        hidden_size = 1024

        x = jnp.ones((batch_size, hidden_size))
        y = jnp.ones((batch_size, hidden_size))

        # Init model and optimizer
        model = MLPModel(num_layers=4,
                         hidden_size=hidden_size,
                         add_manual_pipeline_marker=True)
        rngkey = jax.random.PRNGKey(0)
        params = model.init(rngkey, x)
        tx = optax.sgd(learning_rate=1e-2)
        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  dynamic_scale=None)

        # Train step
        batch = {"x": x, "y": y}
        gradients = train_step(state, batch)
        p_train_step = parallelize(train_step, donate_argnums=(), method=method)
        gradients_with_pipeline = p_train_step(state, batch)

        # Check results
        assert_allclose(gradients, gradients_with_pipeline)

        # Check debug utilities
        if isinstance(method, PipeshardParallel):
            executable = p_train_step.get_last_executable()
            executable.dump_debug_info("tmp")

    def test_2_layer_mlp_local_pipeline_parallel(self):
        self.train_2_layer_mlp(LocalPipelineParallel())

    def test_2_layer_mlp_pipeshard_parallel(self):
        self.train_2_layer_mlp(PipeshardParallel(layer_option="manual"))


def suite():
    suite = unittest.TestSuite()
    #suite.addTest(PipelineMLPTest("test_2_layer_mlp_local_pipeline_parallel"))
    suite.addTest(PipelineMLPTest("test_2_layer_mlp_pipeshard_parallel"))
    return suite


if __name__ == '__main__':
    runner = unittest.TextTestRunner()
    runner.run(suite())
@merrymercy
Copy link
Member Author

@zhuohan123 @ZYHowell Not very urgent though

@merrymercy merrymercy added the known bug Something isn't working label Jul 1, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
known bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant