Skip to content

Commit

Permalink
Merge pull request #171 from kyutai-labs/merge-2
Browse files Browse the repository at this point in the history
More internal changes merge.
  • Loading branch information
LaurentMazare authored Dec 11, 2024
2 parents af1e918 + 0b9f272 commit f301c16
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 7 deletions.
8 changes: 8 additions & 0 deletions rust/moshi-backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,11 @@ vergen = { version = "8.3.1", features = ["build", "cargo", "git", "gitcl", "rus
default = []
cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"]

[profile.release]
debug = true

[profile.release-no-debug]
inherits = "release"
debug = false

4 changes: 2 additions & 2 deletions rust/moshi-backend/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ pub async fn run(args: &crate::BenchmarkArgs, config: &Config) -> Result<()> {
tokio::time::sleep_until(target_time).await;
in_pcm_tx.send(zeros.to_vec())?;
}
let _ = task.await;
let _ = w.await;
task.await?;
w.await??;
}
}
Ok(())
Expand Down
3 changes: 3 additions & 0 deletions rust/moshi-backend/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ pub struct BenchmarkArgs {
#[clap(long)]
chrome_tracing: bool,

#[clap(long)]
asr: bool,

#[clap(long)]
mimi_only: bool,
}
Expand Down
2 changes: 1 addition & 1 deletion rust/moshi-backend/src/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl stream_both::AppStateInner {
let codes = mimi_model.encode_step(&fake_pcm.into())?;
let ys = mimi_model.decode_step(&codes)?;
if ys.as_option().is_none() {
anyhow::bail!("Expected mimi to output some stuff, but nothing came out.");
anyhow::bail!("Expected Mimi to output some stuff, but nothing came out.");
}
device.synchronize()?;
tracing::info!("model is ready to roll!");
Expand Down
85 changes: 81 additions & 4 deletions rust/moshi-backend/src/stream_both.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::sync::Arc;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Config {
pub instance_name: String,
#[serde(default)]
pub hf_repo: String,
pub lm_model_file: String,
pub log_dir: String,
Expand All @@ -22,6 +23,7 @@ pub struct Config {
pub lm_config: Option<moshi::lm_generate_multistream::Config>,
#[serde(default = "default_false")]
pub use_cpu_for_mimi: bool,
pub asr_delay_in_tokens: Option<usize>,
}

fn default_false() -> bool {
Expand Down Expand Up @@ -322,6 +324,62 @@ pub struct StreamingModel {
}

impl StreamingModel {
fn run_with_state_asr(
&self,
state: &mut moshi::lm_generate_multistream::State,
receiver: std::sync::mpsc::Receiver<Vec<f32>>,
sender: tokio::sync::mpsc::UnboundedSender<StreamOut>,
asr_delay_in_tokens: usize,
) -> Result<()> {
use candle::IndexOp;

let app_state = &self.state;

let mut mimi = app_state.mimi_model.clone();
let config = state.config().clone();

mimi.reset_state();
tracing::info!("processing loop");
let mut prev_text_token = config.text_start_token;
let mimi_device =
if self.state.config.use_cpu_for_mimi { &candle::Device::Cpu } else { &self.device };
mimi_device.synchronize()?;
sender.send(StreamOut::Ready)?;
while let Ok(in_pcm) = receiver.recv() {
if in_pcm.is_empty() {
continue;
}
let pcm_len = in_pcm.len();
sender.send(StreamOut::InputPcm { pcm_len })?;
let pcms = candle::Tensor::from_vec(in_pcm, (1, 1, pcm_len), mimi_device)?;
let audio_tokens = mimi.encode_step(&pcms.into())?;
let audio_tokens = match audio_tokens.as_option() {
None => continue,
Some(audio_tokens) => audio_tokens,
};
let (_one, _codebooks, steps) = audio_tokens.dims3()?;

for step in 0..steps {
let codes = audio_tokens.i((0, .., step))?.to_vec1::<u32>()?;
// For the ASR, we don't provide text tokens during the initial steps except the
// initial one.
if state.step_idx() > 0 && state.step_idx() < asr_delay_in_tokens {
prev_text_token = state.step_(None, &codes, None)?;
} else {
sender.send(StreamOut::StepStart { step })?;
let text_token = state.step(prev_text_token, &codes, None)?;
sender.send(StreamOut::StepPostSampling { step })?;
if let Some(text) = app_state.text(prev_text_token, text_token, &config) {
sender.send(StreamOut::Text { text })?;
}
prev_text_token = text_token;
}
}
}
tracing::info!("finished the processing loop");
Ok(())
}

fn run_with_state(
&self,
state: &mut moshi::lm_generate_multistream::State,
Expand Down Expand Up @@ -374,7 +432,6 @@ impl StreamingModel {
sender.send(StreamOut::Pcm { pcm })?;
}
}

if let Some(text) = app_state.text(prev_text_token, text_token, &config) {
sender.send(StreamOut::Text { text })?;
}
Expand Down Expand Up @@ -550,6 +607,8 @@ impl StreamingModel {
// We want to log the output even if the run function returns an error.
let run_result = if self.state.config.use_cpu_for_mimi {
self.run_with_state_mt(&mut state, receiver, sender)
} else if let Some(asr_delay_in_tokens) = self.state.config.asr_delay_in_tokens {
self.run_with_state_asr(&mut state, receiver, sender, asr_delay_in_tokens)
} else {
self.run_with_state(&mut state, receiver, sender)
};
Expand Down Expand Up @@ -577,8 +636,22 @@ impl StreamingModel {
.unwrap_or_else(|_| String::new())
};
let audio_tokens = state.audio_tokens(false);
let audio_tokens = audio_tokens.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)?;
let audio_tokens = audio_tokens
.iter()
.map(|v| {
v.iter()
.map(|v| {
if *v == moshi::lm_generate_multistream::UNGENERATED {
-1
} else {
*v as i64
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)?
.to_dtype(candle::DType::I64)?;
let audio_tokens = candle::Tensor::new(audio_tokens, &candle::Device::Cpu)?;
let since_epoch = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?;
let (secs, us) = (since_epoch.as_secs(), since_epoch.subsec_micros());
Expand Down Expand Up @@ -718,7 +791,11 @@ pub async fn handle_socket(
let (in_pcm_tx, in_pcm_rx) = std::sync::mpsc::channel();
let (stream_out_tx, stream_out_rx) = tokio::sync::mpsc::unbounded_channel();
let (loop1, loop2) = spawn_recv_loops(receiver, in_pcm_tx)?;
std::thread::spawn(move || sm.run(in_pcm_rx, stream_out_tx, addr));
std::thread::spawn(move || {
if let Err(err) = sm.run(in_pcm_rx, stream_out_tx, addr) {
tracing::error!("{err}")
}
});
let sender_loop = tokio::spawn(async move {
match sender_loop(stream_out_rx, sender).await {
Ok(()) => tracing::info!("sender closed"),
Expand Down

0 comments on commit f301c16

Please sign in to comment.