Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA for Whisper speech transcription #483

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions whisper/.vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Transcribe: Telugu using Whisper Medium",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/run_transcribe.py",
"args": [
"--audio", "${workspaceFolder}/whisper/assets/adityateluguthursday.m4a",
"--model", "${workspaceFolder}/mlx_models/whisper-medium-mlx",
],
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Train: Whisper Medium on Telugu using LoRA",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/lora/lora.py",
"args": [
"--model", "${workspaceFolder}/mlx_models/whisper-medium-mlx",
"--train",
"--adapter-file", "${workspaceFolder}/lora_adapters_whisper_with_telugu.npz",
"--hf-dataset", "mozilla-foundation/common_voice_16_1",
"--hf-dataset-lang", "te",
"--batch-size", "2",
"--lora-layers", "2"
],
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Fuse & Save: Whisper Medium on Telugu using LoRA",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/lora/fuse.py",
"args": [
"--model", "${workspaceFolder}/mlx_models/whisper-medium-mlx",
"--adapter-file", "${workspaceFolder}/lora_adapters_whisper_with_telugu.npz",
"--save-path", "${workspaceFolder}/lora_fused_model_whisper_with_telugu"
],
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Transcribe: Telugu using Whisper Medium LoRA-Telugu",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/run_transcribe.py",
"args": [
"--audio", "${workspaceFolder}/whisper/assets/adityateluguthursday.m4a",
"--model", "${workspaceFolder}/lora_fused_model_whisper_with_telugu",
],
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: Attach using Process Id",
"type": "python",
"request": "attach",
"processId": "${command:pickProcess}",
"justMyCode": true
},
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true
}
]
}
Empty file added whisper/lora/__init__.py
Empty file.
56 changes: 56 additions & 0 deletions whisper/lora/fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright © 2023 Apple Inc.

import argparse
from pathlib import Path

import mlx.core as mx
import utils
from mlx.utils import tree_flatten, tree_unflatten
from models.lora import LoRALinear

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--save-path",
default="lora_fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
"--adapter-file",
type=str,
default="adapters.npz",
help="Path to the trained adapter weights (npz or safetensors).",
)

args = parser.parse_args()

model, tokenizer, config = utils.load(args.model)

# Load adapters and get number of LoRA layers
adapters = list(mx.load(args.adapter_file).items())
lora_layers = len([m for m in adapters if "query.lora_a" in m[0]])

# Freeze all layers other than LORA linears
model.freeze()
for block in model.encoder.blocks[len(model.encoder.blocks) - lora_layers :]:
block.attn.query = LoRALinear.from_linear(block.attn.query)
block.attn.value = LoRALinear.from_linear(block.attn.value)
for block in model.decoder.blocks[len(model.decoder.blocks) - lora_layers :]:
block.cross_attn.query = LoRALinear.from_linear(block.cross_attn.query)
block.cross_attn.value = LoRALinear.from_linear(block.cross_attn.value)

model.update(tree_unflatten(adapters))
fused_linears = [
(n, m.to_linear())
for n, m in model.named_modules()
if isinstance(m, LoRALinear)
]

model.update_modules(tree_unflatten(fused_linears))
weights = dict(tree_flatten(model.parameters()))
utils.save_model(args.save_path, weights, tokenizer, config)
Loading