-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
exploration of LoRA using composition #167
base: main
Are you sure you want to change the base?
Conversation
@ModuleInfo(key: "q_proj") var wq: UnaryLayer | ||
@ModuleInfo(key: "k_proj") var wk: UnaryLayer | ||
@ModuleInfo(key: "v_proj") var wv: UnaryLayer | ||
@ModuleInfo(key: "o_proj") var wo: UnaryLayer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To use composition we would declare the layers with an appropriate protocol instead of a concrete type.
// - make a Quantized protocol that provides the groupSize and bits | ||
// - make the QuantizedLinear shape produce the expanded shape | ||
// - make `items()` open | ||
// - make `updateModule(key:_:)` open |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some ideas that I think I should do regardless of the outcome of this
// - see items() and updateModule() | ||
|
||
// TODO: make UnaryLayer extend Module | ||
public protocol UnaryLayer2: Module { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here just so the types work out -- all UnaryLayer are also Module.
|
||
// TODO: in LoRALinear this is | ||
// public static func from(linear: Linear, rank: Int = 8) -> LoRA | ||
public convenience init(linear: Linear, rank: Int = 8, scale: Float = 20.0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So rather than calling:
qProj = LoRALinear.from(linear: qProj)
you would:
qProj = LoRA(linear: qProj)
// produce a merged view of properties (flatten LoRA into adapts) | ||
override func items() -> ModuleItems { | ||
var result = adapts.items() | ||
for (key, value) in super.items() { | ||
if key == "adapts" { continue } | ||
result[key] = value | ||
} | ||
return result | ||
} | ||
|
||
// forward module updates -> adapt | ||
func updateModule(key: String, _ value: Any) throws { | ||
try adapts.updateModule(key: key, value) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work as-is because these methods can't be overridden (see TODOs).
The idea is that the LoRA composition would flatten itself into what it adapts -- the Linear and LoRA keys would be merged for the purpose of updates, etc.
As per the notes noGrad
would need to be overridable (it is a property with storage and cannot be used that way right now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is necessary for a couple reasons:
- generally weight saving and loading doesn't want to see the adaptor layer
- this matches the typical shape of a graph with lora (mixed in to the linear)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think this forwarding is the worst part of it. We could potentially make a subclass of Module that encapsulates this if it becomes a common thing. That would help, but I suspect there would be complications.
|
||
// TODO: this requires knowledge of the innards of the adapted layer so it | ||
// is specific to Linear (and QuantizedLinear). | ||
public func toLinear(deQuantize: Bool = false) -> Linear { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fuse
operation requires knowledge of how to combine the LoRA weights with the target. Type-wise this could easily return a UnaryLayer
but it has to understand the implementation in order to fuse.
// TODO let y = super.callAsFunction(x.asType(scales.dtype)) -- ignoring the asType here | ||
let y = adapts(x) | ||
let z = matmul(matmul(x, self.loraA), self.loraB) | ||
return y + scale * z |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nicest part of it -- since LoRA is an adaptor we can easily express it via composition.
} | ||
|
||
/// LoRA layer that can wrap any UnaryLayer | ||
class LoRA: Module, UnaryLayer2 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference here is the current implementation of LoRA:
- see ml-explore/mlx-swift-examples#167 - also fixes issue where quantize() could quantize a quantized layer!
Not to be merged -- just an exploration of composition and LoRA