Skip to content

Commit

Permalink
refactor: use RegexSet for cleaer stop regex construction
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Oct 2, 2023
1 parent 0643e57 commit 175f4ad
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::sync::Arc;

use dashmap::DashMap;
use regex::Regex;
use regex::RegexSet;
use tokenizers::tokenizer::Tokenizer;

pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
stop_regex_cache: DashMap<&'static Vec<&'static str>, RegexSet>,
}

fn reverse<T>(s: T) -> String
Expand Down Expand Up @@ -33,7 +33,7 @@ impl DecodingFactory {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
}

fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<RegexSet> {
if stop_words.is_empty() {
None
} else {
Expand All @@ -48,18 +48,19 @@ impl DecodingFactory {
}
}

fn create_stop_regex(stop_words: &[&str]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();

fn create_stop_regex(stop_words: &[&str]) -> RegexSet {
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
Regex::new(&regex_string).unwrap()
let tokens: Vec<String> = stop_words
.iter()
.map(|x| r"(?m)\A".to_owned() + &reverse(*x))
.collect();
RegexSet::new(tokens).expect("Failed to create regex set")
}

pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>,
stop_re: Option<Regex>,
stop_re: Option<RegexSet>,

token_ids: Vec<u32>,
prefix_offset: usize,
Expand All @@ -69,7 +70,11 @@ pub struct IncrementalDecoding {
}

impl IncrementalDecoding {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
pub fn new(
tokenizer: Arc<Tokenizer>,
stop_re: Option<RegexSet>,
input_token_ids: &[u32],
) -> Self {
let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer.");
Expand Down Expand Up @@ -112,7 +117,7 @@ impl IncrementalDecoding {
self.reversed_text = reverse(new_text) + &self.reversed_text;

if let Some(re) = &self.stop_re {
if re.find(&self.reversed_text).is_some() {
if re.is_match(&self.reversed_text) {
return None;
}
}
Expand All @@ -131,7 +136,6 @@ mod tests {
let stop_words = vec!["\n\n", "\n\n "];
let re = create_stop_regex(&stop_words);
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
let matched = re.find(&text).is_some();
assert!(!matched);
assert!(!re.is_match(&text))
}
}

0 comments on commit 175f4ad

Please sign in to comment.