diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py index b37adf7124ca..d0f713a7249c 100644 --- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py +++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py @@ -365,6 +365,11 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): def main(args): + if args.seed is None: + colossalai.launch_from_torch(config={}) + else: + colossalai.launch_from_torch(config={}, seed=args.seed) + colossalai.launch_from_torch(config={}) if args.seed is not None: