Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoweringContext] SPMD support #8418

Closed
wants to merge 0 commits into from

Conversation

rpsilva-aws
Copy link
Contributor

@rpsilva-aws rpsilva-aws commented Nov 26, 2024

In this PR, we extend the lowering context to support SPMD.

Testing

  • TestOperations (reference HLO):
HloModule SomeFn.12, entry_computation_layout={(f32[2048]{0}, f32[], f32[32,2048]{1,0})->(f32[2048]{0}, f32[32,2048]{1,0})}

ENTRY %SomeFn.12 (p0.3: f32[2048], p1.7: f32[], p2.8: f32[32,2048]) -> (f32[2048], f32[32,2048]) {
  %p0.3 = f32[2048]{0} parameter(0), sharding={devices=[4,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 last_tile_dim_replicate}
  %constant.2 = f32[] constant(1)
  %constant.1 = f32[] constant(1)
  %multiply.4 = f32[] multiply(f32[] %constant.2, f32[] %constant.1)
  %broadcast.5 = f32[2048]{0} broadcast(f32[] %multiply.4), dimensions={}
  %add.6 = f32[2048]{0} add(f32[2048]{0} %p0.3, f32[2048]{0} %broadcast.5)
  %p2.8 = f32[32,2048]{1,0} parameter(2), sharding={devices=[1,8,4]0,8,16,24,1,9,17,25,2,10,18,26,3,11,19,27,4,12,20,28,5,13,21,29,6,14,22,30,7,15,23,31 last_tile_dim_replicate}
  %p1.7 = f32[] parameter(1), sharding={replicated}
  %broadcast.9 = f32[32,2048]{1,0} broadcast(f32[] %p1.7), dimensions={}
  %multiply.10 = f32[32,2048]{1,0} multiply(f32[32,2048]{1,0} %p2.8, f32[32,2048]{1,0} %broadcast.9)
  ROOT %tuple.11 = (f32[2048]{0}, f32[32,2048]{1,0}) tuple(f32[2048]{0} %add.6, f32[32,2048]{1,0} %multiply.10)
}

@rpsilva-aws
Copy link
Contributor Author

rpsilva-aws commented Nov 26, 2024

@JackCaoG @tengyifei, do we have a plan for different IR values that we can not deduce from the input alone?

Referencing Jack's example:

device = torch_xla.device()
t1 = torch.tensor(100, device=device)
xs.mark_sharding(t1, mesh, spec)
t2 = t1 * 2

scan(layers, t2)

I could look into creating a minimal framework that traces to identify the ancestor node that has sharding specs, if there are no clashes, nor ops that does not allow us to. Do we have any other suggestions for this one, or if something already exists? Worst case scenario, we enforce the inputs with sharding specs as an invariant.

@rpsilva-aws
Copy link
Contributor Author

rpsilva-aws commented Nov 26, 2024

FYI, if we do not eagerly instantiate the tensors on the CPU, but forward it to the device data with rng:

## BEGIN_GRAPH
HloModule IrToHlo.75, entry_computation_layout={(s64[], f32[])->(f32[2048]{0}, f32[32,2048]{1,0}, f32[2048]{0}, f32[32,2048]{1,0})}

ENTRY %IrToHlo.75 (p0.7: s64[], p1.71: f32[]) -> (f32[2048], f32[32,2048], f32[2048], f32[32,2048]) {
  %constant.10 = s64[] constant(2531011)
  %constant.8 = s64[] constant(214013)
  %p0.7 = s64[] parameter(0), sharding={replicated}
  %multiply.9 = s64[] multiply(s64[] %constant.8, s64[] %p0.7)
  %add.11 = s64[] add(s64[] %constant.10, s64[] %multiply.9)
  %convert.22 = u64[] convert(s64[] %add.11)
  %reshape.26 = u64[1]{0} reshape(u64[] %convert.22)
  %constant.23 = u64[] constant(0)
  %reshape.27 = u64[1]{0} reshape(u64[] %constant.23)
  %concatenate.28 = u64[2]{0} concatenate(u64[1]{0} %reshape.26, u64[1]{0} %reshape.27), dimensions={0}
  %rng-bit-generator.29 = (u64[2]{0}, u32[16,2,2048]{2,1,0}) rng-bit-generator(u64[2]{0} %concatenate.28), algorithm=rng_default
  %get-tuple-element.31 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[16,2,2048]{2,1,0}) %rng-bit-generator.29), index=0
  %constant.1 = f32[] constant(1)
  %reshape.2 = f32[1]{0} reshape(f32[] %constant.1)
  %broadcast.3 = f32[1]{0} broadcast(f32[1]{0} %reshape.2), dimensions={0}
  %reshape.4 = f32[] reshape(f32[1]{0} %broadcast.3)
  %broadcast.5 = f32[2048]{0} broadcast(f32[] %reshape.4), dimensions={}
  %custom-call.6 = f32[2048]{0} custom-call(f32[2048]{0} %broadcast.5), custom_call_target="Sharding", sharding={devices=[4,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 last_tile_dim_replicate}
  %constant.17 = f32[] constant(0)
  %reshape.18 = f32[1,1]{1,0} reshape(f32[] %constant.17)
  %broadcast.19 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.18), dimensions={0,1}
  %reshape.20 = f32[] reshape(f32[1,1]{1,0} %broadcast.19)
  %broadcast.21 = f32[32,2048]{1,0} broadcast(f32[] %reshape.20), dimensions={}
  %constant.49 = f32[] constant(6.28318548)
  %broadcast.50 = f32[16,1,2048]{2,1,0} broadcast(f32[] %constant.49), dimensions={}
  %get-tuple-element.30 = u32[16,2,2048]{2,1,0} get-tuple-element((u64[2]{0}, u32[16,2,2048]{2,1,0}) %rng-bit-generator.29), index=1
  %constant.32 = u32[] constant(9)
  %broadcast.33 = u32[16,2,2048]{2,1,0} broadcast(u32[] %constant.32), dimensions={}
  %shift-right-logical.34 = u32[16,2,2048]{2,1,0} shift-right-logical(u32[16,2,2048]{2,1,0} %get-tuple-element.30, u32[16,2,2048]{2,1,0} %broadcast.33)
  %convert.35 = f32[16,2,2048]{2,1,0} convert(u32[16,2,2048]{2,1,0} %shift-right-logical.34)
  %constant.36 = f32[] constant(1.1920929e-07)
  %broadcast.37 = f32[16,2,2048]{2,1,0} broadcast(f32[] %constant.36), dimensions={}
  %multiply.38 = f32[16,2,2048]{2,1,0} multiply(f32[16,2,2048]{2,1,0} %convert.35, f32[16,2,2048]{2,1,0} %broadcast.37)
  %constant.24 = f32[] constant(1)
  %constant.25 = f32[] constant(0)
  %subtract.39 = f32[] subtract(f32[] %constant.24, f32[] %constant.25)
  %broadcast.40 = f32[16,2,2048]{2,1,0} broadcast(f32[] %subtract.39), dimensions={}
  %multiply.41 = f32[16,2,2048]{2,1,0} multiply(f32[16,2,2048]{2,1,0} %multiply.38, f32[16,2,2048]{2,1,0} %broadcast.40)
  %broadcast.42 = f32[16,2,2048]{2,1,0} broadcast(f32[] %constant.25), dimensions={}
  %add.43 = f32[16,2,2048]{2,1,0} add(f32[16,2,2048]{2,1,0} %multiply.41, f32[16,2,2048]{2,1,0} %broadcast.42)
  %slice.45 = f32[16,1,2048]{2,1,0} slice(f32[16,2,2048]{2,1,0} %add.43), slice={[0:16], [1:2], [0:2048]}
  %multiply.51 = f32[16,1,2048]{2,1,0} multiply(f32[16,1,2048]{2,1,0} %broadcast.50, f32[16,1,2048]{2,1,0} %slice.45)
  %sine.57 = f32[16,1,2048]{2,1,0} sine(f32[16,1,2048]{2,1,0} %multiply.51)
  %constant.53 = f32[] constant(-2)
  %broadcast.54 = f32[16,1,2048]{2,1,0} broadcast(f32[] %constant.53), dimensions={}
  %slice.44 = f32[16,1,2048]{2,1,0} slice(f32[16,2,2048]{2,1,0} %add.43), slice={[0:16], [0:1], [0:2048]}
  %constant.46 = f32[] constant(1e-07)
  %broadcast.47 = f32[16,1,2048]{2,1,0} broadcast(f32[] %constant.46), dimensions={}
  %maximum.48 = f32[16,1,2048]{2,1,0} maximum(f32[16,1,2048]{2,1,0} %slice.44, f32[16,1,2048]{2,1,0} %broadcast.47)
  %log.52 = f32[16,1,2048]{2,1,0} log(f32[16,1,2048]{2,1,0} %maximum.48)
  %multiply.55 = f32[16,1,2048]{2,1,0} multiply(f32[16,1,2048]{2,1,0} %broadcast.54, f32[16,1,2048]{2,1,0} %log.52)
  %sqrt.56 = f32[16,1,2048]{2,1,0} sqrt(f32[16,1,2048]{2,1,0} %multiply.55)
  %multiply.58 = f32[16,1,2048]{2,1,0} multiply(f32[16,1,2048]{2,1,0} %sine.57, f32[16,1,2048]{2,1,0} %sqrt.56)
  %cosine.59 = f32[16,1,2048]{2,1,0} cosine(f32[16,1,2048]{2,1,0} %multiply.51)
  %multiply.60 = f32[16,1,2048]{2,1,0} multiply(f32[16,1,2048]{2,1,0} %cosine.59, f32[16,1,2048]{2,1,0} %sqrt.56)
  %concatenate.61 = f32[16,2,2048]{2,1,0} concatenate(f32[16,1,2048]{2,1,0} %multiply.58, f32[16,1,2048]{2,1,0} %multiply.60), dimensions={1}
  %reshape.62 = f32[32,2048]{1,0} reshape(f32[16,2,2048]{2,1,0} %concatenate.61)
  %constant.12 = f32[] constant(1)
  %reshape.13 = f32[1,1]{1,0} reshape(f32[] %constant.12)
  %broadcast.14 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.13), dimensions={0,1}
  %reshape.15 = f32[] reshape(f32[1,1]{1,0} %broadcast.14)
  %broadcast.16 = f32[32,2048]{1,0} broadcast(f32[] %reshape.15), dimensions={}
  %multiply.63 = f32[32,2048]{1,0} multiply(f32[32,2048]{1,0} %reshape.62, f32[32,2048]{1,0} %broadcast.16)
  %add.64 = f32[32,2048]{1,0} add(f32[32,2048]{1,0} %broadcast.21, f32[32,2048]{1,0} %multiply.63)
  %custom-call.65 = f32[32,2048]{1,0} custom-call(f32[32,2048]{1,0} %add.64), custom_call_target="Sharding", sharding={devices=[1,8,4]0,8,16,24,1,9,17,25,2,10,18,26,3,11,19,27,4,12,20,28,5,13,21,29,6,14,22,30,7,15,23,31 last_tile_dim_replicate}
  %constant.67 = f32[] constant(1)
  %constant.66 = f32[] constant(1)
  %multiply.68 = f32[] multiply(f32[] %constant.67, f32[] %constant.66)
  %broadcast.69 = f32[2048]{0} broadcast(f32[] %multiply.68), dimensions={}
  %add.70 = f32[2048]{0} add(f32[2048]{0} %custom-call.6, f32[2048]{0} %broadcast.69)
  %p1.71 = f32[] parameter(1), sharding={replicated}
  %broadcast.72 = f32[32,2048]{1,0} broadcast(f32[] %p1.71), dimensions={}
  %multiply.73 = f32[32,2048]{1,0} multiply(f32[32,2048]{1,0} %custom-call.65, f32[32,2048]{1,0} %broadcast.72)
  ROOT %tuple.74 = (f32[2048]{0}, f32[32,2048]{1,0}, f32[2048]{0}, f32[32,2048]{1,0}) tuple(f32[2048]{0} %custom-call.6, f32[32,2048]{1,0} %custom-call.65, f32[2048]{0} %add.70, f32[32,2048]{1,0} %multiply.73)
}


Graph Hash: 49385ca12d8df41cd1a2ecc9c910fcfa

## END_GRAPH


#OUTPUT_SHARDING_BEGIN

f32[2048] {devices=[4,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 last_tile_dim_replicate}
f32[32,2048] {devices=[1,8,4]0,8,16,24,1,9,17,25,2,10,18,26,3,11,19,27,4,12,20,28,5,13,21,29,6,14,22,30,7,15,23,31 last_tile_dim_replicate}
f32[2048] {devices=[4,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 last_tile_dim_replicate}
f32[32,2048] {devices=[1,8,4]0,8,16,24,1,9,17,25,2,10,18,26,3,11,19,27,4,12,20,28,5,13,21,29,6,14,22,30,7,15,23,31 last_tile_dim_replicate}

#OUTPUT_SHARDING_END

It just seems that the HLO text from the computation does not show the sharding but if using XLA_SAVE_TENSORS_FILE it does. At the same time, avoiding using this flag to manage an external file in the test. Let me know if we want it instead.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_spmd_lc branch 2 times, most recently from 3fae03a to 476602d Compare November 26, 2024 23:27
@rpsilva-aws rpsilva-aws marked this pull request as ready for review November 26, 2024 23:54
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_spmd_lc branch 5 times, most recently from f79218f to 449a6ee Compare December 4, 2024 00:16
@rpsilva-aws
Copy link
Contributor Author

Added more coverage with IR dump graphs + test cleanup commit to avoid lingering files in case of failures.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 4, 2024

@tengyifei

@tengyifei
Copy link
Collaborator

I'll take a look

@rpsilva-aws
Copy link
Contributor Author

Various conflicts, synced and moved to #8471.

@rpsilva-aws rpsilva-aws deleted the rpsilva_spmd_lc branch December 9, 2024 18:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants