Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exploration of LoRA using composition #167

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

davidkoski
Copy link
Collaborator

@davidkoski davidkoski commented Dec 16, 2024

Not to be merged -- just an exploration of composition and LoRA

@ModuleInfo(key: "q_proj") var wq: UnaryLayer
@ModuleInfo(key: "k_proj") var wk: UnaryLayer
@ModuleInfo(key: "v_proj") var wv: UnaryLayer
@ModuleInfo(key: "o_proj") var wo: UnaryLayer
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To use composition we would declare the layers with an appropriate protocol instead of a concrete type.

// - make a Quantized protocol that provides the groupSize and bits
// - make the QuantizedLinear shape produce the expanded shape
// - make `items()` open
// - make `updateModule(key:_:)` open
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some ideas that I think I should do regardless of the outcome of this

// - see items() and updateModule()

// TODO: make UnaryLayer extend Module
public protocol UnaryLayer2: Module {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here just so the types work out -- all UnaryLayer are also Module.


// TODO: in LoRALinear this is
// public static func from(linear: Linear, rank: Int = 8) -> LoRA
public convenience init(linear: Linear, rank: Int = 8, scale: Float = 20.0) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So rather than calling:

qProj = LoRALinear.from(linear: qProj)

you would:

qProj = LoRA(linear: qProj)

Comment on lines +353 to +366
// produce a merged view of properties (flatten LoRA into adapts)
override func items() -> ModuleItems {
var result = adapts.items()
for (key, value) in super.items() {
if key == "adapts" { continue }
result[key] = value
}
return result
}

// forward module updates -> adapt
func updateModule(key: String, _ value: Any) throws {
try adapts.updateModule(key: key, value)
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work as-is because these methods can't be overridden (see TODOs).

The idea is that the LoRA composition would flatten itself into what it adapts -- the Linear and LoRA keys would be merged for the purpose of updates, etc.

As per the notes noGrad would need to be overridable (it is a property with storage and cannot be used that way right now).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is necessary for a couple reasons:

  • generally weight saving and loading doesn't want to see the adaptor layer
  • this matches the typical shape of a graph with lora (mixed in to the linear)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think this forwarding is the worst part of it. We could potentially make a subclass of Module that encapsulates this if it becomes a common thing. That would help, but I suspect there would be complications.


// TODO: this requires knowledge of the innards of the adapted layer so it
// is specific to Linear (and QuantizedLinear).
public func toLinear(deQuantize: Bool = false) -> Linear {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fuse operation requires knowledge of how to combine the LoRA weights with the target. Type-wise this could easily return a UnaryLayer but it has to understand the implementation in order to fuse.

// TODO let y = super.callAsFunction(x.asType(scales.dtype)) -- ignoring the asType here
let y = adapts(x)
let z = matmul(matmul(x, self.loraA), self.loraB)
return y + scale * z
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nicest part of it -- since LoRA is an adaptor we can easily express it via composition.

}

/// LoRA layer that can wrap any UnaryLayer
class LoRA: Module, UnaryLayer2 {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

davidkoski added a commit to davidkoski/mlx-swift that referenced this pull request Dec 17, 2024
- see ml-explore/mlx-swift-examples#167
- also fixes issue where quantize() could quantize a quantized layer!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant