Skip to content

Commit

Permalink
address issues that prevent using composition for layers like LoRA
Browse files Browse the repository at this point in the history
- see ml-explore/mlx-swift-examples#167
- also fixes issue where quantize() could quantize a quantized layer!
  • Loading branch information
davidkoski committed Dec 17, 2024
1 parent 15f12e4 commit 739f84b
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Source/MLXNN/Linear.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ open class Linear: Module, UnaryLayer, Quantizable {
public let weight: MLXArray
public let bias: MLXArray?

public var shape: (Int, Int) {
open var shape: (Int, Int) {
weight.shape2
}

Expand Down
85 changes: 66 additions & 19 deletions Source/MLXNN/Module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ open class Module {

/// Flag to indicate whether the module is being trained. Manipulated via
/// ``train(_:)``.
///
/// ### See Also
/// - ``didSetTrain(_:)``
public private(set) var training = true

/// Set of property names that are frozen. Maniupulated via
/// ``freeze(recursive:keys:strict:)`` and
/// ``unfreeze(recursive:keys:strict:)``.
public private(set) var noGrad = Set<String>()
/// See ``noGrad()``
private var _noGrad = Set<String>()

private var _items: ModuleItems!
private var _setters: [String: TypeErasedSetter]!
Expand Down Expand Up @@ -139,7 +140,7 @@ open class Module {
/// and ``update(parameters:)`` for example.
///
/// Subclasses could potentially override this to provide custom introspection.
public func items() -> ModuleItems {
open func items() -> ModuleItems {
_items
}

Expand Down Expand Up @@ -222,7 +223,7 @@ open class Module {
/// - ``mapParameters(map:isLeaf:)``
/// - ``modules()``
/// - ``items()``
public func filterMap<Result>(
open func filterMap<Result>(
filter: (Module, String, ModuleItem) -> Bool,
map: (ModuleItem) -> Result? = { $0 },
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault
Expand Down Expand Up @@ -331,7 +332,7 @@ open class Module {
/// ### See Also
/// - <doc:module-filters>
/// - ``mapParameters(map:)``
public func mapParameters<Result>(
open func mapParameters<Result>(
map: @escaping (MLXArray) -> Result? = { $0 },
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault
) -> NestedDictionary<String, Result> {
Expand All @@ -343,28 +344,28 @@ open class Module {

/// Return a `NestedDictionary<String, MLXArray>` for all parameters in the
/// model (all layers).
public func parameters() -> ModuleParameters {
open func parameters() -> ModuleParameters {
filterMap(filter: Self.filterValidParameters, map: Self.mapParameters())
}

/// Return a `NestedDictionary<String, MLXArray>` for all trainable parameters in the
/// model (all layers).
///
/// This omits ``freeze(recursive:keys:strict:)`` (frozen) parameters.
public func trainableParameters() -> ModuleParameters {
open func trainableParameters() -> ModuleParameters {
filterMap(filter: Self.filterTrainableParameters, map: Self.mapParameters())
}

/// Produces a `NestedDictionary<String, Module>` for all direct children of the module.
public func children() -> ModuleChildren {
open func children() -> ModuleChildren {
filterMap(filter: Self.filterValidChild, map: Self.mapModule(), isLeaf: Self.isLeafModule)
}

/// Produces a `NestedDictionary<String, Module>` for all leaf modules module.
///
/// ### See Also
/// - ``isLeafModuleNoChildren``
public func leafModules() -> ModuleChildren {
open func leafModules() -> ModuleChildren {
filterMap(
filter: Self.filterValidChild, map: Self.mapModule(),
isLeaf: Self.isLeafModuleNoChildren)
Expand Down Expand Up @@ -710,7 +711,23 @@ open class Module {
return self
}

private func updateModule(key: String, _ value: Any) throws {
/// Set a module to a new value.
///
/// The module property must be wrapped in a ``ModuleInfo``:
///
/// ```swift
/// @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
/// ```
///
/// and the value must be a compatible type.
///
/// This method is called via ``update(modules:)`` and is not typically called directly. This
/// is exposed as an overridable method for subclasses.
///
/// - Parameters:
/// - key: module key, see ``ModuleInfo``
/// - value: the replacement module
open func updateModule(key: String, _ value: Any) throws {
if let setter = _setters[key] {
do {
try setter.updateModule(value)
Expand All @@ -727,7 +744,7 @@ open class Module {
}

// `apply_to_modules()`
public func visit(modules visitor: (String, Module) throws -> Void) rethrows {
open func visit(modules visitor: (String, Module) throws -> Void) rethrows {
var stack = [(String, Module)]()
stack.append(("", self))

Expand All @@ -746,7 +763,7 @@ open class Module {
/// - ``namedModules()``
/// - ``children()``
/// - ``leafModules()``
public func modules() -> [Module] {
open func modules() -> [Module] {
var result = [Module]()
visit {
result.append($1)
Expand All @@ -760,7 +777,7 @@ open class Module {
/// - ``modules()``
/// - ``children()``
/// - ``leafModules()``
public func namedModules() -> [(String, Module)] {
open func namedModules() -> [(String, Module)] {
var result = [(String, Module)]()
visit {
result.append(($0, $1))
Expand Down Expand Up @@ -822,7 +839,8 @@ open class Module {
/// - ``unfreeze(recursive:keys:strict:)``
open func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws {
let visitor = freezeVisitor(keys: keys, strict: strict) {
$0.noGrad.formUnion($1)
$0._noGrad.formUnion($1)
$0.didSetNoGrad($0._noGrad)
}

if recursive {
Expand Down Expand Up @@ -859,7 +877,8 @@ open class Module {
/// - ``Module/unfreeze(recursive:keys:strict:)``
open func unfreeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws {
let visitor = freezeVisitor(keys: keys, strict: strict) {
$0.noGrad.subtract($1)
$0._noGrad.subtract($1)
$0.didSetNoGrad($0._noGrad)
}

if recursive {
Expand All @@ -869,6 +888,24 @@ open class Module {
}
}

/// Set of property names that are frozen. Maniupulated via
/// ``freeze(recursive:keys:strict:)`` and
/// ``unfreeze(recursive:keys:strict:)``.
open func noGrad() -> Set<String> {
_noGrad
}

/// Called when ``noGrad()`` is updated.
///
/// This is provided for subclasses to override.
///
/// - Parameter noGrad: set of properties that are frozen
///
/// ### See Also
/// - ``noGrad()``
open func didSetNoGrad(_ noGrad: Set<String>) {
}

/// Recursively set the model's training mode.
///
/// Training mode only applies to certain layers. For example
Expand All @@ -877,11 +914,21 @@ open class Module {
///
/// ### See Also
/// - ``training``
/// - ``didSetTrain(_:)``
public func train(_ mode: Bool = true) {
visit(modules: {
$1.training = mode
$1.didSetTrain(mode)
})
}

/// Called when ``train(_:)`` is updated.
///
/// This is provided for subclasses to override.
///
/// - Parameter mode: `true` is training
open func didSetTrain(_ mode: Bool) {
}
}

extension Module: IndentedDescription {
Expand Down Expand Up @@ -922,7 +969,7 @@ extension Module: Updatable, Evaluatable {
/// ### See Also
/// - <doc:layers>
/// - ``Sequential``
public protocol UnaryLayer {
public protocol UnaryLayer: Module {
func callAsFunction(_ x: MLXArray) -> MLXArray
}

Expand Down Expand Up @@ -996,7 +1043,7 @@ extension Module {
(module: Module, key: String, item: ModuleItem) in
switch item {
case .array, .dictionary, .value(.parameters), .value(.module):
!key.hasPrefix("_") && !module.noGrad.contains(key)
!key.hasPrefix("_") && !module.noGrad().contains(key)
default: false
}
}
Expand Down
24 changes: 19 additions & 5 deletions Source/MLXNN/Quantized.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@ public protocol Quantizable {
func toQuantized(groupSize: Int, bits: Int) -> Module
}

public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Module? {
if let quantizable = layer as? Quantizable {
quantizable.toQuantized(groupSize: groupSize, bits: bits)
/// Protocol for layers that are quantized.
public protocol Quantized: Module {
var groupSize: Int { get }
var bits: Int { get }
}

public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Quantized? {
if layer is Quantized {
// already quantized
nil
} else if let quantizable = layer as? Quantizable {
quantizable.toQuantized(groupSize: groupSize, bits: bits) as? Quantized
} else {
nil
}
Expand Down Expand Up @@ -52,7 +61,7 @@ public func quantize(
}

/// The same as ``Embedding`` but with a quantized weight matrix.
open class QuantizedEmbedding: Embedding {
open class QuantizedEmbedding: Embedding, Quantized {

public let groupSize: Int
public let bits: Int
Expand Down Expand Up @@ -121,14 +130,19 @@ open class QuantizedEmbedding: Embedding {
///
/// ### See Also
/// - ``init(weight:bias:groupSize:bits:)``
open class QuantizedLinear: Linear {
open class QuantizedLinear: Linear, Quantized {

public let groupSize: Int
public let bits: Int

public let scales: MLXArray
public let biases: MLXArray

open override var shape: (Int, Int) {
let shape = weight.shape2
return (shape.0, shape.1 * 32 / bits)
}

/// Applies an affine transformation to the input using a quantized weight matrix.
///
/// This is the quantized version of ``Linear``. Typically this is used via ``quantize(model:groupSize:bits:predicate:)``.
Expand Down
8 changes: 4 additions & 4 deletions Source/MLXNN/Transformer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ open class MultiHeadAttention: Module {

public let numHeads: Int

@ModuleInfo(key: "query_proj") public var queryProjection: Linear
@ModuleInfo(key: "key_proj") public var keyProjection: Linear
@ModuleInfo(key: "value_proj") public var valueProjection: Linear
@ModuleInfo(key: "out_proj") public var outProjection: Linear
@ModuleInfo(key: "query_proj") public var queryProjection: UnaryLayer
@ModuleInfo(key: "key_proj") public var keyProjection: UnaryLayer
@ModuleInfo(key: "value_proj") public var valueProjection: UnaryLayer
@ModuleInfo(key: "out_proj") public var outProjection: UnaryLayer

/// Implements the scaled dot product attention with multiple heads.
///
Expand Down
Loading

0 comments on commit 739f84b

Please sign in to comment.