Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature dim attributes to BitLinear for easier PEFT integration #34946

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

agostinv
Copy link

@agostinv agostinv commented Nov 26, 2024

What does this PR do?

This PR is an extremely simple two-liner (adding in_features and out_features as attributes to BitLinear) whose only purpose is to improve accessibility for BitLinear to users that want to employ peft. Currently, BitLinear is not usable with LoRAs in peft out-of-the-box.

The typical flow for enabling LoRAs for custom layers in peft is to construct a custom class that describes the LoRAs behavior and then registers it with a private API. The problem is that peft still needs additional information on input and output dimensionality via in_features and out_features, which BitLinear currently lacks. The current solution for this problem is to wrap BitLinear with another module that adds these attributes during initialization and then replace all instances of BitLinear with that new module. Alternatively, the LoRA source code would have to be revised to support BitLinear and derive the feature dimensions from its weight matrix. From the perspective of potential users, adding the aforementioned attributes improves accessibility and avoids requiring some hacky looking fixes from their end.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Other checkmarks are left untouched, as they don't look relevant.

Who can review?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@MekkCyber
Copy link
Contributor

LGTM @agostinv, thanks for the feature !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Nov 28, 2024

Great little addition, thanks. Is this change sufficient to enable PEFT LoRA with bitlinear? Do you have a snippet to show its usage? I could imagine that training and inference work out of the box with this change, but some features like merging don't work or need special handling in PEFT.

Edit: As BitLinear is not a subclass of nn.Linear, we probably need extra handling in PEFT to make it work.

@agostinv
Copy link
Author

agostinv commented Nov 28, 2024

@BenjaminBossan You're exactly right! Based on my experience, training functions but merging is non-trivial (also should clarify I forked peft to have it deduce the weight matrix dimensionality when I was looking for a hacky solution while working on my own experiments). Personally am keeping my adaptors separate for academic experiments so far.

src/transformers/integrations/bitnet.py has a number functions that can be used/adapted to be helpers in this respect if users really want to maintain a fully 1.58b layer versus some mixed-precision weights and two parallel forward paths through the layer. Ultimately, this all has to be user-defined behavior anyways unless peft integrates direct support for BitLinear on its own, so it shouldn't be too big of a deal on the transformers side of thing for the moment.

The attributes allow us to get caught by the following code in peft/tuners/lora/layer:

https://github.com/huggingface/peft/blob/131efba5d48753a3355ecd4f3833ae010a0510d6/src/peft/tuners/lora/layer.py#L93-L101

As far as a quick example goes, I have the following snippet that's pretty ad-hoc but is generally based on the BitsAndBytes implementations for peft with LoRA. Currently using it for a small, private project. In fact, it probably needs some changes to function during inference.

import warnings
from typing import Any, Optional

import torch

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.other import transpose

from peft.tuners.lora.layer import LoraLayer
    
class BitNetLinearLora(torch.nn.Module, LoraLayer):
    # Lora implemented in a dense layer
    def __init__(
        self,
        base_layer: torch.nn.Module,
        adapter_name: str,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        init_lora_weights: bool = True,
        use_rslora: bool = False,
        use_dora: bool = False,
        **kwargs,
    ) -> None:
        super().__init__()
        LoraLayer.__init__(self, base_layer)
        self.fan_in_fan_out = False

        self._active_adapter = adapter_name
        self.update_layer(
            adapter_name,
            r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            init_lora_weights=init_lora_weights,
            use_rslora=use_rslora,
            use_dora=use_dora,
        )

    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
        raise NotImplementedError

    def unmerge(self) -> None:
        raise NotImplementedError

    def get_delta_weight(self, adapter):
        return (
            transpose(
                self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
                False,
            )
            * self.scaling[adapter]
        )

    def _mixed_batch_forward(
        self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
    ) -> torch.Tensor:
        # This is a special method that handles the case when users pass the argument `adapter_names`. This is an
        # extra argument that allows mixing different adapters in the same batch at inference time.
        result = self.base_layer(x, *args, **kwargs)

        unique_adapters = set(adapter_names)
        sub_batch_indices_list = []
        for adapter in unique_adapters:
            sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])

        for i, active_adapter in enumerate(unique_adapters):
            if active_adapter == "__base__":
                continue
            if active_adapter not in self.lora_A.keys():
                continue

            lora_A = self.lora_A[active_adapter]
            lora_B = self.lora_B[active_adapter]
            dropout = self.lora_dropout[active_adapter]
            scaling = self.scaling[active_adapter]

            requires_conversion = not torch.is_autocast_enabled()
            if requires_conversion:
                expected_dtype = result.dtype
                x = x.to(lora_A.weight.dtype)

            # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
            # layer output
            sub_batch = x[sub_batch_indices_list[i]]
            output = lora_B(lora_A(dropout(sub_batch))) * scaling
            if requires_conversion:
                output = output.to(expected_dtype)
            result[sub_batch_indices_list[i]] += output

        return result

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        self._check_forward_args(x, *args, **kwargs)
        adapter_names = kwargs.pop("adapter_names", None)

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif adapter_names is not None:
            result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            result = self.base_layer(x, *args, **kwargs)
            # As per Tim Dettmers, for 4bit, we need to defensively clone here.
            # The reason is that in some cases, an error can occur that backprop
            # does not work on a manipulated view. This issue may be solved with
            # newer PyTorch versions but this would need extensive testing to be
            # sure.
            result = result.clone()

            for active_adapter in self.active_adapters:
                if active_adapter not in self.lora_A.keys():
                    continue
                lora_A = self.lora_A[active_adapter]
                lora_B = self.lora_B[active_adapter]
                dropout = self.lora_dropout[active_adapter]
                scaling = self.scaling[active_adapter]

                requires_conversion = not torch.is_autocast_enabled()
                if requires_conversion:
                    expected_dtype = result.dtype
                    x = x.to(lora_A.weight.dtype)

                if not self.use_dora[active_adapter]:
                    output = lora_B(lora_A(dropout(x))) * scaling
                else:
                    x = dropout(x)
                    output = self.lora_magnitude_vector[active_adapter](
                        x,
                        lora_A=lora_A,
                        lora_B=lora_B,
                        scaling=scaling,
                        base_layer=self.get_base_layer(),
                    )
                if requires_conversion:
                    output = output.to(expected_dtype)

                result = result + output

        return result

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "lora." + rep

def dispatch_bitnet(target: torch.nn.Module, adapter_name: str, **kwargs):
    new_module = None

    if isinstance(target, BaseTunerLayer):
        target_base_layer = target.get_base_layer()
    else:
        target_base_layer = target

    bitnet_kwargs = kwargs.copy()
    new_module = BitNetLinearLora(target, adapter_name, **bitnet_kwargs)

    return new_module

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants