Skip to content

Commit

Permalink
Add conditional normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Systemcluster committed Dec 20, 2024
1 parent 7ceb134 commit f2e7d32
Show file tree
Hide file tree
Showing 57 changed files with 1,420,337 additions and 117 deletions.
6 changes: 4 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//!
//! Defines the tokenization mode fallback, input normalization, pre-tokenization split behavior, post-tokenization processing, post-decode processing, and input templates.
use core::ops::Range;

use alloc::borrow::Cow;
use alloc::string::String;
use alloc::vec::Vec;
Expand Down Expand Up @@ -128,12 +130,12 @@ impl Configuration {

/// Normalizes the input before tokenization.
#[inline(never)]
pub fn normalize(&self, text: &mut Cow<str>) {
pub fn normalize(&self, text: &mut Cow<str>, position: Range<usize>) {
if text.is_empty() {
return;
}
for norm in &self.normalization {
norm.normalize(text);
norm.normalize(text, position.clone());
}
}

Expand Down
81 changes: 70 additions & 11 deletions src/config/normalization.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Pre-tokenization input normalization.
use core::ops::Range;

use alloc::borrow::Cow;
use alloc::boxed::Box;
use alloc::string::{String, ToString};
Expand Down Expand Up @@ -62,6 +64,14 @@ impl From<Regex> for NormalizationReplacePattern {
}
}

/// Condition for conditional normalization.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serialization", derive(Deserialize, Serialize))]
pub enum NormalizationCondition {
StartOfText,
EndOfText,
}

/// Pre-tokenization input normalization configuration.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serialization", derive(Deserialize, Serialize))]
Expand Down Expand Up @@ -98,11 +108,16 @@ pub enum Normalization {
},
/// Precompiled character map.
CharsMap { map: CharsMap },
/// Conditional normalization.
Conditional {
condition: NormalizationCondition,
normalization: Box<Normalization>,
},
}

impl Normalization {
#[inline(never)]
pub fn normalize(&self, text: &mut Cow<str>) {
pub fn normalize(&self, text: &mut Cow<str>, position: Range<usize>) {
use Normalization::*;
match self {
Unicode { scheme } => {
Expand Down Expand Up @@ -147,6 +162,17 @@ impl Normalization {
CharsMap { map } => {
normalize_charsmap(text, map);
}
Conditional {
condition,
normalization,
} => {
if match condition {
NormalizationCondition::StartOfText => position.start == 0,
NormalizationCondition::EndOfText => position.end == usize::MAX,
} {
normalization.normalize(text, position);
}
}
}
}
}
Expand Down Expand Up @@ -324,20 +350,20 @@ mod tests {
fn test_normalization_nmt() {
let mut text = Cow::Borrowed("aaa\u{200D}bbb\u{8f}");
let normalization = Normalization::NMT;
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aaa bbb");
}

#[test]
fn test_normalization_case_fold() {
let mut text = Cow::Borrowed("AAA bbb");
let normalization = Normalization::CaseFold { upper: false };
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aaa bbb");

let mut text = Cow::Borrowed("AAA bbb");
let normalization = Normalization::CaseFold { upper: true };
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "AAA BBB");
}

Expand All @@ -347,7 +373,7 @@ mod tests {
let normalization = Normalization::Append {
append: " bbb".to_string(),
};
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aaa bbb");
}

Expand All @@ -357,7 +383,7 @@ mod tests {
let normalization = Normalization::Prepend {
prepend: "aaa ".to_string(),
};
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aaa bbb");
}

Expand All @@ -370,7 +396,7 @@ mod tests {
right: 3,
pad: false,
};
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aabbbaaa");

let mut text = Cow::Borrowed("aba");
Expand All @@ -380,7 +406,7 @@ mod tests {
right: 3,
pad: true,
};
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aabaaa");
}

Expand All @@ -392,15 +418,15 @@ mod tests {
left: 2,
right: 3,
};
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aba");
}

#[test]
fn test_normalization_collapse() {
let mut text = Cow::Borrowed("abbbba bbb");
let normalization = Normalization::Collapse { character: 'b' };
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aba b");
}

Expand All @@ -411,7 +437,40 @@ mod tests {
pattern: Regex::new(r"b").unwrap().into(),
replacement: "a".to_string(),
};
normalization.normalize(&mut text);
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aaa aaa");
}

#[test]
fn test_normalization_conditional() {
let mut text = Cow::Borrowed("aba bbb");
let normalization = Normalization::Conditional {
condition: NormalizationCondition::StartOfText,
normalization: Box::new(Normalization::Replace {
pattern: Regex::new(r"b").unwrap().into(),
replacement: "a".to_string(),
}),
};
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aaa aaa");

let mut text = Cow::Borrowed("aba bbb");
normalization.normalize(&mut text, 1..usize::MAX);
assert_eq!(text, "aba bbb");

let mut text = Cow::Borrowed("aba bbb");
let normalization = Normalization::Conditional {
condition: NormalizationCondition::EndOfText,
normalization: Box::new(Normalization::Replace {
pattern: Regex::new(r"b").unwrap().into(),
replacement: "a".to_string(),
}),
};
normalization.normalize(&mut text, 0..usize::MAX);
assert_eq!(text, "aaa aaa");

let mut text = Cow::Borrowed("aba bbb");
normalization.normalize(&mut text, 0..4);
assert_eq!(text, "aba bbb");
}
}
2 changes: 1 addition & 1 deletion src/convert/tekken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ use ms::Tokenizer;
/// use kitoken::convert::convert_tekken;
/// use kitoken::Kitoken;
///
/// let data = std::fs::read("tests/models/tekken/nemo.json")?;
/// let data = std::fs::read("tests/models/tekken/mistral2410.json")?;
/// let definition = convert_tekken(data).unwrap();
///
/// let tokenizer = Kitoken::try_from(definition).unwrap();
Expand Down
18 changes: 14 additions & 4 deletions src/convert/tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ use hashbrown::HashMap;
use crate::convert::ConversionError;
use crate::{
Configuration, Decoding, Definition, Fallback, InsertionPosition, Kitoken, Metadata, Model,
Normalization, Processing, ProcessingDirection, Regex, Scores, SpecialToken, SpecialTokenKind,
SpecialVocab, Split, SplitBehavior, Template, Token, TokenBytes, TokenId, UnicodeNormalization,
Vocab,
Normalization, NormalizationCondition, Processing, ProcessingDirection, Regex, Scores,
SpecialToken, SpecialTokenKind, SpecialVocab, Split, SplitBehavior, Template, Token,
TokenBytes, TokenId, UnicodeNormalization, Vocab,
};

mod hf {
Expand Down Expand Up @@ -642,13 +642,23 @@ pub fn convert_tokenizers(data: impl AsRef<[u8]>) -> Result<Definition, Conversi
pattern: Regex::new(r" ")?.into(),
replacement: replacement.to_string(),
});
if prepend_scheme != PrependScheme::Never {
if prepend_scheme == PrependScheme::Always {
config.normalization.push(Normalization::Extend {
character: replacement,
left: 1,
right: 0,
pad: true,
});
} else if prepend_scheme == PrependScheme::First {
config.normalization.push(Normalization::Conditional {
condition: NormalizationCondition::StartOfText,
normalization: Box::new(Normalization::Extend {
character: replacement,
left: 1,
right: 0,
pad: true,
}),
});
}
if split {
config.split.push(Split::Pattern {
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl Kitoken {
if let Some(next) = extracted.pop() {
if next.0 > posit {
let mut text = text[posit..next.0].into();
self.config.normalize(&mut text);
self.config.normalize(&mut text, posit..next.0);
parts.push(TextPart {
text,
special: Token::INVALID,
Expand All @@ -295,7 +295,7 @@ impl Kitoken {
posit = next.1;
} else {
let mut rest = text[posit..text.len()].into();
self.config.normalize(&mut rest);
self.config.normalize(&mut rest, posit..usize::MAX);
parts.push(TextPart {
text: rest,
special: Token::INVALID,
Expand Down
Loading

0 comments on commit f2e7d32

Please sign in to comment.