Skip to content

Commit

Permalink
refactor: extract registry to its sub dir.
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Sep 26, 2023
1 parent 585aba6 commit 88947aa
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 34 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/tabby-download/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ tokio-retry = "0.3.0"
urlencoding = "2.1.3"
serde_json = { workspace = true }
cached = { version = "0.46.0", features = ["async", "proc_macro"] }
async-trait = { workspace = true }
12 changes: 6 additions & 6 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use anyhow::{anyhow, Result};
use cache_info::CacheInfo;
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use registry::HuggingFaceRegistry;
use registry::{create_registry, Registry};
use tabby_common::path::ModelDir;
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Expand All @@ -17,15 +17,15 @@ use tokio_retry::{
pub struct Downloader {
model_id: String,
prefer_local_file: bool,
registry: HuggingFaceRegistry,
registry: Box<dyn Registry>,
}

impl Downloader {
pub fn new(model_id: &str, prefer_local_file: bool) -> Self {
Self {
model_id: model_id.to_owned(),
prefer_local_file,
registry: HuggingFaceRegistry::default(),
registry: create_registry(),
}
}

Expand Down Expand Up @@ -62,7 +62,7 @@ impl Downloader {
let mut cache_info = CacheInfo::from(&self.model_id).await;
for (path, required) in files {
download_model_file(
&self.registry,
self.registry.as_ref(),
&mut cache_info,
&self.model_id,
path,
Expand All @@ -76,7 +76,7 @@ impl Downloader {
}

async fn download_model_file(
registry: &HuggingFaceRegistry,
registry: &dyn Registry,
cache_info: &mut CacheInfo,
model_id: &str,
path: &str,
Expand Down Expand Up @@ -122,7 +122,7 @@ async fn download_model_file(
}

async fn download_file(
registry: &HuggingFaceRegistry,
registry: &dyn Registry,
url: &str,
path: &str,
local_cache_key: Option<&str>,
Expand Down
24 changes: 24 additions & 0 deletions crates/tabby-download/src/registry/huggingface.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;

use crate::Registry;

#[derive(Default)]
pub struct HuggingFaceRegistry {}

#[async_trait]
impl Registry for HuggingFaceRegistry {
fn build_url(&self, model_id: &str, path: &str) -> String {
format!("https://huggingface.co/{}/resolve/main/{}", model_id, path)
}

async fn build_cache_key(&self, url: &str) -> Result<String> {
let res = reqwest::get(url).await?;
let cache_key = res
.headers()
.get("etag")
.ok_or(anyhow!("etag key missing"))?
.to_str()?;
Ok(cache_key.to_owned())
}
}
25 changes: 25 additions & 0 deletions crates/tabby-download/src/registry/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
mod huggingface;
mod modelscope;

use anyhow::Result;
use async_trait::async_trait;
use huggingface::HuggingFaceRegistry;

use self::modelscope::ModelScopeRegistry;

#[async_trait]
pub trait Registry {
fn build_url(&self, model_id: &str, path: &str) -> String;
async fn build_cache_key(&self, url: &str) -> Result<String>;
}

pub fn create_registry() -> Box<dyn Registry> {
let registry = std::env::var("TABBY_REGISTRY").unwrap_or("huggingface".to_owned());
if registry == "huggingface" {
Box::<HuggingFaceRegistry>::default()
} else if registry == "modelscope" {
Box::<ModelScopeRegistry>::default()
} else {
panic!("Unsupported registry {}", registry);
}
}
Original file line number Diff line number Diff line change
@@ -1,45 +1,30 @@
use std::collections::HashMap;

use anyhow::{anyhow, Result};
use async_trait::async_trait;
use cached::proc_macro::cached;
use reqwest::Url;
use serde::Deserialize;

#[derive(Default)]
pub struct HuggingFaceRegistry {}

impl HuggingFaceRegistry {
pub fn build_url(&self, model_id: &str, path: &str) -> String {
format!("https://huggingface.co/{}/resolve/main/{}", model_id, path)
}

pub async fn build_cache_key(&self, url: &str) -> Result<String> {
let res = reqwest::get(url).await?;
let cache_key = res
.headers()
.get("etag")
.ok_or(anyhow!("etag key missing"))?
.to_str()?;
Ok(cache_key.to_owned())
}
}
use crate::Registry;

#[derive(Default)]
pub struct ModelScopeRegistry {
}
pub struct ModelScopeRegistry {}

impl ModelScopeRegistry {
pub fn build_url(&self, model_id: &str, path: &str) -> String {
#[async_trait]
impl Registry for ModelScopeRegistry {
fn build_url(&self, model_id: &str, path: &str) -> String {
format!(
"https://modelscope.cn/api/v1/models/{}/repo?FilePath={}",
model_id,
urlencoding::encode(path)
)
}

pub async fn build_cache_key(&self, url: &str) -> Result<String> {
async fn build_cache_key(&self, url: &str) -> Result<String> {
let url = Url::parse(url)?;
let model_id = url.path()
let model_id = url
.path()
.strip_prefix("/api/v1/models/")
.ok_or(anyhow!("Invalid url"))?
.strip_suffix("/repo")
Expand All @@ -49,13 +34,10 @@ impl ModelScopeRegistry {
let path = query
.get("FilePath")
.ok_or(anyhow!("Failed to extract FilePath"))?;
self.get_cache_key(model_id, path).await
}

async fn get_cache_key(&self, model_id: &str, path: &str) -> Result<String> {
let revision_map = fetch_revision_map(model_id.to_owned()).await?;
for x in revision_map.data.files {
if x.path == path {
if x.path == *path {
return Ok(x.sha256);
}
}
Expand Down

0 comments on commit 88947aa

Please sign in to comment.