-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
exploration of LoRA using composition #167
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ import MLX | |
import MLXFast | ||
import MLXLMCommon | ||
import MLXNN | ||
import MLXRandom | ||
|
||
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py | ||
|
||
|
@@ -33,10 +34,10 @@ private class Attention: Module { | |
let nKVHeads: Int | ||
let repeats: Int | ||
|
||
@ModuleInfo(key: "q_proj") var wq: Linear | ||
@ModuleInfo(key: "k_proj") var wk: Linear | ||
@ModuleInfo(key: "v_proj") var wv: Linear | ||
@ModuleInfo(key: "o_proj") var wo: Linear | ||
@ModuleInfo(key: "q_proj") var wq: UnaryLayer | ||
@ModuleInfo(key: "k_proj") var wk: UnaryLayer | ||
@ModuleInfo(key: "v_proj") var wv: UnaryLayer | ||
@ModuleInfo(key: "o_proj") var wo: UnaryLayer | ||
|
||
let rope: RoPE | ||
|
||
|
@@ -288,3 +289,120 @@ extension Gemma2Model: LoRAModel { | |
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } | ||
} | ||
} | ||
|
||
// TODO - notes | ||
// | ||
// - make UnaryLayer extend Module | ||
// - make a Quantized protocol that provides the groupSize and bits | ||
// - make the QuantizedLinear shape produce the expanded shape | ||
// - make `items()` open | ||
// - make `updateModule(key:_:)` open | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some ideas that I think I should do regardless of the outcome of this |
||
// | ||
// - evaluation and training should work as expected | ||
// - this flattens the weights and modules into one layer | ||
// - to match the normal lora implementation | ||
// - see items() and updateModule() | ||
|
||
// TODO: make UnaryLayer extend Module | ||
public protocol UnaryLayer2: Module { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here just so the types work out -- all UnaryLayer are also Module. |
||
func callAsFunction(_ x: MLXArray) -> MLXArray | ||
} | ||
|
||
/// LoRA layer that can wrap any UnaryLayer | ||
class LoRA: Module, UnaryLayer2 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reference here is the current implementation of LoRA: |
||
|
||
let adapts: UnaryLayer2 | ||
let scale: Float | ||
|
||
@ParameterInfo(key: "lora_a") var loraA: MLXArray | ||
@ParameterInfo(key: "lora_b") var loraB: MLXArray | ||
|
||
public init( | ||
adapts: UnaryLayer2, inputDimensions: Int, outputDimensions: Int, rank: Int = 8, | ||
scale: Float = 20.0 | ||
) { | ||
self.adapts = adapts | ||
|
||
self.scale = scale | ||
|
||
let loraScale = 1 / sqrt(Float(inputDimensions)) | ||
self._loraA.wrappedValue = MLXRandom.uniform( | ||
low: -loraScale, high: loraScale, [inputDimensions, rank]) | ||
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) | ||
|
||
freeze() | ||
} | ||
|
||
// TODO: in LoRALinear this is | ||
// public static func from(linear: Linear, rank: Int = 8) -> LoRA | ||
public convenience init(linear: Linear, rank: Int = 8, scale: Float = 20.0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So rather than calling: qProj = LoRALinear.from(linear: qProj) you would: qProj = LoRA(linear: qProj) |
||
var (outputDimensions, inputDimensions) = linear.shape | ||
|
||
if let linear = linear as? QuantizedLinear { | ||
// TODO Linear should probably have a property to return these directly | ||
// rather than shape which represents the physical shape of the layers | ||
inputDimensions = inputDimensions * 32 / linear.bits | ||
} | ||
|
||
self.init( | ||
adapts: linear, | ||
inputDimensions: inputDimensions, outputDimensions: outputDimensions, | ||
rank: rank, scale: scale) | ||
} | ||
|
||
// produce a merged view of properties (flatten LoRA into adapts) | ||
override func items() -> ModuleItems { | ||
var result = adapts.items() | ||
for (key, value) in super.items() { | ||
if key == "adapts" { continue } | ||
result[key] = value | ||
} | ||
return result | ||
} | ||
|
||
// forward module updates -> adapt | ||
func updateModule(key: String, _ value: Any) throws { | ||
try adapts.updateModule(key: key, value) | ||
} | ||
Comment on lines
+354
to
+367
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't work as-is because these methods can't be overridden (see TODOs). The idea is that the LoRA composition would flatten itself into what it adapts -- the Linear and LoRA keys would be merged for the purpose of updates, etc. As per the notes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is necessary for a couple reasons:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I think this forwarding is the worst part of it. We could potentially make a subclass of Module that encapsulates this if it becomes a common thing. That would help, but I suspect there would be complications. |
||
|
||
override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws | ||
{ | ||
try adapts.freeze(recursive: recursive, keys: keys, strict: strict) | ||
} | ||
|
||
// TODO: this requires knowledge of the innards of the adapted layer so it | ||
// is specific to Linear (and QuantizedLinear). | ||
public func toLinear(deQuantize: Bool = false) -> Linear { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
// TODO throws? failable? | ||
guard let linear = adapts as? Linear else { fatalError("Not a Linear") } | ||
|
||
var weight: MLXArray | ||
if let quantized = linear as? QuantizedLinear { | ||
weight = dequantized( | ||
quantized.weight, scales: quantized.scales, biases: quantized.biases, | ||
groupSize: quantized.groupSize, bits: quantized.bits) | ||
} else { | ||
weight = linear.weight | ||
} | ||
|
||
let loraB = (scale * loraB.T).asType(.float16) | ||
let loraA = loraA.T.asType(.float16) | ||
let mergedWeight = weight + matmul(loraB, loraA) | ||
|
||
// TODO maybe add a protocol for Quanitzed | ||
if let quantized = linear as? QuantizedLinear { | ||
return QuantizedLinear( | ||
weight: mergedWeight, bias: quantized.bias, | ||
groupSize: quantized.groupSize, bits: quantized.bits) | ||
} else { | ||
return Linear(weight: mergedWeight, bias: linear.bias) | ||
} | ||
} | ||
|
||
public func callAsFunction(_ x: MLXArray) -> MLXArray { | ||
// TODO let y = super.callAsFunction(x.asType(scales.dtype)) -- ignoring the asType here | ||
let y = adapts(x) | ||
let z = matmul(matmul(x, self.loraA), self.loraB) | ||
return y + scale * z | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The nicest part of it -- since LoRA is an adaptor we can easily express it via composition. |
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To use composition we would declare the layers with an appropriate protocol instead of a concrete type.