Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[FEATURE] Auto-Parallelization for Graphs with Predesignated Parameter Shardings #550

Open
zhuohan123 opened this issue Jun 28, 2022 · 0 comments
Assignees
Labels
enhancement New feature

Comments

@zhuohan123
Copy link
Member

zhuohan123 commented Jun 28, 2022

Currently, Alpa only supports auto parallelization of a graph without existing annotation. However, there are cases where the shardings of some parameters are provided in advance. For example, when performing inference on variable length inputs, the graphs for different should share the same parameter sharding to be able to serve these graphs concurrently efficiently. This doc outlines the steps towards fully supporting auto-parallelization with user provided shardings:

Manual inter-operator parallelism and sharding-propagation based intra-operator parallelism

This should be our first step and should be the easiest to be implemented (@comaniac). We can follow the logic of compile_create_state_executable to:

  1. Get the sharding spec and the mesh assignment for the parameters.
    executable = train_step.get_executable(state_aval, other_args)
    placement_specs = executable.get_input_placement_specs()[0]
    placement_specs, _ = tree_flatten(placement_specs)
  2. Slice the pipeline stages according to the user annotation. Assume the user provided stage boundaries align with the provided mesh assignment for the parameters.
  3. Set the input sharding and run the sharding propagation pass. In the example below, we run sharding propagation with the specific output sharding:
    pipeshard_config = compile_pipeshard_executable_internal(
    new_jaxpr, None, 1, in_tree, [False] * len(avals),
    [False] * len(avals), executable.mesh_group.parent, 1, "inference",
    AutoShardingOption(enable_auto_sharding=False),
    UniformStageOption(), name, output_shardings)

Manual inter-operator parallelism and Alpa intra-operator parallelism

In this step, we need to modify the auto-sharding ILP in Alpa to support user provided annotation. The basic idea is to force some variables in the ILP to select the user provided choice. This part has already been implemented within Google's internal auto-sharding pass. We are in the process of open-sourcing that part into the official tensoflow/XLA codebase, which should be out within two weeks. I will update this part after the code is open-sourced.

Alpa inter- and intra-operator parallelism

In this step, we need to modify the stage-slicing DP algorithm in Alpa as follows:

  1. Support running the DP algorithm with a user provided mesh slicing. In the current algorithm, we assume the user only provides a huge 2D device mesh and we will slice the mesh and the stages together.
  2. Let the DP algorithm consider the parameters' mesh assignment when slicing stages.
  3. Let the layer clustering algorithm also consider parameters' mesh assignment.

I can follow up with a more detailed design for the algorithmic changes here.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature
Projects
None yet
Development

No branches or pull requests

2 participants