Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[FEATURE] Reduce congestion of sending on one mesh #802

Open
4 tasks
ZYHowell opened this issue Dec 6, 2022 · 16 comments
Open
4 tasks

[FEATURE] Reduce congestion of sending on one mesh #802

ZYHowell opened this issue Dec 6, 2022 · 16 comments
Labels
enhancement New feature good first issue Good for newcomers

Comments

@ZYHowell
Copy link
Collaborator

ZYHowell commented Dec 6, 2022

This can be a starting point to learn runtime_emitter and cross_mesh_resharding.

Background

In Pipeshard Parallel, when a tensor is required to be received from a mesh, we always chose the mesh that exactly generates it, has what happens here. However, when the tensor is consumed into multiple pipeline stages, a better solution MIGHT be that the later consumer receives the tensor from one of the prior consumers. For example, when stage 0 sends a tensor to stages 1, 2, and 3, but stages 1-3 don't have much communication, then it can be stage 2 receives from stage 1, and stage 3 receives from stage 2.

TODO

  • Add a pass to optimize the case above. Given a pipeline schedule and mesh allocation, the pass decides which mesh should a tensor be received from.
  • In CrossMeshCommunicator, consume that designed decision
  • In runtime emitter, chose that decision instead of using the src or using the first(using the first is even worse. It can be a bug).
  • (Optional) To do the same thing for the backward semantic, we need to do more: if stage 0 outputs t to stage 1, 2, and 3, in the backward there will be:
x = gradient of t a stage 3
x' = pipeline_end(x)
...
y = gradient of t a stage 2
y' = pipeline_end(y)
...
z = gradient of t a stage 1
z' = pipeline_end(z)
a = x' + y'
b = a + z'
grad_of_t = pipeline_start(b)
consume grad_of_t in layer 0's backward

We need to merge a = x' + y' into the backward of stage 2, b = a + z' into the backward of stage 1 by modifying code here

@ZYHowell ZYHowell added enhancement New feature good first issue Good for newcomers labels Dec 6, 2022
@jon-chuang
Copy link

Hello, please assign this to me :)

To my understanding, Ray's object store is not used for activation/param transfer, and Ray is only used for task orchestration, correct?

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Apr 4, 2023

right. The compilation is:
wrapped Jaxprs of each pipeline stage --by CrossMeshCommunicator--> SymbolicReshardingTask --by PipelineInstEmitter--> PipelineInstruction(SEND/RECV/BROADCAST)
Each pipeline inst is orchestrated by code in the collective folder, which finally calls nccl. The congestion issue may be solved by only consider the compilation part.

@jon-chuang
Copy link

For the first task, writing the pass, I will simply write a test to show the desired transform is applied to jaxpr.

As for scheduling, I guess the tensor should be queued instantly upon being received.

@jon-chuang
Copy link

One complication: are all vars uniquely named in HLO module, i.e. SSA?

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Apr 4, 2023

For Jaxpr, each Var has its own id(an int) unless its a DropVar(a placeholder), I'd expect most work are at this level;
For HLO, each var corresponds to a specific HloInstruction

@jon-chuang
Copy link

Hello, another question: are we guaranteed that the stages are sequentially dependent? Meaning that we have a chain, not a DAG? It doesn't affect too much, but presumably, for DAG structure:

stage 0 -> stage 1
        -> stage 2

Where there is no functional dependence of stage 2 on stage 1, we should indeed broadcast to stage 1 and stage 2 from stage 0 to prevent any stalls.

However, perhaps we can ignore it for now.

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Apr 5, 2023

it's mainly sequential, but will have some skip connection(e.g. stage 0 -> stage 2, stage 0 -> stage 3, etc.). Otherwise we wouldn't have this issue.
Besides, there are both forward and backward stages, so stage 0 and stage -1 are on the same mesh

@jon-chuang
Copy link

each Var has its own id

I presume this is unique, so I will use it as var uuid to act as a lookup key.

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Apr 5, 2023

you can just use var. It wraps such an id

@jon-chuang
Copy link

jon-chuang commented Apr 5, 2023

Another question:

Given var x on stage 0 and consumed by stage 1, but not output by stage 1, do we need to now add var x to the outvars of stage 1 to be consumed from stage 1 by a downstream stage 2?

Further, is there any cost to adding all invars to outvars of every stage by default (except messiness)?

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Apr 5, 2023

it depends on your algo. I think the first principle is to not increase the total comm size. E.g. if originally we send 0>2, I cannot see any advantage in making it 0>1>2. The case in the issue is: 0>1>2 is better than (0>1 and 0>2). In addition, if 2 sends x to 0 and 1 sends y to 0, but 0 only uses x + y, we can make it be 2 sends x to 1 and 1 sends x+y to 0.

Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars. Besides, the messiness itself might influence later passes so we'd hope to avoid it.

@jon-chuang
Copy link

jon-chuang commented Apr 5, 2023

The algo is a simple one. It is take from last seen stage:

last_seen = {}
# Sequentially walk the stages
for stage in stages:
  for (src, var) in cross_deps[stage.id]:
    # If var is a dep, check if we have already read from it.
    # If so, add to outvars of that stage and fetch from the latest stage.
    if var in cache:
      src_mesh = meshes[last_seen[var]]
      upsert(src_mesh.outvars, var)
      last_seen[var] = stage.id
    else:
      last_seen[var] = stage.id
      src_mesh = src

Is adding to outvars necessary? It seems that in our case, we don't need to add to outvars, we should be able to fetch from the invars?

invars -> [model] -> outvars ==cross-shard==>
 |======================cross-shard=>

This would mean that the cross-shard invars can begin to be re-sharded prior to the model invocation.

However, not sure if outvars is merely logical, and we can facilitate the same async transfer as soon as one of the outvars is ready, as marked by the runtime.

Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars.

I will avoid this then.

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Apr 5, 2023

The heuristic works for the first scene. In the above (0>1 & 0>2) case, we don't need to add it in 1's outvars. You can read the PipelineInstEmitter for more details how we actually launch send/recv and free.
We've already done the async transfer with cuda event and some kernel injected to record the event.

@jon-chuang
Copy link

jon-chuang commented Apr 5, 2023

Btw, as far as producer goes, every var corresponds to e.g. a single tensor sharded across the entire submesh, correct?

Anw, adding an invar to the outvars is non-trivial. One has to deal with donation, and also might need to recompile the jaxpr. Prefer if the transfer takes a different pathway to piggybacking on outvars. Any suggestions?

@jon-chuang
Copy link

jon-chuang commented Apr 5, 2023

In addition, if 2 sends x to 0 and 1 sends y to 0, but 0 only uses x + y, we can make it be 2 sends x to 1 and 1 sends x+y to 0.

This seems more complicated. For an SQL database person, it sounds like expression pushdown.

Sounds like we really do want to reconfigure the jaxpr after pipelining but before sharding. So at the pipelining pass, we should use our last seen stage heuristic to force relevant invars to become outvars. Not sure if behaviour should be to skip invar if it is donated before this pass.

We've already done the async transfer with cuda event and some kernel injected to record the event.

However, I don't understand how the async queueing would occur in this case. Will every jaxpr outvar be evaluated and queued async concurrently?

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Apr 5, 2023

A var corresponds to a logical tensor including all its shards.

In the "pipeline pass", we only decide how the computational graph is divided into pipeline stages, but not the communication between pipeline stages.

Instead, in PipelineInstEmitter, we create a schedule of each device mesh, where we manages the execution of each communication, computation, and memory deallocation(tensors are allocated only with computation, and communication reuses those allocated tensors to receive). At there, we store the live tensors of each mesh in PipelineInstEmitterHelper at each time tick.

So the way I think about the forward part is: for the forward part, we only modify code in PipelineInstEmitter to emit send/recv from a mesh with the least traffic among all meshes having the tensor. For the backward part, things are more complicated, there might be something related to the "pipeline pass"

I'd suggest you read the following for more details: the architecture section in the project's doc and

def _compile_exec_one_tick(self, sched, donation_mapping, instruction_lists,
executable_uuids, executable_config_lists):
worker_tmp_instructions = {}
for mesh in self.mesh_group:
for worker in mesh.workers:
worker_tmp_instructions[worker] = []
for mesh_idx, task in enumerate(sched):
if not task:
continue
batch_idx, stage_idx = task
stage = self.stages[stage_idx]
# shard_args for intermediates
to_reshard_vars = []
reshard_sharding_specs = []
for invar, spec in zip(stage.invars, stage.input_sharding_specs):
if self.env.var_at(invar, batch_idx, mesh_idx):
# have a copy at the current mesh
continue
# TODO(yonghao): to avoid congestion, maybe sending from the
# last one (a.k.a. the latest one receiving it) is better, but
# we have to create the corresponding cross-mesh communication
# task.
# if len(self.env.get_var_meshes(invar, batch_idx)) > 1:
# raise NotImplementedError(
# "Not support resharding replicated")
var_key = self.env.get_var_with_accumulate(invar, batch_idx)
src_idx = list(
self.env.get_var_meshes(invar, batch_idx).keys())[0]
resharding = self._resharding_tasks[src_idx][mesh_idx][var_key]
if resharding.is_local_allgather_task:
spec = resharding.task_spec.dst_sharding_spec
to_reshard_vars.append(invar)
reshard_sharding_specs.append(spec)
self._compile_get_vars_from_mesh(to_reshard_vars,
reshard_sharding_specs, mesh_idx,
batch_idx, instruction_lists,
instruction_lists,
executable_config_lists)
# execute
self._compile_exec_one_mesh(mesh_idx, task, executable_uuids,
donation_mapping,
worker_tmp_instructions)
for worker, worker_instruction in worker_tmp_instructions.items():
instruction_lists[worker].extend(worker_instruction)
.

For overlapping communication and computation, please refer to #773 and alpa-projects/tensorflow-alpa#127

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants