Skip to content

Commit

Permalink
feat: validate model capability before download (#3565)
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Zhang <[email protected]>
  • Loading branch information
zwpaper authored Dec 14, 2024
1 parent 6d866b7 commit 0314db6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 9 deletions.
45 changes: 43 additions & 2 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::{fs, io};

use aim_downloader::{bar::WrappedBar, error::DownloadError, hash::HashChecker, https};
use anyhow::{bail, Result};
use anyhow::{bail, Context, Result};
use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Expand Down Expand Up @@ -187,17 +187,58 @@ async fn download_file(
Ok(())
}

pub async fn download_model(model_id: &str, prefer_local_file: bool) {
pub enum ModelKind {
Embedding,
Completion,
Chat,
}

pub async fn download_model(model_id: &str, prefer_local_file: bool, kind: Option<ModelKind>) {
let (registry, name) = parse_model_id(model_id);

let registry = ModelRegistry::new(registry).await;

if let Some(kind) = kind {
let model_info = registry.get_model_info(name);
validate_model_kind(kind, model_info)
.context(
"Model validation has failed. For TabbyML models, please consult https://github.com/tabbyml/registry-tabby to locate the appropriate models.",
)
.unwrap();
}

let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err);
download_model_impl(&registry, name, prefer_local_file)
.await
.unwrap_or_else(handler)
}

fn validate_model_kind(kind: ModelKind, info: &ModelInfo) -> Result<()> {
match kind {
ModelKind::Embedding => Ok(()),
ModelKind::Completion => info
.prompt_template
.as_ref()
.ok_or_else(|| {
anyhow::anyhow!(
"Model '{}' is not a completion model; it does not have a prompt template.",
info.name
)
})
.map(|_| ()),
ModelKind::Chat => info
.chat_template
.as_ref()
.ok_or_else(|| {
anyhow::anyhow!(
"Model '{}' is not a chat model, it does not have a chat template",
info.name
)
})
.map(|_| ()),
}
}

#[cfg(test)]
mod tests {
// filter_download_address tests should be serial because they rely on environment variables
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ pub struct DownloadArgs {
}

pub async fn main(args: &DownloadArgs) {
download_model(&args.model, args.prefer_local_file).await;
download_model(&args.model, args.prefer_local_file, None).await;
info!("model '{}' is ready", args.model);
}
7 changes: 4 additions & 3 deletions crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use tabby_common::{
config::{Config, ModelConfig},
usage,
};
use tabby_download::ModelKind;
use tabby_inference::ChatCompletionStream;
use tokio::{sync::oneshot::Sender, time::sleep};
use tower_http::timeout::TimeoutLayer;
Expand Down Expand Up @@ -212,15 +213,15 @@ pub async fn main(config: &Config, args: &ServeArgs) {

async fn load_model(config: &Config) {
if let Some(ModelConfig::Local(ref model)) = config.model.completion {
download_model_if_needed(&model.model_id).await;
download_model_if_needed(&model.model_id, ModelKind::Completion).await;
}

if let Some(ModelConfig::Local(ref model)) = config.model.chat {
download_model_if_needed(&model.model_id).await;
download_model_if_needed(&model.model_id, ModelKind::Chat).await;
}

if let ModelConfig::Local(ref model) = config.model.embedding {
download_model_if_needed(&model.model_id).await;
download_model_if_needed(&model.model_id, ModelKind::Embedding).await;
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/tabby/src/services/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{fs, sync::Arc};

pub use llama_cpp_server::PromptInfo;
use tabby_common::config::ModelConfig;
use tabby_download::download_model;
use tabby_download::{download_model, ModelKind};
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
use tracing::info;

Expand Down Expand Up @@ -80,10 +80,10 @@ async fn load_completion_and_chat(
(completion, prompt, chat)
}

pub async fn download_model_if_needed(model: &str) {
pub async fn download_model_if_needed(model: &str, kind: ModelKind) {
if fs::metadata(model).is_ok() {
info!("Loading model from local path {}", model);
} else {
download_model(model, true).await;
download_model(model, true, Some(kind)).await;
}
}

0 comments on commit 0314db6

Please sign in to comment.