Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM on train on 48GPU with batch 1 proc 1 #49

Closed
recoilme opened this issue Nov 26, 2024 · 56 comments
Closed

OOM on train on 48GPU with batch 1 proc 1 #49

recoilme opened this issue Nov 26, 2024 · 56 comments
Labels
Answered Answered the question

Comments

@recoilme
Copy link

i waiting sana so long for training on potato, but its not working on A40 with 48GPU(

(sana) root@c88159d783a4:/workspace/sana# bash train_scripts/train.sh   configs/sana_config/1024ms/Sana_1600M_img1024.yaml   --data.data_dir="[asset/example_data]"   --data.type=SanaImgDataset   --model.multi_scale=false
2024-11-26 23:33:11 - [Sana] - INFO - Distributed environment: MULTI_GPU  Backend: nccl
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda:0

Mixed precision type: fp16

2024-11-26 23:33:11 - [Sana] - INFO - Config: 
{
    "data": {
        "data_dir": [
            "asset/example_data"
        ],
        "caption_proportion": {
            "prompt": 1
        },
        "external_caption_suffixes": [
            "",
            "_InternVL2-26B",
            "_VILA1-5-13B"
        ],
        "external_clipscore_suffixes": [
            "_InternVL2-26B_clip_score",
            "_VILA1-5-13B_clip_score",
            "_prompt_clip_score"
        ],
        "clip_thr_temperature": 0.1,
        "clip_thr": 25.0,
        "sort_dataset": false,
        "load_text_feat": false,
        "load_vae_feat": false,
        "transform": "default_train",
        "type": "SanaImgDataset",
        "image_size": 1024,
        "hq_only": false,
        "valid_num": 0,
        "data": null,
        "extra": null
    },
    "model": {
        "model": "SanaMS_1600M_P1_D20",
        "image_size": 1024,
        "mixed_precision": "fp16",
        "fp32_attention": true,
        "load_from": null,
        "resume_from": {
            "checkpoint": "latest",
            "load_ema": false,
            "resume_optimizer": true,
            "resume_lr_scheduler": true
        },
        "aspect_ratio_type": "ASPECT_RATIO_1024",
        "multi_scale": false,
        "pe_interpolation": 1.0,
        "micro_condition": false,
        "attn_type": "linear",
        "autocast_linear_attn": false,
        "ffn_type": "glumbconv",
        "mlp_acts": [
            "silu",
            "silu",
            null
        ],
        "mlp_ratio": 2.5,
        "use_pe": false,
        "qk_norm": false,
        "class_dropout_prob": 0.1,
        "linear_head_dim": 32,
        "cross_norm": false,
        "cfg_scale": 4,
        "guidance_type": "classifier-free",
        "pag_applied_layers": [
            8
        ],
        "extra": null
    },
    "vae": {
        "vae_type": "dc-ae",
        "vae_pretrained": "mit-han-lab/dc-ae-f32c32-sana-1.0",
        "scale_factor": 0.41407,
        "vae_latent_dim": 32,
        "vae_downsample_rate": 32,
        "sample_posterior": true,
        "extra": null
    },
    "text_encoder": {
        "text_encoder_name": "gemma-2-2b-it",
        "caption_channels": 2304,
        "y_norm": true,
        "y_norm_scale_factor": 0.01,
        "model_max_length": 300,
        "chi_prompt": [
            "Given a user prompt, generate an \"Enhanced prompt\" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
            "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
            "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
            "Here are examples of how to transform or refine prompts:",
            "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
            "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
            "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
            "User Prompt: "
        ],
        "extra": null
    },
    "scheduler": {
        "train_sampling_steps": 1000,
        "predict_v": true,
        "noise_schedule": "linear_flow",
        "pred_sigma": false,
        "learn_sigma": true,
        "vis_sampler": "flow_dpm-solver",
        "flow_shift": 3.0,
        "weighting_scheme": "logit_normal",
        "logit_mean": 0.0,
        "logit_std": 1.0,
        "extra": null
    },
    "train": {
        "num_workers": 1,
        "seed": 1,
        "train_batch_size": 1,
        "num_epochs": 100,
        "gradient_accumulation_steps": 1,
        "grad_checkpointing": true,
        "gradient_clip": 0.1,
        "gc_step": 1,
        "optimizer": {
            "betas": [
                0.9,
                0.999,
                0.9999
            ],
            "eps": [
                1e-30,
                1e-16
            ],
            "lr": 0.0001,
            "type": "CAMEWrapper",
            "weight_decay": 0.0
        },
        "lr_schedule": "constant",
        "lr_schedule_args": {
            "num_warmup_steps": 2000
        },
        "auto_lr": {
            "rule": "sqrt"
        },
        "ema_rate": 0.9999,
        "eval_batch_size": 16,
        "use_fsdp": false,
        "use_flash_attn": false,
        "eval_sampling_steps": 500,
        "lora_rank": 4,
        "log_interval": 1,
        "mask_type": "null",
        "mask_loss_coef": 0.0,
        "load_mask_index": false,
        "snr_loss": false,
        "real_prompt_ratio": 1.0,
        "save_image_epochs": 1,
        "save_model_epochs": 5,
        "save_model_steps": 500,
        "visualize": true,
        "null_embed_root": "output/pretrained_models/",
        "valid_prompt_embed_root": "output/tmp_embed/",
        "validation_prompts": [
            "dog",
            "portrait photo of a girl, photograph, highly detailed face, depth of field",
            "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
            "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
            "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece"
        ],
        "local_save_vis": true,
        "deterministic_validation": true,
        "online_metric": false,
        "eval_metric_step": 2000,
        "online_metric_dir": "metric_helper",
        "work_dir": "output/debug",
        "skip_step": 0,
        "loss_type": "huber",
        "huber_c": 0.001,
        "num_ddim_timesteps": 50,
        "w_max": 15.0,
        "w_min": 3.0,
        "ema_decay": 0.95,
        "debug_nan": false,
        "extra": null
    },
    "work_dir": "output/debug",
    "resume_from": "latest",
    "load_from": null,
    "debug": true,
    "caching": false,
    "report_to": "tensorboard",
    "tracker_project_name": "t2i-evit-baseline",
    "name": "tmp",
    "loss_report_name": "loss"
}
2024-11-26 23:33:11 - [Sana] - INFO - World_size: 1, seed: 1
2024-11-26 23:33:11 - [Sana] - INFO - Initializing: DDP for training
[DC-AE] Loading model from mit-han-lab/dc-ae-f32c32-sana-1.0
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.42it/s]
2024-11-26 23:33:16 - [Sana] - INFO - vae type: dc-ae
2024-11-26 23:33:16 - [Sana] - INFO - Complex Human Instruct: Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.
Here are examples of how to transform or refine prompts:
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
User Prompt: 
2024-11-26 23:33:16 - [Sana] - INFO - v-prediction: True, noise schedule: linear_flow, flow shift: 3.0, flow weighting: logit_normal, logit-mean: 0.0, logit-std: 1.0
2024-11-26 23:33:28 - [Sana] - WARNING - use pe: False, position embed interpolation: 1.0, base size: 32
2024-11-26 23:33:28 - [Sana] - WARNING - attention type: linear; ffn type: glumbconv; autocast linear attn: false
2024-11-26 23:33:41 - [Sana] - INFO - SanaMS:SanaMS_1600M_P1_D20, Model Parameters: 1604.46M
2024-11-26 23:33:41 - [Sana] - INFO - Constructing dataset SanaImgDataset...
2024-11-26 23:33:41 - [Sana] - INFO - Dataset is repeat 2000 times for toy dataset
2024-11-26 23:33:41 - [Sana] - INFO - Dataset samples: 4000
2024-11-26 23:33:41 - [Sana] - INFO - Loading external caption json from: original_filename['', '_InternVL2-26B', '_VILA1-5-13B'].json
2024-11-26 23:33:41 - [Sana] - INFO - Loading external clipscore json from: original_filename['_InternVL2-26B_clip_score', '_VILA1-5-13B_clip_score', '_prompt_clip_score'].json
2024-11-26 23:33:41 - [Sana] - INFO - external caption clipscore threshold: 25.0, temperature: 0.1
2024-11-26 23:33:41 - [Sana] - INFO - T5 max token length: 300
2024-11-26 23:33:41 - [Sana] - INFO - Dataset SanaImgDataset constructed: time: 0.00 s, length (use/ori): 4000/4000
2024-11-26 23:33:41 - [Sana] - INFO - Automatically adapt lr to 0.00001 (using sqrt scaling rule).
2024-11-26 23:33:41 - [Sana] - INFO - CAMEWrapper Optimizer: total 316 param groups, 316 are learnable, 0 are fix. Lr group: 316 params with lr 0.00001; Weight decay group: 316 params with weight decay 0.0.
2024-11-26 23:33:41 - [Sana] - INFO - Lr schedule: constant, num_warmup_steps:2000.
2024-11-26 23:33:41 - [Sana] - WARNING - Basic Setting: lr: 0.00001, bs: 1, gc: True, gc_accum_step: 1, qk norm: False, fp32 attn: True, attn type: linear, ffn type: glumbconv, text encoder: gemma-2-2b-it, captions: {'prompt': 1}, precision: fp16
2024-11-26 23:33:58 - [Sana] - INFO - Epoch: 1 | Global Step: 1 | Local Step: 1 // 4000, total_eta: 71 days, 3:34:13, epoch_eta:17:04:17, time: all:15.368, model:14.333, data:0.201, lm:0.304, vae:0.529, lr:3.125e-09, Cap: VILA1-5-13B, s:(32, 32), loss:4.3361, grad_norm:61.1122
2024-11-26 23:33:58 - [Sana] - INFO - Running validation... 
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 441, in model_fn
[rank0]:     noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 386, in noise_pred_fn
[rank0]:     output = model(x, t_input, cond, **model_kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 348, in forward_with_dpmsolver
[rank0]:     model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/accelerate/utils/operations.py", line 823, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/accelerate/utils/operations.py", line 811, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 327, in forward
[rank0]:     x = auto_grad_checkpoint(
[rank0]:   File "/workspace/sana/diffusion/model/utils.py", line 72, in auto_grad_checkpoint
[rank0]:     return checkpoint(module, *args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
[rank0]:     return CheckpointFunction.apply(function, preserve, *args)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward
[rank0]:     outputs = run_function(*args)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 162, in forward
[rank0]:     x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_blocks.py", line 160, in forward
[rank0]:     qkv = self.qkv(x).reshape(B, N, 3, C)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
[rank0]:     return F.linear(input, self.weight, self.bias)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB. GPU 0 has a total capacity of 44.34 GiB of which 20.81 MiB is free. Process 265335 has 44.31 GiB memory in use. Of the allocated memory 42.84 GiB is allocated by PyTorch, and 837.16 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

[rank0]: During handling of the above exception, another exception occurred:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 974, in <module>
[rank0]:     main()
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/pyrallis/argparsing.py", line 158, in wrapper_inner
[rank0]:     response = fn(cfg, *args, **kwargs)
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 959, in main
[rank0]:     train(
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 479, in train
[rank0]:     log_validation(
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 154, in log_validation
[rank0]:     image_logs += run_sampling(init_z=None, label_suffix="", vae=vae, sampler=vis_sampler)
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 127, in run_sampling
[rank0]:     denoised = dpm_solver.sample(
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 1529, in sample
[rank0]:     model_prev_list = [self.model_fn(x, t)]
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 689, in model_fn
[rank0]:     return self.data_prediction_fn(x, t)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 677, in data_prediction_fn
[rank0]:     noise = self.noise_prediction_fn(x, t)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 671, in noise_prediction_fn
[rank0]:     return self.model(x, t)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 616, in <lambda>
[rank0]:     self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 443, in model_fn
[rank0]:     noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in)[0].chunk(2)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 386, in noise_pred_fn
[rank0]:     output = model(x, t_input, cond, **model_kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 348, in forward_with_dpmsolver
[rank0]:     model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/accelerate/utils/operations.py", line 823, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/accelerate/utils/operations.py", line 811, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 327, in forward
[rank0]:     x = auto_grad_checkpoint(
[rank0]:   File "/workspace/sana/diffusion/model/utils.py", line 72, in auto_grad_checkpoint
[rank0]:     return checkpoint(module, *args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
[rank0]:     return CheckpointFunction.apply(function, preserve, *args)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward
[rank0]:     outputs = run_function(*args)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 162, in forward
[rank0]:     x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_blocks.py", line 160, in forward
[rank0]:     qkv = self.qkv(x).reshape(B, N, 3, C)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
[rank0]:     return F.linear(input, self.weight, self.bias)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB. GPU 0 has a total capacity of 44.34 GiB of which 20.81 MiB is free. Process 265335 has 44.31 GiB memory in use. Of the allocated memory 42.87 GiB is allocated by PyTorch, and 800.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
E1126 15:34:00.242000 124914645387072 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 18768) of binary: /root/miniconda3/envs/sana/bin/python
Traceback (most recent call last):
  File "/root/miniconda3/envs/sana/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
    return f(*args, **kwargs)
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/run.py", line 901, in main
    run(args)
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train_scripts/train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-11-26_15:34:00
  host      : c88159d783a4
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 18768)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
(sana) root@c88159d783a4:/workspace/sana# 
@lawrence-cj
Copy link
Collaborator

Seems my training is under 48GB

bash train_scripts/train.sh configs/sana_config/1024ms/Sana_1600M_img1024.yaml --data.data_dir="[asset/example_data]" --data.type=SanaImgDataset --model.multi_scale=false --data.load_vae_feat=false --train.train_batch_size=1

refer to:
image

Actually if you switch optimizer type in config file to AdamW, the GPU memory will be less. We will update a newer Came in the future, which will occupy even less than AdamW.

train:
  optimizer:
    lr: 1.0e-4
    type: AdamW
    weight_decay: 0.01
    eps: 1.0e-8
    betas: [0.9, 0.999]

refer to:
image

@lawrence-cj lawrence-cj added the Answered Answered the question label Nov 26, 2024
@recoilme
Copy link
Author

But how you do it? You have 32x VAE vs 8x in SDXL, less model size and need more then 2.5x against SDXL for train in 1024?

@lawrence-cj
Copy link
Collaborator

lawrence-cj commented Nov 26, 2024

The model you are using is 1.6B and all the VAE and Text Encoder are all extracting feature online.

@FurkanGozukara
Copy link

it is all about training scripts

Currently with using Kohya we are able to fully fine to 12 billion parameters FLUX dev in 16 bit precision even on 6 GB GPUs via using block swapping :)

@lawrence-cj
Copy link
Collaborator

Cool. Hh, I can't imagine the speed. lol : )

@recoilme
Copy link
Author

i'm very hope what you will add some minimum optimizations in the future..
its dead for full fine tuning for now, A100 is very expensive
Thx for reply and good model!

@lawrence-cj
Copy link
Collaborator

What do you mean it's dead?

@FurkanGozukara
Copy link

Cool. Hh, I can't imagine the speed. lol : )

with latest improvements speeds are really decent

rtx 3090 is 7 second per sample image - batch size 1
rtx 4090 is like 5 second per sample image
RTX A6000 is like 6 second per sample image

@recoilme
Copy link
Author

dead

it's dead for full fine-tunings from GPU poor guys
We rent GPU for train. It's very expensive. 48GPU+ for train with batch 1 - its stop factor for most of us
We need latent/TE caching and multi aspect ratio
and probably slow optimizer like adafactor for train fine details like eyes with low LR

@recoilme recoilme reopened this Nov 26, 2024
@Muinez
Copy link

Muinez commented Nov 26, 2024

Could you please try this:
#50

@lawrence-cj
Copy link
Collaborator

@Muinez , nice man. I'll try this out.

@recoilme
Copy link
Author

recoilme commented Nov 27, 2024

@Muinez thank you! bucketing works nice, code is clean and very simple!

i add checkpoints dir creation around line 500 (it fail on 1st run):

checkpoints_dir = osp.join(config.work_dir, "checkpoints")
os.makedirs(checkpoints_dir, exist_ok=True)

its not work for me on 48Gb/A40 with batch 1 with CAME optimizer (default settings, exclude batch/num proc) - OOM after generating validation images with wandb (probably after decoding - you may see memory spikes at validation)

but its looks like it do something strange and work on Adam https://wandb.ai/recoilme/sana/runs/1?nw=nwuserrecoilme
Снимок экрана 2024-11-27 в 12 05 10
Looks like i train from clean checkpoint, something wrong?

What do you think about add adafactor? Is it possible? I have totally better results with this optimizer on SDXL. Adam is like tank - good only in pretrain case..

@lawrence-cj
Copy link
Collaborator

lawrence-cj commented Nov 27, 2024

The loss means you are training from scratch. (no ckpts loaded)

@recoilme
Copy link
Author

probably i fogot model path and will have my personal SANA, soon
rXh7wFLD7xQKUokjupxI7yYgm14-1920

@lawrence-cj
Copy link
Collaborator

Looking forward to your news. If you meet NaN problem, you could try to change your config file with:

model:
  mixed_precision: bf16

, which is mentioned in other issue.

@lawrence-cj
Copy link
Collaborator

I'll look into adafactor when I'm free. Bandwidth is full recently.

@recoilme
Copy link
Author

recoilme commented Nov 28, 2024

Looking forward to your news.

1st of all im GPU poor, and train on old A40, slowly, sorry for that

I train not in one big train, but step by step

  • train a little on small dataset
  • analyze result
  • change params like LR/optimizer and add more data in dataset
  • train again

So instead 1 big train with high LR on billions 100 ep, i train by 4-8 ep on small dataset, in wich i add new data to learn on each iteration
It's like human children's learn, then its lear base knowledge - i add new knowledge

So far, 2 steps done

1st step train 4 ep on 4k dataset (just mix of very aesthetic) with 1e-4
https://wandb.ai/recoilme/sana/runs/3

Снимок экрана 2024-11-28 в 10 11 37

2 step - add 2k images in dataset (new concept, illustrations) and train 4ep (it small - here 8ep will works better)
https://wandb.ai/recoilme/sana/runs/6

Снимок экрана 2024-11-28 в 10 11 01

For now i think add 4k anime imgs and maybe switch on bf16? (in sdxl bf16 works much better then fp16, fp16 train just not work in sdxl).
May i just change precision and thats all? May be switch on adambf16 needed? Must i recalculate dc-ae latents with bf16 @Muinez ?

What do you think in general? Is it good result for 1 day? Any suggestions? (i dont train from zero before)
ps last checkpoint https://huggingface.co/datasets/recoilme/ae/blob/main/potato.fp16.pth

@recoilme
Copy link
Author

May be switch on adambf16 needed?

Looks like just switch precision don't work out of the box

Traceback (most recent call last):
  File "/home/Sana/train_scripts/train_local.py", line 940, in <module>
    main()
  File "/usr/local/lib/python3.10/dist-packages/pyrallis/argparsing.py", line 158, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/Sana/train_scripts/train_local.py", line 846, in main
    optimizer = build_optimizer(model, config.train.optimizer)
  File "/home/Sana/diffusion/utils/optimizer.py", line 156, in build_optimizer
    optimizer = mm_build_optimizer(model, optimizer_cfg)
  File "/usr/local/lib/python3.10/dist-packages/mmcv/runner/optimizer/builder.py", line 65, in build_optimizer
    optimizer = optim_constructor(model)
  File "/usr/local/lib/python3.10/dist-packages/mmcv/runner/optimizer/default_constructor.py", line 255, in __call__
    self.add_params(params, model)
  File "/home/Sana/diffusion/utils/optimizer.py", line 133, in add_params
    self.add_params(params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module)
  File "/home/Sana/diffusion/utils/optimizer.py", line 133, in add_params
    self.add_params(params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module)
  File "/home/Sana/diffusion/utils/optimizer.py", line 76, in add_params
    base_lr *= bias_lr_mult
TypeError: can't multiply sequence by non-int of type 'float'

Looking at this code now, but not sure its good https://github.com/bghira/SimpleTuner/blob/main/helpers/training/optimizers/adamw_bfloat16/__init__.py

@lawrence-cj
Copy link
Collaborator

no need to change the optimizer to bf16

@lawrence-cj
Copy link
Collaborator

Are you training from scratch or loaded a Pretrained checkpoints?

@recoilme
Copy link
Author

recoilme commented Nov 28, 2024

Are you training from scratch or loaded a Pretrained checkpoints?

i train from scratch (it's started by my mistake, lol, i forgot load_from)

no need to change the optimizer to bf16

ok, switched on bf16 and let it try
looks i start from zero noise again (switch from fp16 2 bf16 broke model) So slow.. Do you have some GPU grants on research? )
https://wandb.ai/recoilme/sana/runs/8

@recoilme
Copy link
Author

Looks like just switch precision don't work out of the box

it mistake in config, lr:9e-5 (int) instead 9.0e-5

@recoilme
Copy link
Author

no need to change the optimizer to bf16

if i make buckets in fp16, make model in bf16 and start train

Does it mean what we may train in bf16 and convert for inference in fp16 and it will work? What do you think?

@Muinez
Copy link

Muinez commented Nov 28, 2024

You don't need to redo the VAE latents for bf16, and I think it's pointless to train from scratch because it will take too long to achieve reasonable results

@Muinez
Copy link

Muinez commented Nov 28, 2024

I'm also currently trying to train a 512 model with a batch size of 96, and I'm noticing that the model requires a high learning rate (5e-5) and is learning extremely slowly. After over 10,000 steps, it still hasn't fully grasped the concept, and the samples don't look great. It seems the model is indeed struggling to work with this VAE. I'm also training on low-quality images for the negative prompt, and while it helps a bit, the improvements are minimal.

@recoilme
Copy link
Author

@Muinez pls take a look at #50 (comment) looks critical if i right
also some minors like checkpoints dir creation, i also notices some other bugs, but its minor and not sure how do better, like continue training counters, reusing config and params and so on, will take a look later

about train from scratch - its research - how faster train in bf16 and how fast train in general

@lawrence-cj
Copy link
Collaborator

lawrence-cj commented Nov 28, 2024

We have a totally different diffusion training schedule from SDXL. Flow matching vs DDPM. The loss is coursed by this, and nothing with the VAE here. If you try to train SD3, I think the loss should be larger than 0.6. Sana's loss will converge at ~0.7 at 1024px and ~0.8 at 512px.

Refer to SD3 paper:
image

Refer to Lumina:
image

@lawrence-cj
Copy link
Collaborator

lawrence-cj commented Nov 28, 2024

I'm noticing that the model requires a high learning rate (5e-5) and is learning extremely slowly.

Actually, we train with 1e-4 if the new dataset has many new concepts.

@Muinez
Copy link

Muinez commented Nov 28, 2024

This might be a silly question since I don't have much experience with this, but if we add a sampling step to the training pipeline, could it help the model mitigate its own errors? I mean, instead of providing the model with an almost perfect noise-to-signal ratio, it would also learn to correct itself along the way

@lawrence-cj
Copy link
Collaborator

if we add a sampling step to the training pipeline

Could you pls explain more about it?

@Muinez
Copy link

Muinez commented Nov 28, 2024

if we add a sampling step to the training pipeline

Could you pls explain more about it?

Make the model predict from a noised sample, treat it as a sampling step, add (or subtract, can’t remember) a random percentage of the prediction back to the noised sample, then make another prediction based on this modified version. After that, calculate the MSE loss between this final prediction and the actual target

@lawrence-cj
Copy link
Collaborator

What's the insight? What's the difference between the MSE loss between the actual target and the model prediction directly?

@Muinez
Copy link

Muinez commented Nov 28, 2024

What's the insight? What's the difference between the MSE loss between the actual target and the model prediction directly?

I’m not entirely sure, but the idea is that by letting the model make a prediction, we’re assuming it will introduce some errors, and the loss would then train the model to predict the correct image from that "incorrect" noise. Whereas if we’re just training from a perfect sample where we added the noise ourselves, it doesn’t contain those kinds of errors. Or am I wrong?

@Muinez
Copy link

Muinez commented Nov 28, 2024

#pseudocode
noisy_data = add_noise(real_data, time_steps)

predicted_sample = predict(noisy_data, time_steps)

weighted_mean_sample = (noisy_data + time_steps * predicted_sample) / 2

final_prediction = predict(weighted_mean_sample, time_steps)

loss = (criterion(real_data, final_prediction) + criterion(real_data, predicted_sample)) / 2

Or maybe something like this could work, though I'm not sure if I’ve got it right

@lawrence-cj
Copy link
Collaborator

Hh, this is more like something done by CM(consistency model)

@Muinez
Copy link

Muinez commented Nov 28, 2024

Hh, this is more like something done by CM(consistency model)

I didn’t know, I was just coming from the idea that the model is trained to predict data without noise. That’s why I decided to take the average to mix noisy data and sample predicted by the model, which still has errors, and then predict the target data based on that. Anyway, I hope you get the idea. I’m not sure if you’ll develop this further or do anything with it. Looking at the code now, I probably didn’t do it quite right. Maybe we should’ve applied add_noise on time_steps * predicted_sample, but I’m not sure

@lawrence-cj
Copy link
Collaborator

Yah, I totally understand your statement. Cool idea. I have tried a similar idea before. refer to: Pixart-delta, which combines PixArt and LCM together. LCM or CM is something like you said.

Refer to: https://arxiv.org/abs/2401.05252

@Muinez
Copy link

Muinez commented Nov 29, 2024

I'm also currently trying to train a 512 model

https://huggingface.co/Muinez/sana-512-anime

This is what I ended up with. Negative prompt: bad artwork

@Muinez
Copy link

Muinez commented Nov 29, 2024

@lawrence-cj
Copy link
Collaborator

I'm also currently trying to train a 512 model

https://huggingface.co/Muinez/sana-512-anime

This is what I ended up with. Negative prompt: bad artwork

Oh, looks pretty cool on my side. 👍

@lawrence-cj
Copy link
Collaborator

https://github.com/NVlabs/Sana/blob/5765093947cbf7053463a8835974575da1278017/sana/tools/hf_utils.py#L32C71-L32C94

why download_full_repo=True 😭

Since we have a separate func for not downloading the full repo:
image

@Muinez
Copy link

Muinez commented Nov 29, 2024

https://github.com/NVlabs/Sana/blob/5765093947cbf7053463a8835974575da1278017/sana/tools/hf_utils.py#L32C71-L32C94
why download_full_repo=True 😭

Since we have a separate func for not downloading the full repo:

from_pretrained uses find_model, which uses hf_download_or_fpath. I asked about this because when I run the app with my model, it downloads a 19GB file that isn't needed since I specifically moved the state_dict to a separate file

@Muinez
Copy link

Muinez commented Nov 29, 2024

Oh, looks pretty cool on my side. 👍

Yes, the model has generally learned the quality of the dataset (style, etc.). For example, it generates scenery beautifully, but the details are quite poor, and it's not very coherent

@lawrence-cj
Copy link
Collaborator

from_pretrained uses find_model, which uses hf_download_or_fpath. I asked about this because when I run the app with my model, it downloads a 19GB file that isn't needed since I specifically moved the state_dict to a separate file

lol, got it. We suppose there is only related ckpt in the hf hub when using from_pretrained. Not a big deal, anyway. could be refined better. If you want, you can draft a PR also. 👍

@lawrence-cj
Copy link
Collaborator

Yes, the model has generally learned the quality of the dataset (style, etc.). For example, it generates scenery beautifully, but the details are quite poor, and it's not very coherent

Make sense, we will try to address the detail issue in our next version this year or at the beginning of next year.

@lawrence-cj
Copy link
Collaborator

Besides, the 1024 model will have much less detail issues you mentioned. Since for a 512px image, our vea will compress it into only 16x16 tokens. If you want, 1024px models may give you much better results.

@Muinez
Copy link

Muinez commented Nov 29, 2024

Besides, the 1024 model will have much less detail issues you mentioned. Since for a 512px image, our vea will compress it into only 16x16 tokens. If you want, 1024px models may give you much better results.

Well, I’m already out of money to rent a GPU 😅. Besides, I’ve already tried, and the results weren’t great—either I ended up with NaN loss or the model just trained much slower. Maybe I should have increased the learning rate back then, but I’m not sure how stable it would have been with a batch size of 30

@lawrence-cj
Copy link
Collaborator

ok, I see. I tried bf16 fine-tuning , no NaN and converge fast. Maybe we will released this more stable version later in the future

@Muinez
Copy link

Muinez commented Nov 29, 2024

By the way, about Sana’s low variability—maybe it’s because fewer tokens with higher dimensions make the model focus on precision over diversity. More tokens might allow for greater variation since each one wouldn’t need to carry as much global information. Or maybe it just feels like the model generates overly similar images, or it could even be a dataset issue, idk. What do you think?

@lawrence-cj
Copy link
Collaborator

Actually, I don't think the variability is related to the less token number. First, the CFG and PAG guidance scale will affect the variability. Then, model size is other core element affecting the generalization. Also, the data may affect.

@Muinez
Copy link

Muinez commented Dec 1, 2024

@lawrence-cj It seems like the 600M model has a lower loss than the 1600M model. The loss is around 0.6 on the 600M, while on the 1600M, it's around 0.7. It also feels like the model is learning new concepts much faster on its own. Shouldn't it be the other way around, or is it just me?

@recoilme
Copy link
Author

recoilme commented Dec 2, 2024

Looking forward to your news.

Снимок экрана 2024-12-02 в 12 40 53

@recoilme
Copy link
Author

recoilme commented Dec 2, 2024

Why its understand complex prompt (not from dataset), but 1girl - dont?

@lawrence-cj
Copy link
Collaborator

Seems like the BF16 training is corrupt the model at the very first beginning. We will release a BF16 version for fine-tuning later.

@recoilme
Copy link
Author

recoilme commented Dec 4, 2024

Seems like the BF16 training is corrupt the model at the very first beginning

ok, but its not retrain fp16 in bf16 its train from empty, zero, new model

@recoilme recoilme closed this as completed Dec 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Answered Answered the question
Projects
None yet
Development

No branches or pull requests

4 participants