Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN error when using a GPU with no support for igemmlt #165

Closed
0cc4m opened this issue Feb 25, 2023 · 20 comments · May be fixed by #408
Closed

NaN error when using a GPU with no support for igemmlt #165

0cc4m opened this issue Feb 25, 2023 · 20 comments · May be fixed by #408

Comments

@0cc4m
Copy link

0cc4m commented Feb 25, 2023

I get RuntimeError: probability tensor contains either inf, nan or element < 0 on most language models when trying to run them in 8bit.

I adapted a script made by lorr1 #42 (comment) into a small script that first runs the model using 8bit with igemmlt and then disables the support for igemmlt and runs it again. I tested this on an RTX 3060 and the result is the RuntimeError when running without igemmlt. I think there is a bug in the code that replaces igemmlt on older GPUs.

Interestingly, it works on some models, like EleutherAI/pythia-70m-deduped, EleutherAI/gpt-neo-125M, facebook/opt-6.7b, but on most others it fails with the RuntimeError. When run with EleutherAI/pythia-410m-deduped it outputs the following:

» python 8bit_test.py

===================================BUG REPORT===================================
Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
================================================================================
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
8bit-reg:
Q: On average Joe throws 25 punches per minute. A fight lasts 5 rounds of 3 minutes.
How many punches did he throw?

A: Let’s think step by step.

First, Joe threw a baseball cap.
Next, he threw a bat in the air.
Joe threw a bat in the air.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Traceback (most recent call last):
  File "/media/veryhighspeed/koboldai/client/8bit_test.py", line 57, in <module>
    generated_ids_8bit = model_8bit.generate(input_ids, max_length=len(input_ids[0]) + MAX_NEW_TOKENS, do_sample=True)
  File "/media/veryhighspeed/koboldai/client/8bit-venv/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/media/veryhighspeed/koboldai/client/8bit-venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1437, in generate
    return self.sample(
  File "/media/veryhighspeed/koboldai/client/8bit-venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2479, in sample
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

@Ph0rk0z in #131 (comment) also ran into this issue.

@Ph0rk0z
Copy link

Ph0rk0z commented Feb 25, 2023

Only fix thus far is to lower the threshold for int8 like they did here: https://gist.github.com/whjms/2505ef082a656e7a80a3f663c16f4277

its still buggy and a bit slow.

@0cc4m
Copy link
Author

0cc4m commented Feb 25, 2023

Thank you @Ph0rk0z, I was not aware of that gist. I'll try if that works on AMD as well.

@0cc4m
Copy link
Author

0cc4m commented Feb 25, 2023

It does resolve the RuntimeError, that's great! But when trying to use 8bit for inference on larger models it seems to use a ton of VRAM, negating any advantage from switching to 8bit. Strange.

@Ph0rk0z
Copy link

Ph0rk0z commented Feb 25, 2023

Its because it puts much stuff as FP16 and that shows up when generating. It's POOF and you are back to out of memory again.

For me this was slower than offloading or flexgen. Either this is a HW problem or a bug and the dev appears to be busy.

@ZQ-Dev8
Copy link

ZQ-Dev8 commented Feb 28, 2023

Only fix thus far is to lower the threshold for int8 like they did here: https://gist.github.com/whjms/2505ef082a656e7a80a3f663c16f4277

its still buggy and a bit slow.

This worked for me as a band-aid to run inference on OPT-66B, but I don't understand exactly what changing the threshold is doing. I'm assuming by lowering the threshold, we are increasing the number of weights that are considered large outliers, thus converting less of the model into int8? If so, what's the default threshold?

@Ph0rk0z
Copy link

Ph0rk0z commented Mar 6, 2023

I assume default is 1.0.

@0cc4m
Copy link
Author

0cc4m commented Mar 6, 2023

No, I think it corresponds to the threshold mentioned in the bitsandbytes README, which defaults to 6.0. That explains why it works on older cards with 0.8, it doesn't convert much to 8-bit anymore.

@Ph0rk0z
Copy link

Ph0rk0z commented Mar 9, 2023

Threshold at 1 works and isn't too bad. No more OOM. Maybe something between 1 and 6 needs to be figured out for speed vs NaN. This will be absolutely crucial when we are trying 4bit because I expect behavior will be similar since its not "supported"

8bitPascal

@0cc4m
Copy link
Author

0cc4m commented Mar 12, 2023

Pretty cool. It seems AMD and Nvidia Pascal will be back in business soon anyways, when 4bit gets released. Pascal supports DP4A and so does AMD Vega20 and 6000 series. Looking forward to it.

@LoopControl
Copy link

Pretty cool. It seems AMD and Nvidia Pascal will be back in business soon anyways, when 4bit gets released. Pascal supports DP4A and so does AMD Vega20 and 6000 series. Looking forward to it.

Can confirm as a person with a Pascal card -- 4bit works great on it.

Llama 30b isn't a problem and is pretty fast. OPT also works but runs slowly.

@richardwth
Copy link

richardwth commented Apr 3, 2023

I came across a similar problem when finetuning Llama 7B: the hidden states became inf at LlamaMLP (specifically, down_proj). I used V100 with device_capability 7.0 so igemmlt is not supported naturally. Then I found the inf happens at this line of autograd._functions.MatMul8bitLt

# (line 390) 3. Matmul, else branch
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))

The inf happens because output has some values larger than 65536 at F.linear.

As I understand, state.CB ranges between -127 and 127 and is relatively larger than A_wo_outliers (which is confined by threshold 6.0). Wouldn't it be safer to calculate CB first then do F.linear? That is,

CB = state.CB.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A_wo_outliers, CB)

Is it designed to prevent underflow? I also notice that CB is calculated first in the backward pass (line 455).

CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

@Opdoop
Copy link

Opdoop commented Apr 22, 2023

@richardwth Hi Richard, I'm facing the same problem. Did you solve this bug?

@Ph0rk0z
Copy link

Ph0rk0z commented Apr 22, 2023

@0cc4m
@Opdoop

I edit /site-packages/bitsandbytes/autograd/_functions.py

first at #406

        else:
            A_wo_outliers = A.clone()
            if state.idx is not None:
                A_wo_outliers[:, state.idx.long()] = 0
            #output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
            #output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
            CB = state.CB.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
            output = torch.nn.functional.linear(A_wo_outliers, CB)
            if bias is not None:
                output = output.add_(bias)

then at 468:


        if req_gradA:
            if state.CBt is not None:
                C32grad, Sgrad = F.transform(Cgrad, "col32")
                if state.CxBt is None:
                    state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
                gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
                CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
                grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
                #grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)

and now pythia-12b in 8bits at 1.5 threshold no longer NaN on me.

I then switch to full 6.0 threshold and run inference again!
It works!

@richardwth you are a hero, you fixed this bug and nobody noticed!

wahoo! #335

@zhaoqf123
Copy link

I came across a similar problem when finetuning Llama 7B: the hidden states became inf at LlamaMLP (specifically, down_proj). I used V100 with device_capability 7.0 so igemmlt is not supported naturally. Then I found the inf happens at this line of autograd._functions.MatMul8bitLt

# (line 390) 3. Matmul, else branch
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))

The inf happens because output has some values larger than 65536 at F.linear.

As I understand, state.CB ranges between -127 and 127 and is relatively larger than A_wo_outliers (which is confined by threshold 6.0). Wouldn't it be safer to calculate CB first then do F.linear? That is,

CB = state.CB.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A_wo_outliers, CB)

Is it designed to prevent underflow? I also notice that CB is calculated first in the backward pass (line 455).

CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

It's very informative! May I know how do you find out that inf occurs at exactly that line?

@zhaoqf123
Copy link

@0cc4m @Opdoop

I edit /site-packages/bitsandbytes/autograd/_functions.py

first at #406

        else:
            A_wo_outliers = A.clone()
            if state.idx is not None:
                A_wo_outliers[:, state.idx.long()] = 0
            #output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
            #output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
            CB = state.CB.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
            output = torch.nn.functional.linear(A_wo_outliers, CB)
            if bias is not None:
                output = output.add_(bias)

then at 468:


        if req_gradA:
            if state.CBt is not None:
                C32grad, Sgrad = F.transform(Cgrad, "col32")
                if state.CxBt is None:
                    state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
                gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
                CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
                grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
                #grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)

and now pythia-12b in 8bits at 1.5 threshold no longer NaN on me.

I then switch to full 6.0 threshold and run inference again! It works!

@richardwth you are a hero, you fixed this bug and nobody noticed!

wahoo! #335

Do we really need to correct line 468? Based on Richard's finding, we only need to correct line 406, right?

@Ph0rk0z
Copy link

Ph0rk0z commented May 17, 2023

Try with one and see if you get the error. I'm not sure what they ended up doing with the latest version. Does it just work now and was rewritten?

@zhaoqf123
Copy link

Try with one and see if you get the error. I'm not sure what they ended up doing with the latest version. Does it just work now and was rewritten?

In version 0.38.1, script _functions.py, line 410-411, replaced with

CB = state.CB.to(A.dtype, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A_wo_outliers, CB)

and the 2nd part is left untouched, and it works fine.

@zhaoqf123
Copy link

zhaoqf123 commented May 24, 2023

I came across a similar problem when finetuning Llama 7B: the hidden states became inf at LlamaMLP (specifically, down_proj). I used V100 with device_capability 7.0 so igemmlt is not supported naturally. Then I found the inf happens at this line of autograd._functions.MatMul8bitLt

# (line 390) 3. Matmul, else branch
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))

The inf happens because output has some values larger than 65536 at F.linear.

As I understand, state.CB ranges between -127 and 127 and is relatively larger than A_wo_outliers (which is confined by threshold 6.0). Wouldn't it be safer to calculate CB first then do F.linear? That is,

CB = state.CB.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A_wo_outliers, CB)

Is it designed to prevent underflow? I also notice that CB is calculated first in the backward pass (line 455).

CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

Inspired by Finding source of NaN in forward pass, I use the following script to trace the source of NaN, but I think it is not very good:

import torch
import bitsandbytes as bnb
from functools import partial


def nan_hook_bnb(module, args, output, name=None):
    if isinstance(module, bnb.nn.Linear8bitLt):
        if not isinstance(output, tuple):
            outputs = [output]
        else:
            outputs = output

        for i, out in enumerate(outputs):
            nan_mask = torch.isnan(out)
            if nan_mask.any():
                raise RuntimeError(f"In module {name} of name {module.__class__.__name__}, Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:",
                                   out[nan_mask.nonzero()[:, 0].unique(sorted=True)])

def register_nan_hook(model: torch.nn.Module):
    for name, submodule in model.named_modules():
        new_hook = partial(nan_hook_bnb, name=name)
        submodule.register_forward_hook(new_hook)

debug = True
register_nan_hook(model) if debug else None

Do you have better method? @richardwth

@richardwth
Copy link

@zhaoqf123 Hi buddy, sorry for the late reply. I did not use any advanced methods as you used here. I manually inserted break points and used torch.isnan, torch.isinf or torch.isfinite to check which transformer layer and later on which line exactly gave the infinite results.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants