Skip to content

Commit

Permalink
[hotfix] Fix ShardFormer test execution path when using sequence para…
Browse files Browse the repository at this point in the history
…llelism (#5230)
  • Loading branch information
KKZ20 authored Jan 17, 2024
1 parent 46e0916 commit 5d9a0ae
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _criterion(outputs, inputs):

data = data_gen_fn()

if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data["input_ids"].shape[-1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
Expand Down

0 comments on commit 5d9a0ae

Please sign in to comment.