Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip the gpu memory checks if the device is set to 'auto' #609

Merged
merged 4 commits into from
Sep 21, 2023

Conversation

winglian
Copy link
Collaborator

resolves #456

@Napuh
Copy link
Contributor

Napuh commented Sep 20, 2023

Unfortunately I haven't been able to get this branch to work, neither with examples/llama-2/lora.yml (with load_in_8bit: false, as in #456) nor with examples/llama-2/gptq-lora.yml (#599).

The same error keeps rising.

I've added a little bit of logging to try to understand where the error is located:

import logging

LOG = logging.getLogger("axolotl")

def check_cuda_device(default_value):
    """
    wraps a function and returns the default value instead of running the
    wrapped function if cuda isn't available or the device is auto
    :param default_value:
    :return:
    """

    def actual_decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            device = kwargs.get("device", args[0] if args else None)

            LOG.debug(f"Device in check_cuda_device: {device}")
            LOG.debug(f"Default value on check_cuda_device: {default_value}")

            if not torch.cuda.is_available() or device == "auto":
                return default_value

            return func(*args, **kwargs)

        return wrapper

    return actual_decorator

When i examine the output of the error I find something interesting:

[2023-09-20 07:00:15,180] [WARNING] [axolotl.scripts.finetune.do_cli:25] [PID:5374] [RANK:0] scripts/finetune.py will be replaced with calling axolotl.cli.train                                                                                                        
[2023-09-20 07:00:15,184] [INFO] [axolotl.validate_config:95] [PID:5374] [RANK:0] bf16 support detected, but not enabled for this configuration.                                                                                                                        
[2023-09-20 07:00:15,184] [WARNING] [axolotl.validate_config:155] [PID:5374] [RANK:0] We recommend setting `load_in_8bit: true` for LORA finetuning                                                                                                                     
[2023-09-20 07:00:15,257] [DEBUG] [axolotl.wrapper:24] [PID:5374] [RANK:0] Device in check_cuda_device: cuda:0                      
[2023-09-20 07:00:15,257] [DEBUG] [axolotl.wrapper:25] [PID:5374] [RANK:0] Default value on check_cuda_device: (0.0, 0.0, 0.0)      
[2023-09-20 07:00:15,258] [DEBUG] [axolotl.wrapper:24] [PID:5374] [RANK:0] Device in check_cuda_device: cuda:0                      
[2023-09-20 07:00:15,258] [DEBUG] [axolotl.wrapper:25] [PID:5374] [RANK:0] Default value on check_cuda_device: 0.0                  
[2023-09-20 07:00:15,259] [INFO] [axolotl.normalize_config:89] [PID:5374] [RANK:0] GPU memory usage baseline: 0.000GB (+0.345GB misc)                                                                                                                                   
[2023-09-20 07:00:15,387] [DEBUG] [axolotl.load_tokenizer:74] [PID:5374] [RANK:0] EOS: 2 / </s>                                     
[2023-09-20 07:00:15,387] [DEBUG] [axolotl.load_tokenizer:75] [PID:5374] [RANK:0] BOS: 1 / <s>                                      
[2023-09-20 07:00:15,387] [DEBUG] [axolotl.load_tokenizer:76] [PID:5374] [RANK:0] PAD: 2 / </s>                                     
[2023-09-20 07:00:15,387] [DEBUG] [axolotl.load_tokenizer:77] [PID:5374] [RANK:0] UNK: 0 / <unk>                                    
[2023-09-20 07:00:15,462] [INFO] [axolotl.load_tokenized_prepared_datasets:128] [PID:5374] [RANK:0] Loading prepared dataset from disk at last_run_prepared/ad149256d2226c66eef84cba1806c06f...                                                                         
[2023-09-20 07:00:15,463] [INFO] [axolotl.load_tokenized_prepared_datasets:130] [PID:5374] [RANK:0] Prepared dataset loaded from disk...                                                                                                                                
[2023-09-20 07:00:15,509] [INFO] [axolotl.calculate_total_num_steps:513] [PID:5374] [RANK:0] total_num_steps: 5940                  
[2023-09-20 07:00:15,511] [INFO] [axolotl.train.train:49] [PID:5374] [RANK:0] loading tokenizer... TheBloke/Llama-2-7B-GPTQ         
[2023-09-20 07:00:15,627] [DEBUG] [axolotl.load_tokenizer:74] [PID:5374] [RANK:0] EOS: 2 / </s>                                     
[2023-09-20 07:00:15,627] [DEBUG] [axolotl.load_tokenizer:75] [PID:5374] [RANK:0] BOS: 1 / <s>                                      
[2023-09-20 07:00:15,627] [DEBUG] [axolotl.load_tokenizer:76] [PID:5374] [RANK:0] PAD: 2 / </s>                                     
[2023-09-20 07:00:15,627] [DEBUG] [axolotl.load_tokenizer:77] [PID:5374] [RANK:0] UNK: 0 / <unk>                                    
[2023-09-20 07:00:15,701] [INFO] [axolotl.train.train:57] [PID:5374] [RANK:0] loading model and (optionally) peft_config...
[2023-09-20 07:01:16,385] [INFO] [axolotl.load_model:386] [PID:5374] [RANK:0] converting modules to torch.float16 for flash attention
trainable params: 8,388,608 || all params: 270,798,848 || trainable%: 3.097726619575575
[2023-09-20 07:01:16,614] [DEBUG] [axolotl.wrapper:24] [PID:5374] [RANK:0] Device in check_cuda_device: cpu
[2023-09-20 07:01:16,614] [DEBUG] [axolotl.wrapper:25] [PID:5374] [RANK:0] Default value on check_cuda_device: (0.0, 0.0, 0.0)
Traceback (most recent call last):
  File "/home/axolotl/scripts/finetune.py", line 52, in <module>
    fire.Fire(do_cli)
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/axolotl/scripts/finetune.py", line 48, in do_cli
    train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
  File "/home/axolotl/src/axolotl/train.py", line 58, in train
    model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
  File "/home/axolotl/src/axolotl/utils/models.py", line 422, in load_model
    log_gpu_memory_usage(LOG, "after adapters", model.device)
  File "/home/axolotl/src/axolotl/utils/bench.py", line 69, in log_gpu_memory_usage
    usage, cache, misc = gpu_memory_usage_all(device)
  File "/home/axolotl/src/axolotl/utils/bench.py", line 30, in wrapper
    return func(*args, **kwargs)
  File "/home/axolotl/src/axolotl/utils/bench.py", line 44, in gpu_memory_usage_all
    usage = torch.cuda.memory_allocated(device) / 1024.0**3
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/memory.py", line 351, in memory_allocated
    return memory_stats(device=device).get("allocated_bytes.all.current", 0)
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/memory.py", line 230, in memory_stats
      stats = memory_stats_as_nested_dict(device=device)
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/memory.py", line 241, in memory_stats_as_nested_dict
    device = _get_device_index(device, optional=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/_utils.py", line 32, in _get_device_index
    raise ValueError('Expected a cuda device, but got: {}'.format(device))
ValueError: Expected a cuda device, but got: cpu
Traceback (most recent call last):
  File "/opt/conda/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py", line 986, in launch_command
    simple_launcher(args)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py", line 628, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/opt/conda/bin/python', 'scripts/finetune.py', 'examples/llama-2/gptq-lora.yml']' returned non-zero exit status 1.

Before loading the model, log traces are:

[2023-09-20 07:00:15,257] [DEBUG] [axolotl.wrapper:24] [PID:5374] [RANK:0] Device in check_cuda_device: cuda:0                      
[2023-09-20 07:00:15,257] [DEBUG] [axolotl.wrapper:25] [PID:5374] [RANK:0] Default value on check_cuda_device: (0.0, 0.0, 0.0)      
[2023-09-20 07:00:15,258] [DEBUG] [axolotl.wrapper:24] [PID:5374] [RANK:0] Device in check_cuda_device: cuda:0                      
[2023-09-20 07:00:15,258] [DEBUG] [axolotl.wrapper:25] [PID:5374] [RANK:0] Default value on check_cuda_device: 0.0

So I assume GPU gets picked up correctly, but after the model is loaded, this are the logs:

[2023-09-20 07:01:16,614] [DEBUG] [axolotl.wrapper:24] [PID:5374] [RANK:0] Device in check_cuda_device: cpu
[2023-09-20 07:01:16,614] [DEBUG] [axolotl.wrapper:25] [PID:5374] [RANK:0] Default value on check_cuda_device: (0.0, 0.0, 0.0)

And suddenly the device detected by pytorch is cpu.

@NanoCode012
Copy link
Collaborator

NanoCode012 commented Sep 21, 2023

I added some print steps, and it shows model.device=='cpu' within the load_model function. I think we can simply add another check for cpu?

@winglian winglian merged commit 196ff11 into main Sep 21, 2023
4 checks passed
@winglian winglian deleted the auto-device-skip branch September 21, 2023 19:20
mkeoliya pushed a commit to mkeoliya/axolotl that referenced this pull request Dec 15, 2023
…-cloud#609)

* skip the gpu memory checks if the device is set to 'auto'

* skip gpu mem logging if cpu too

* don't worry about log_gpu_memory_usage since it calls another annotated fn

* rename decorator internal
djsaunde pushed a commit that referenced this pull request Dec 17, 2024
* skip the gpu memory checks if the device is set to 'auto'

* skip gpu mem logging if cpu too

* don't worry about log_gpu_memory_usage since it calls another annotated fn

* rename decorator internal
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

example config llama-2/lora.yml fails when load_in_8bit is set to False
3 participants