Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex loading from tokenizer.json and code refinement #863

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime_extensions/pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def __init__(self, tokenizer_dir):
self.tokenizer = create_tokenizer(tokenizer_dir)

def tokenize(self, text):
if isinstance(text, (list, tuple)):
return batch_tokenize(self.tokenizer, text)
return batch_tokenize(self.tokenizer, [text])[0]

def detokenize(self, tokens):
return batch_detokenize(self.tokenizer, [tokens])[0]
return batch_detokenize(self.tokenizer, [tokens])

def __del__(self):
if delete_object and self.tokenizer:
Expand Down
14 changes: 4 additions & 10 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,

// Parse input
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
bpe::TokenWithRegularExp regcmp;
bpe::PreTokenizerWithRegEx reg_splitter;

for (auto& seg_id : special_token_split_res) {
if (static_cast<int64_t>(res.size()) >= max_length) break;
Expand All @@ -274,7 +274,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,

// Note: keep ptr to make sure the string_view is valid in the following process
std::u32string str(seg_id.first);
regcmp.Set(str.c_str());
reg_splitter.Set(str.c_str());

size_t offset = 0;
OffsetMappingType offset_mapping;
Expand All @@ -287,14 +287,8 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
}

while (static_cast<int64_t>(res.size()) < max_length) {
std::string regex_expr = "";
if (ModelName() == kModel_Llama){
regex_expr = regcmp.LLAMA_REGEX_PATTERN;
} else {
// default to GPT2 regex
regex_expr = regcmp.GPT2_REGEX_PATTERN;
}
auto [b, tok] = regcmp.GetNextToken(regex_expr);
std::string regex_expr = bbpe_tokenizer_->GetPreTokenizerRegex(ModelName());
auto [b, tok] = reg_splitter.GetNextToken(regex_expr);

if (!b) break;

Expand Down
68 changes: 68 additions & 0 deletions operators/tokenizer/bpe_tokenizer_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,58 @@ class BpeModel {
}
}

OrtxStatus LoadPreTokenizer(const json& bpe_model) {
auto node_pre_tokenizer = bpe_model.find("pre_tokenizer");
if (node_pre_tokenizer == bpe_model.end() || node_pre_tokenizer->is_null()) {
return {};
}

auto iter_type = node_pre_tokenizer->find("type");
if (iter_type == node_pre_tokenizer->end() || iter_type->is_null()) {
return {};
}

if (iter_type->get<std::string>() != "Sequence") {
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
Copy link
Contributor

@sayanshaw24 sayanshaw24 Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something in the test seems off, or maybe my understanding is wrong - i saw in test_pp_api.py you use the "amd/AMD-OLMo-1B-SFT-DPO" model, and if you look in the tokenizer.json for it, pre_tokenizer type is "ByteLevel", not "Sequence" (Note: you have support for "ByteLevel" for "pretokenizers" below, (which I guess you are expecting within the "pre_tokenizer") but not for "pre_tokenizer" itself); so it would fail here right? (but in the CI it is passing)

So - maybe we should add "ByteLevel" to the supported types for "pre_tokenizer" as well here, but also first identify why it is not failing the test currently, perhaps the type is not being extracted right or it is conflating "pretokenizers" and "pre_tokenizer".

}

auto iter_node_list = node_pre_tokenizer->find("pretokenizers");

if (iter_node_list == node_pre_tokenizer->end() || iter_node_list->is_null()) {
return {};
}

for (const auto& node : *iter_node_list) {
auto iter_type = node.find("type");
if (iter_type == node.end() || iter_type->is_null()) {
continue; // ignore unknown pre-tokenizer type
}


auto pre_type = iter_type->get<std::string>();
if (pre_type == "Split") {
auto iter_pattern = node.find("pattern");
if (iter_pattern == node.end() || iter_pattern->is_null()) {
continue;
}

auto regex_str = iter_pattern->find("Regex");
Copy link
Contributor

@sayanshaw24 sayanshaw24 Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am seeing some examples of lowercase "regex" in tokenizer.json as well - perhaps we make the case insensitive here?

if (regex_str == iter_pattern->end() || regex_str->is_null()) {
continue;
}

pre_tokenizer_regex_ = regex_str->get<std::string>();
} else if (pre_type == "ByteLevel") {
; // need to add more flag support here in the future
}
else {
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
}
}

return {};
}

OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token,
const char* special_tokens, bool spm_converted) {
nlohmann::json tok_json;
Expand Down Expand Up @@ -121,6 +173,8 @@ class BpeModel {
}

OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) {
ORTX_RETURN_IF_ERROR(LoadPreTokenizer(bpe_model));

const json& vocab_json = bpe_model["vocab"];
const json& merges_json = bpe_model["merges"];
vocab_json.get_to(vocab_map_);
Expand Down Expand Up @@ -358,6 +412,19 @@ class BpeModel {

const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; }

std::string GetPreTokenizerRegex(const std::string& model_name) const {
if (!pre_tokenizer_regex_.empty()) {
return pre_tokenizer_regex_;
}

if (model_name == "Llama") {
return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN;
}

// by default, use the GPT2 pretokenizer regex
return bpe::PreTokenizerWithRegEx::GPT2_REGEX_PATTERN;
}

private:
struct BpeNode {
uint32_t id;
Expand All @@ -379,6 +446,7 @@ class BpeModel {
uint32_t unk_id_ = (std::numeric_limits<uint32_t>::max)();
bpe::SpecialTokenMap special_tokens_;
TrieTree<char32_t> added_tokens_;
std::string pre_tokenizer_regex_;
};

} // namespace ort_extensions
10 changes: 5 additions & 5 deletions operators/tokenizer/bpe_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ class SpecialTokenMap {
std::unordered_map<ustring, int> token_map_;
};

class TokenWithRegularExp {
class PreTokenizerWithRegEx {
public:
static constexpr const char* GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
static constexpr const char* LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
static constexpr const char* LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

void Set(std::u32string_view val) {
m_text = val;
}
Expand All @@ -115,10 +119,6 @@ class TokenWithRegularExp {
return {false, {}};
}

const std::string GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
const std::string LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
const std::string LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

public:

// Although we have RegexMatchGeneral which performs regex matching given any general regex string
Expand Down
1 change: 1 addition & 0 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
{"GPT2Tokenizer", TokenType::kBPE},
{"Qwen2Tokenizer", TokenType::kBPE},
{"BaichuanTokenizer", TokenType::kBPE},
{"GPTNeoXTokenizer", TokenType::kBPE},

{"", TokenType::kUnigram},
{"T5Tokenizer", TokenType::kUnigram},
Expand Down
2 changes: 1 addition & 1 deletion pyop/py_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
OrtxTokenizer* tokenizer = nullptr;
auto err = OrtxCreateTokenizer(&tokenizer, tokenizer_def_json.c_str());
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to create tokenizer") + OrtxGetLastErrorMessage());
throw std::runtime_error(std::string("Failed to create tokenizer\n") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(tokenizer);
},
Expand Down
77 changes: 28 additions & 49 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,15 @@

namespace ort_extensions {

std::set<std::string> TokenizerImpl::supported_bpe_models_ = {
"PreTrainedTokenizerFast",
"CLIPTokenizer",
"WhisperTokenizer",
"GemmaTokenizer",
"LlamaTokenizer",
"Phi3Tokenizer",
"CodeLlamaTokenizer",
"CodeGenTokenizer",
"GPT2Tokenizer",
"Qwen2Tokenizer",
"BaichuanTokenizer"
};

std::set<std::string> TokenizerImpl::supported_ugm_models_ = {
"XLMRobertaTokenizer",
"T5Tokenizer",
"ChatGLMTokenizer"
};

TokenizerImpl::TokenizerImpl()
: OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {};
TokenizerImpl::~TokenizerImpl() {};

OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
if (tok_config_->tokenizer_class_.empty() ||
supported_ugm_models_.count(tok_config_->tokenizer_class_)) {

auto type = TokenJsonConfig::GetTokenType(tok_config_->tokenizer_class_);
if (type == TokenType::kUnigram) {
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
auto status = tokenizer->Load(*tok_config_);
if (!status.IsOk()) {
Expand All @@ -53,42 +35,39 @@ OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}

return status;
}

if (!supported_bpe_models_.count(tok_config_->tokenizer_class_)) {
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}

auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
} else if (type == TokenType::kBPE) {
auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;

auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
}

auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);

auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);
if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}

if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
return status;
}

return status;
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}

OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {
Expand Down
3 changes: 0 additions & 3 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace ort_extensions {

class TokenizerImpl : public OrtxObjectImpl {
public:
static std::set<std::string> supported_bpe_models_;
static std::set<std::string> supported_ugm_models_;

TokenizerImpl();
virtual ~TokenizerImpl();

Expand Down
8 changes: 4 additions & 4 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TEST(CApiTest, StreamApiTest) {

TEST(OrtxTokenizerTest, RegexTest) {
std::u32string str = U"CAN'T \r\n 2413m";
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::u32string> res;
std::vector<std::u32string> out_tokens = {U"CAN", U"'T", U" \r\n", U" ", U"241", U"3", U"m"};
Expand All @@ -91,7 +91,7 @@ TEST(OrtxTokenizerTest, RegexMatchSTDTest) {
std::vector<std::u32string> input_strings = {U"not its, or IT'S, but it's",
U" ",
U"AbCd"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::vector<std::u32string>> res_vector;
std::vector<std::vector<std::u32string>> out_tokens = {{U"'s"},
Expand All @@ -118,7 +118,7 @@ TEST(OrtxTokenizerTest, WrapStandaloneCategoriesTest) {
"\\p{rn}\\p{L}\\p{N}\\p{L}",
"\\p{Z}*[\\p{rn}]+",
"\\p{Z}+"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::string> res;
std::vector<std::string> out_regex = {"[^\\p{rn}\\p{L}\\p{N}]?[\\p{L}]+",
Expand Down Expand Up @@ -152,7 +152,7 @@ TEST(OrtxTokenizerTest, RegexMatchGeneralTest) {
U"241356m",
U"Ich liebe München <3 \r\n ",
U"生活的真谛是"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::vector<std::u32string>> res_vector;
std::vector<std::vector<std::u32string>> out_tokens = {{U"CAN", U"'T", U"", U""},
Expand Down
17 changes: 14 additions & 3 deletions test/test_pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@


is_pp_api_available = False
hf_token_id = None
try:
from transformers import AutoImageProcessor
from transformers import AutoImageProcessor, AutoTokenizer
from onnxruntime_extensions import pp_api
is_pp_api_available = True
hf_token_id = os.environ.get("HF_TOKEN", None)
except ImportError:
pass

Expand Down Expand Up @@ -46,7 +48,6 @@ def setUpClass(cls):
else:
cls.temp_dir = tempfile.mkdtemp()
print(f"Created temp dir: {cls.temp_dir}")
cls.token_id = os.environ.get("HF_TOKEN", None)

def test_CLIP_image_processing(self):
model_id = "openai/clip-vit-large-patch14"
Expand Down Expand Up @@ -76,6 +77,7 @@ def test_CLIP_image_processing(self):
a_image = regen_image(np.transpose(actual, (1, 2, 0)))
a_image.save(f"{self.temp_dir}/CLIP_a_{i}.png")

@unittest.skipIf(hf_token_id is None, "HF_TOKEN is not available")
def test_llama3_2_image_processing(self):
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

Expand All @@ -91,7 +93,7 @@ def test_llama3_2_image_processing(self):
"test/data/processor/exceltable.png"]
(image, image2, image3) = [Image.open(f) for f in image_list]

processor = AutoImageProcessor.from_pretrained(model_id, token=TestPPAPI.token_id)
processor = AutoImageProcessor.from_pretrained(model_id, token=hf_token_id)
inputs = processor.preprocess(
[image, image2, image3], return_tensors="np")
print({k: v.shape if k == "pixel_values" else v for k, v in inputs.items()})
Expand All @@ -114,6 +116,15 @@ def test_llama3_2_image_processing(self):
a_image = regen_image(np.transpose(actual, (1, 2, 0)))
a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png")

def test_OLMa_tokenizer(self):
test_sentence = ["I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61" + " |||IP_ADDRESS|||"]
model_id = "amd/AMD-OLMo-1B-SFT-DPO"
hf_enc = AutoTokenizer.from_pretrained(model_id)
inputs = hf_enc(test_sentence)["input_ids"]
tokenizer = pp_api.Tokenizer(model_id)
ortx_inputs = tokenizer.tokenize(test_sentence)
# self.assertEqual(inputs, ortx_inputs)
np.testing.assert_array_equal(ortx_inputs, inputs)

if __name__ == '__main__':
unittest.main()
Loading