Skip to content

Commit

Permalink
Add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Nov 18, 2023
1 parent dcbe7ba commit 2035917
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import math
from typing import Dict, List, NamedTuple, Optional, Union

Expand All @@ -18,6 +19,9 @@
from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase


logger = logging.getLogger(__name__)


class LoRA(nn.Module):
def __init__(
self,
Expand All @@ -44,6 +48,8 @@ def __init__(
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
self.scaling = self.lora_alpha / self.r

# For compatibility with (IA)^3, allow all init_weights types here.
# Usually should be "lora".
if config.init_weights == "lora":
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
Expand Down Expand Up @@ -114,7 +120,10 @@ def __init__(
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
self.scaling = self.lora_alpha

# For compatibility with LoRA, allow all init_weights types here.
# Usually should be "ia3".
if config.init_weights == "lora":
logger.warning("(IA)^3 module initialized with LoRA zeo init. Ignore if this is intended.")
nn.init.zeros_(self.lora_B)
elif config.init_weights == "bert":
nn.init.normal_(self.lora_B, std=0.02)
Expand Down

0 comments on commit 2035917

Please sign in to comment.