diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 208d3feb..1f1993ff 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -140,9 +140,8 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=-1, - tensor=2, - # pipeline=dict(size=2, interleaved_overlap=True), + zero1=8, + pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=False, )