Skip to content

Commit

Permalink
fix: use 10x less memory
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Sep 22, 2024
1 parent a20dab5 commit 6d02797
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ resolver = "2"


[workspace.package]
version = "0.1.89"
version = "0.1.90"
authors = ["louis030195 <[email protected]>"]
description = ""
repository = "https://github.com/mediar-ai/screenpipe"
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-app-tauri/src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "screenpipe-app"
version = "0.2.62"
version = "0.2.63"
description = ""
authors = ["you"]
license = ""
Expand Down
9 changes: 5 additions & 4 deletions screenpipe-audio/benches/record_and_transcribe_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use screenpipe_audio::vad_engine::VadSensitivity;
use screenpipe_audio::{
create_whisper_channel, default_input_device, record_and_transcribe, AudioDevice, AudioInput,
AudioTranscriptionEngine,
Expand All @@ -7,12 +8,11 @@ use std::path::PathBuf;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;

async fn setup_test() -> (
Arc<AudioDevice>,
PathBuf,
mpsc::UnboundedSender<AudioInput>,
crossbeam::channel::Sender<AudioInput>,
Arc<AtomicBool>,
) {
let audio_device = default_input_device().unwrap(); // TODO feed voice in automatically somehow
Expand All @@ -22,6 +22,8 @@ async fn setup_test() -> (
Arc::new(AudioTranscriptionEngine::WhisperDistilLargeV3),
screenpipe_audio::VadEngineEnum::Silero,
None,
&output_path,
VadSensitivity::High,
)
.await
.unwrap();
Expand Down Expand Up @@ -51,14 +53,13 @@ fn bench_record_and_transcribe(c: &mut Criterion) {
let mut total_duration = Duration::new(0, 0);

for _ in 0..iters {
let (audio_device, output_path, whisper_sender, is_running) = setup_test().await;
let (audio_device, _, whisper_sender, is_running) = setup_test().await;
let duration = Duration::from_secs(5); // 5 seconds of recording

let start = std::time::Instant::now();
let result = record_and_transcribe(
black_box(audio_device),
black_box(duration),
black_box(output_path),
black_box(whisper_sender),
black_box(is_running),
)
Expand Down
71 changes: 40 additions & 31 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use crate::AudioInput;
use anyhow::{anyhow, Result};
use chrono::Utc;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::StreamError;
use crossbeam::queue::ArrayQueue;
use log::{debug, error, info, warn};
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, thread};
use tokio::sync::Mutex;
use std::{fmt, thread}; // Note: We're using parking_lot for better performance

#[derive(Clone, Debug, PartialEq)]
pub enum AudioTranscriptionEngine {
Expand Down Expand Up @@ -167,13 +166,15 @@ pub async fn record_and_transcribe(
"Audio device config: sample_rate={}, channels={}",
sample_rate, channels
);
let start_time = Utc::now();

let audio_data = Arc::new(Mutex::new(Vec::new()));
// Create an ArrayQueue with a capacity of 100 chunks (adjust as needed)
let audio_queue = Arc::new(ArrayQueue::new(100));
let audio_queue_clone = Arc::clone(&audio_queue);

let is_running_weak = Arc::downgrade(&is_running);
let is_running_weak_2 = Arc::downgrade(&is_running);
let is_running_weak_3 = Arc::downgrade(&is_running);
let audio_data_clone = Arc::clone(&audio_data);
let is_running_weak_4 = Arc::downgrade(&is_running);

// Define the error callback function
let error_callback = move |err: StreamError| {
Expand All @@ -196,8 +197,7 @@ pub async fn record_and_transcribe(
.upgrade()
.map_or(false, |arc| arc.load(Ordering::Relaxed))
{
let mut audio_data = audio_data_clone.blocking_lock();
audio_data.extend_from_slice(bytemuck::cast_slice::<i8, f32>(data));
let _ = audio_queue_clone.push(bytemuck::cast_slice(data).to_vec());
}
},
error_callback,
Expand All @@ -210,8 +210,7 @@ pub async fn record_and_transcribe(
.upgrade()
.map_or(false, |arc| arc.load(Ordering::Relaxed))
{
let mut audio_data = audio_data_clone.blocking_lock();
audio_data.extend_from_slice(bytemuck::cast_slice::<i16, f32>(data));
let _ = audio_queue_clone.push(bytemuck::cast_slice(data).to_vec());
}
},
error_callback,
Expand All @@ -224,8 +223,7 @@ pub async fn record_and_transcribe(
.upgrade()
.map_or(false, |arc| arc.load(Ordering::Relaxed))
{
let mut audio_data = audio_data_clone.blocking_lock();
audio_data.extend_from_slice(bytemuck::cast_slice::<i32, f32>(data));
let _ = audio_queue_clone.push(bytemuck::cast_slice(data).to_vec());
}
},
error_callback,
Expand All @@ -238,8 +236,7 @@ pub async fn record_and_transcribe(
.upgrade()
.map_or(false, |arc| arc.load(Ordering::Relaxed))
{
let mut audio_data = audio_data_clone.blocking_lock();
audio_data.extend_from_slice(data);
let _ = audio_queue_clone.push(data.to_vec());
}
},
error_callback,
Expand Down Expand Up @@ -275,36 +272,48 @@ pub async fn record_and_transcribe(
audio_device.to_string(),
duration.as_secs()
);

// wait for the duration unless is_running is false
while is_running.load(Ordering::Relaxed) {
std::thread::sleep(Duration::from_millis(100));
if Utc::now().timestamp() - start_time.timestamp() > duration.as_secs() as i64 {
debug!("Recording duration reached");
break;
// Spawn another thread to collect audio data
let collector_handle = tokio::spawn(async move {
let mut collected_audio = Vec::new();
while is_running_weak_4
.upgrade()
.map_or(false, |arc| arc.load(Ordering::Relaxed))
{
while let Some(chunk) = audio_queue.pop() {
collected_audio.extend(chunk);
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
collected_audio
});

// Wait for the duration
tokio::time::sleep(duration).await;

// Signal the recording thread to stop
// Signal the recording to stop
is_running.store(false, Ordering::Relaxed);

// Wait for the native thread to finish
// Wait for the audio thread to finish
if let Err(e) = audio_handle.join() {
error!("Error joining audio thread: {:?}", e);
error!("error joining audio thread: {:?}", e);
}

debug!("Sending audio to audio model");
let data = audio_data.lock().await;
debug!("Sending audio of length {} to audio model", data.len());
// Collect the final audio data
let audio_data = collector_handle.await.unwrap_or_else(|e| {
error!("error joining collector thread: {:?}", e);
Vec::new()
});

debug!("sending audio to audio model");
if let Err(e) = whisper_sender.send(AudioInput {
data: Arc::new(data.to_vec()),
data: Arc::new(audio_data),
device: audio_device.clone(),
sample_rate,
channels,
}) {
error!("Failed to send audio to audio model: {}", e);
error!("failed to send audio to audio model: {}", e);
}
debug!("Sent audio to audio model");
debug!("sent audio to audio model");

Ok(())
}
Expand Down

0 comments on commit 6d02797

Please sign in to comment.