Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP2 integration: torch.chunks(Params4bit) not returning Params4bit subclass #1424

Open
mreso opened this issue Nov 21, 2024 · 1 comment
Open
Labels
bug Something isn't working FSDP help wanted Extra attention is needed high priority (first issues that will be worked on)

Comments

@mreso
Copy link

mreso commented Nov 21, 2024

System Info

Hi, I'am trying to make FSDP2 work with a llama model quantized with bitsandbytes but it seems that bitsandbytes' tensor subclasses like Params4bit are not compatible with the way FSDP2 shards the model.
When creating the DTensors to shard the model FSDP2 applies torch.chunk to the parameters which get returned by torch.chunk as ordinary Tensors instead of the original subclass (like Params4bit) which leads to errors down the line.

Is this a known issue and are there plans to make bitsandbytes composable with FSDP2?

Reproduction

Created a simple repro:

import torch
import torch.nn as nn

import bitsandbytes as bnb
from bitsandbytes.nn import Params4bit
        
blocksize=64
compress_statistics = True
quant_type = "fp4"
quant_storage=torch.uint8

w = torch.ones(4).to("cuda")

w_4bit, quant_state = bnb.functional.quantize_4bit(
    w,
    blocksize=blocksize,
    compress_statistics=compress_statistics,
    quant_type=quant_type,
    quant_storage=quant_storage,
    )

b = Params4bit.from_prequantized(w_4bit, quant_state.as_dict(packed=True))
print(f"{b=}")

chunks = torch.chunk(b, 2, dim=0)

print(f"{chunks=}")

Output:

b=Parameter containing:
Parameter(Params4bit([[51],
            [51]], device='cuda:0', dtype=torch.uint8))
chunks=(tensor([[51]], device='cuda:0', dtype=torch.uint8), tensor([[51]], device='cuda:0', dtype=torch.uint8))

Expected behavior

Expecting the output of torch.chunk to be a a tuple of Params4bits instead of a Tensors.

@matthewdouglas matthewdouglas added the bug Something isn't working label Nov 26, 2024
@Titus-von-Koeller
Copy link
Collaborator

We’re prioritizing this with high prio, but our current focus is on tasks with even greater impact. Contributions are highly appreciated, and we’re happy to assist with anything needed along the way.

Let’s collaborate to address this as soon as possible. Thanks for taking the initiative and highlighting its importance, along with the limitations you’ve encountered.

@Titus-von-Koeller Titus-von-Koeller pinned this issue Dec 11, 2024
@Titus-von-Koeller Titus-von-Koeller changed the title torch.chunks(Params4bit) not returning Params4bit subclass FSDP2 integration: torch.chunks(Params4bit) not returning Params4bit subclass Dec 11, 2024
@Titus-von-Koeller Titus-von-Koeller added help wanted Extra attention is needed high priority (first issues that will be worked on) FSDP labels Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working FSDP help wanted Extra attention is needed high priority (first issues that will be worked on)
Projects
None yet
Development

No branches or pull requests

3 participants