Skip to content

Commit

Permalink
[PEFT] Fix PEFT multi adapters support (#26407)
Browse files Browse the repository at this point in the history
* fix PEFT multi adapters support

* refactor a bit

* save pretrained + BC + added tests

* Update src/transformers/integrations/peft.py

Co-authored-by: Benjamin Bossan <[email protected]>

* add more tests

* add suggestion

* final changes

* adapt a bit

* fixup

* Update src/transformers/integrations/peft.py

Co-authored-by: Patrick von Platen <[email protected]>

* adapt from suggestions

---------

Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2023
1 parent 946bac7 commit 3ca18d6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 11 deletions.
57 changes: 47 additions & 10 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..utils import (
check_peft_version,
Expand Down Expand Up @@ -245,20 +245,27 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non

self.set_adapter(adapter_name)

def set_adapter(self, adapter_name: str) -> None:
def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
Args:
adapter_name (`str`):
The name of the adapter to set.
adapter_name (`Union[List[str], str]`):
The name of the adapter to set. Can be also a list of strings to set multiple adapters.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
elif isinstance(adapter_name, list):
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
elif adapter_name not in self.peft_config:
raise ValueError(
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
Expand All @@ -270,7 +277,11 @@ def set_adapter(self, adapter_name: str) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
Expand All @@ -294,7 +305,11 @@ def disable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = True
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True

def enable_adapters(self) -> None:
"""
Expand All @@ -312,14 +327,22 @@ def enable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False

def active_adapter(self) -> str:
def active_adapters(self) -> List[str]:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Gets the current active adapter of the model.
Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the list of all active adapters so that users can deal with them accordingly.
For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
a single string.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

Expand All @@ -333,7 +356,21 @@ def active_adapter(self) -> str:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
active_adapters = module.active_adapter
break

# For previous PEFT versions
if isinstance(active_adapters, str):
active_adapters = [active_adapters]

return active_adapters

def active_adapter(self) -> str:
logger.warning(
"The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning
)

return self.active_adapters()[0]

def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
"""
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,7 +2006,16 @@ def save_pretrained(
peft_state_dict[f"base_model.model.{key}"] = value
state_dict = peft_state_dict

current_peft_config = self.peft_config[self.active_adapter()]
active_adapter = self.active_adapters()

if len(active_adapter) > 1:
raise ValueError(
"Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
"by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
)
active_adapter = active_adapter[0]

current_peft_config = self.peft_config[active_adapter]
current_peft_config.save_pretrained(save_directory)

# Save the model
Expand Down
19 changes: 19 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,11 @@ def test_peft_add_multi_adapter(self):
_ = model.generate(input_ids=dummy_input)

model.set_adapter("default")
self.assertTrue(model.active_adapters() == ["default"])
self.assertTrue(model.active_adapter() == "default")

model.set_adapter("adapter-2")
self.assertTrue(model.active_adapters() == ["adapter-2"])
self.assertTrue(model.active_adapter() == "adapter-2")

# Logits comparison
Expand All @@ -276,6 +278,23 @@ def test_peft_add_multi_adapter(self):
)
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))

model.set_adapter(["adapter-2", "default"])
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
self.assertTrue(model.active_adapter() == "adapter-2")

logits_adapter_mixed = model(dummy_input)
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

self.assertFalse(
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

# multi active adapter saving not supported
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

@require_torch_gpu
def test_peft_from_pretrained_kwargs(self):
"""
Expand Down

0 comments on commit 3ca18d6

Please sign in to comment.