-
Notifications
You must be signed in to change notification settings - Fork 360
[FEATURE] Reduce congestion of sending on one mesh #802
Comments
Hello, please assign this to me :) To my understanding, Ray's object store is not used for activation/param transfer, and Ray is only used for task orchestration, correct? |
right. The compilation is: |
For the first task, writing the pass, I will simply write a test to show the desired transform is applied to jaxpr. As for scheduling, I guess the tensor should be queued instantly upon being received. |
One complication: are all vars uniquely named in HLO module, i.e. SSA? |
For |
Hello, another question: are we guaranteed that the stages are sequentially dependent? Meaning that we have a chain, not a DAG? It doesn't affect too much, but presumably, for DAG structure:
Where there is no functional dependence of stage 2 on stage 1, we should indeed broadcast to stage 1 and stage 2 from stage 0 to prevent any stalls. However, perhaps we can ignore it for now. |
it's mainly sequential, but will have some skip connection(e.g. stage 0 -> stage 2, stage 0 -> stage 3, etc.). Otherwise we wouldn't have this issue. |
I presume this is unique, so I will use it as var uuid to act as a lookup key. |
you can just use var. It wraps such an id |
Another question: Given Further, is there any cost to adding all |
it depends on your algo. I think the first principle is to not increase the total comm size. E.g. if originally we send 0>2, I cannot see any advantage in making it 0>1>2. The case in the issue is: 0>1>2 is better than (0>1 and 0>2). In addition, if 2 sends Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars. Besides, the messiness itself might influence later passes so we'd hope to avoid it. |
The algo is a simple one. It is last_seen = {}
# Sequentially walk the stages
for stage in stages:
for (src, var) in cross_deps[stage.id]:
# If var is a dep, check if we have already read from it.
# If so, add to outvars of that stage and fetch from the latest stage.
if var in cache:
src_mesh = meshes[last_seen[var]]
upsert(src_mesh.outvars, var)
last_seen[var] = stage.id
else:
last_seen[var] = stage.id
src_mesh = src Is adding to outvars necessary? It seems that in our case, we don't need to add to outvars, we should be able to fetch from the invars?
This would mean that the cross-shard invars can begin to be re-sharded prior to the model invocation. However, not sure if outvars is merely logical, and we can facilitate the same async transfer as soon as one of the outvars is ready, as marked by the runtime.
I will avoid this then. |
The heuristic works for the first scene. In the above (0>1 & 0>2) case, we don't need to add it in 1's outvars. You can read the PipelineInstEmitter for more details how we actually launch send/recv and free. |
Btw, as far as producer goes, every var corresponds to e.g. a single tensor sharded across the entire submesh, correct? Anw, adding an invar to the outvars is non-trivial. One has to deal with donation, and also might need to recompile the jaxpr. Prefer if the transfer takes a different pathway to piggybacking on |
This seems more complicated. For an SQL database person, it sounds like expression pushdown. Sounds like we really do want to reconfigure the jaxpr after pipelining but before sharding. So at the pipelining pass, we should use our
However, I don't understand how the async queueing would occur in this case. Will every jaxpr |
A var corresponds to a logical tensor including all its shards. In the "pipeline pass", we only decide how the computational graph is divided into pipeline stages, but not the communication between pipeline stages. Instead, in So the way I think about the forward part is: for the forward part, we only modify code in I'd suggest you read the following for more details: the architecture section in the project's doc and alpa/alpa/pipeline_parallel/runtime_emitter.py Lines 545 to 591 in 1ddb2dc
For overlapping communication and computation, please refer to #773 and alpa-projects/tensorflow-alpa#127 |
This can be a starting point to learn
runtime_emitter
andcross_mesh_resharding
.Background
In Pipeshard Parallel, when a tensor is required to be received from a mesh, we always chose the mesh that exactly generates it, has what happens here. However, when the tensor is consumed into multiple pipeline stages, a better solution MIGHT be that the later consumer receives the tensor from one of the prior consumers. For example, when stage 0 sends a tensor to stages 1, 2, and 3, but stages 1-3 don't have much communication, then it can be stage 2 receives from stage 1, and stage 3 receives from stage 2.
TODO
CrossMeshCommunicator
, consume that designed decisiont
to stage 1, 2, and 3, in the backward there will be:We need to merge a = x' + y' into the backward of stage 2, b = a + z' into the backward of stage 1 by modifying code here
The text was updated successfully, but these errors were encountered: