-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add feature dim attributes to BitLinear for easier PEFT integration #34946
base: main
Are you sure you want to change the base?
Add feature dim attributes to BitLinear for easier PEFT integration #34946
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
LGTM @agostinv, thanks for the feature ! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Great little addition, thanks. Is this change sufficient to enable PEFT LoRA with bitlinear? Do you have a snippet to show its usage? I could imagine that training and inference work out of the box with this change, but some features like merging don't work or need special handling in PEFT. Edit: As |
@BenjaminBossan You're exactly right! Based on my experience, training functions but merging is non-trivial (also should clarify I forked
The attributes allow us to get caught by the following code in As far as a quick example goes, I have the following snippet that's pretty ad-hoc but is generally based on the BitsAndBytes implementations for import warnings
from typing import Any, Optional
import torch
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.other import transpose
from peft.tuners.lora.layer import LoraLayer
class BitNetLinearLora(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)
self.fan_in_fan_out = False
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
raise NotImplementedError
def unmerge(self) -> None:
raise NotImplementedError
def get_delta_weight(self, adapter):
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
False,
)
* self.scaling[adapter]
)
def _mixed_batch_forward(
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
continue
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
sub_batch = x[sub_batch_indices_list[i]]
output = lora_B(lora_A(dropout(sub_batch))) * scaling
if requires_conversion:
output = output.to(expected_dtype)
result[sub_batch_indices_list[i]] += output
return result
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
result = result.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
output = self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
)
if requires_conversion:
output = output.to(expected_dtype)
result = result + output
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
def dispatch_bitnet(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
bitnet_kwargs = kwargs.copy()
new_module = BitNetLinearLora(target, adapter_name, **bitnet_kwargs)
return new_module |
What does this PR do?
This PR is an extremely simple two-liner (adding
in_features
andout_features
as attributes toBitLinear
) whose only purpose is to improve accessibility for BitLinear to users that want to employpeft
. Currently, BitLinear is not usable with LoRAs inpeft
out-of-the-box.The typical flow for enabling LoRAs for custom layers in
peft
is to construct a custom class that describes the LoRAs behavior and then registers it with a private API. The problem is thatpeft
still needs additional information on input and output dimensionality viain_features
andout_features
, whichBitLinear
currently lacks. The current solution for this problem is to wrapBitLinear
with another module that adds these attributes during initialization and then replace all instances ofBitLinear
with that new module. Alternatively, the LoRA source code would have to be revised to supportBitLinear
and derive the feature dimensions from its weight matrix. From the perspective of potential users, adding the aforementioned attributes improves accessibility and avoids requiring some hacky looking fixes from their end.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Other checkmarks are left untouched, as they don't look relevant.
Who can review?