Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exploration of LoRA using composition #167

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 123 additions & 4 deletions Libraries/MLXLLM/Models/Gemma2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import MLX
import MLXFast
import MLXLMCommon
import MLXNN
import MLXRandom

// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py

Expand Down Expand Up @@ -33,10 +34,10 @@ private class Attention: Module {
let nKVHeads: Int
let repeats: Int

@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
@ModuleInfo(key: "q_proj") var wq: UnaryLayer
@ModuleInfo(key: "k_proj") var wk: UnaryLayer
@ModuleInfo(key: "v_proj") var wv: UnaryLayer
@ModuleInfo(key: "o_proj") var wo: UnaryLayer
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To use composition we would declare the layers with an appropriate protocol instead of a concrete type.


let rope: RoPE

Expand Down Expand Up @@ -288,3 +289,121 @@ extension Gemma2Model: LoRAModel {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}

// TODO - notes
//
// - make UnaryLayer extend Module
// - make a Quantized protocol that provides the groupSize and bits
// - make the QuantizedLinear shape produce the expanded shape
// - make `items()` open
// - make `updateModule(key:_:)` open
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some ideas that I think I should do regardless of the outcome of this

// - make `noGrad` overridable (turn into function?)
//
// - evaluation and training should work as expected
// - this flattens the weights and modules into one layer
// - to match the normal lora implementation
// - see items() and updateModule()

// TODO: make UnaryLayer extend Module
public protocol UnaryLayer2: Module {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here just so the types work out -- all UnaryLayer are also Module.

func callAsFunction(_ x: MLXArray) -> MLXArray
}

/// LoRA layer that can wrap any UnaryLayer
class LoRA: Module, UnaryLayer2 {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


let adapts: UnaryLayer2
let scale: Float

@ParameterInfo(key: "lora_a") var loraA: MLXArray
@ParameterInfo(key: "lora_b") var loraB: MLXArray

public init(
adapts: UnaryLayer2, inputDimensions: Int, outputDimensions: Int, rank: Int = 8,
scale: Float = 20.0
) {
self.adapts = adapts

self.scale = scale

let loraScale = 1 / sqrt(Float(inputDimensions))
self._loraA.wrappedValue = MLXRandom.uniform(
low: -loraScale, high: loraScale, [inputDimensions, rank])
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])

freeze()
}

// TODO: in LoRALinear this is
// public static func from(linear: Linear, rank: Int = 8) -> LoRA
public convenience init(linear: Linear, rank: Int = 8, scale: Float = 20.0) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So rather than calling:

qProj = LoRALinear.from(linear: qProj)

you would:

qProj = LoRA(linear: qProj)

var (outputDimensions, inputDimensions) = linear.shape

if let linear = linear as? QuantizedLinear {
// TODO Linear should probably have a property to return these directly
// rather than shape which represents the physical shape of the layers
inputDimensions = inputDimensions * 32 / linear.bits
}

self.init(
adapts: linear,
inputDimensions: inputDimensions, outputDimensions: outputDimensions,
rank: rank, scale: scale)
}

// produce a merged view of properties (flatten LoRA into adapts)
override func items() -> ModuleItems {
var result = adapts.items()
for (key, value) in super.items() {
if key == "adapts" { continue }
result[key] = value
}
return result
}

// forward module updates -> adapt
func updateModule(key: String, _ value: Any) throws {
try adapts.updateModule(key: key, value)
}
Comment on lines +354 to +367
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work as-is because these methods can't be overridden (see TODOs).

The idea is that the LoRA composition would flatten itself into what it adapts -- the Linear and LoRA keys would be merged for the purpose of updates, etc.

As per the notes noGrad would need to be overridable (it is a property with storage and cannot be used that way right now).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is necessary for a couple reasons:

  • generally weight saving and loading doesn't want to see the adaptor layer
  • this matches the typical shape of a graph with lora (mixed in to the linear)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think this forwarding is the worst part of it. We could potentially make a subclass of Module that encapsulates this if it becomes a common thing. That would help, but I suspect there would be complications.


override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws
{
try adapts.freeze(recursive: recursive, keys: keys, strict: strict)
}

// TODO: this requires knowledge of the innards of the adapted layer so it
// is specific to Linear (and QuantizedLinear).
public func toLinear(deQuantize: Bool = false) -> Linear {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fuse operation requires knowledge of how to combine the LoRA weights with the target. Type-wise this could easily return a UnaryLayer but it has to understand the implementation in order to fuse.

// TODO throws? failable?
guard let linear = adapts as? Linear else { fatalError("Not a Linear") }

var weight: MLXArray
if let quantized = linear as? QuantizedLinear {
weight = dequantized(
quantized.weight, scales: quantized.scales, biases: quantized.biases,
groupSize: quantized.groupSize, bits: quantized.bits)
} else {
weight = linear.weight
}

let loraB = (scale * loraB.T).asType(.float16)
let loraA = loraA.T.asType(.float16)
let mergedWeight = weight + matmul(loraB, loraA)

// TODO maybe add a protocol for Quanitzed
if let quantized = linear as? QuantizedLinear {
return QuantizedLinear(
weight: mergedWeight, bias: quantized.bias,
groupSize: quantized.groupSize, bits: quantized.bits)
} else {
return Linear(weight: mergedWeight, bias: linear.bias)
}
}

public func callAsFunction(_ x: MLXArray) -> MLXArray {
// TODO let y = super.callAsFunction(x.asType(scales.dtype)) -- ignoring the asType here
let y = adapts(x)
let z = matmul(matmul(x, self.loraA), self.loraB)
return y + scale * z
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nicest part of it -- since LoRA is an adaptor we can easily express it via composition.

}
}