From db3155cae10b465f9335d52acfa605456841237d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Sun, 11 Feb 2024 22:48:30 +0100 Subject: [PATCH 1/5] Add tests for humongous encodings --- tests/test_encoding.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 27b21925..3d903c64 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -11,6 +11,22 @@ from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + for c in ["^", "0", "a", "'s"]: # TODO " ", "\n" are still failing + print(f"Validating `{c}`") + + big_value = c * 1_000_000 + assert big_value == enc.decode(enc.encode(big_value)) + + big_value = " " + big_value + assert big_value == enc.decode(enc.encode(big_value)) + + big_value = big_value + "\n" + assert big_value == enc.decode(enc.encode(big_value)) + + def test_simple(): enc = tiktoken.get_encoding("gpt2") assert enc.encode("hello world") == [31373, 995] From 58cf8f69f40ef77c64199275e8f483a7e153e2a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Sun, 11 Feb 2024 22:48:53 +0100 Subject: [PATCH 2/5] Add possessive quantifiers to legacy encodings as well --- tests/test_encoding.py | 1 + tiktoken_ext/openai_public.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 3d903c64..0427d6f5 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -11,6 +11,7 @@ from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES +@pytest.mark.skip(reason="Takes a really long time to finish, but was added to reproduce a crash.") @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]): enc = make_enc() diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index 330ecabb..c7b41541 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -6,6 +6,11 @@ FIM_SUFFIX = "<|fim_suffix|>" ENDOFPROMPT = "<|endofprompt|>" +# The pattern in the original GPT-2 release is: +# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +# This is equivalent, but executes faster: +_legacy_splitter_regex = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s+(?!\S)|\s++""" + def gpt2(): mergeable_ranks = data_gym_to_mergeable_bpe_ranks( @@ -17,10 +22,7 @@ def gpt2(): return { "name": "gpt2", "explicit_n_vocab": 50257, - # The pattern in the original GPT-2 release is: - # r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" - # This is equivalent, but executes faster: - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": _legacy_splitter_regex, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -34,7 +36,7 @@ def r50k_base(): return { "name": "r50k_base", "explicit_n_vocab": 50257, - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": _legacy_splitter_regex, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -48,7 +50,7 @@ def p50k_base(): return { "name": "p50k_base", "explicit_n_vocab": 50281, - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": _legacy_splitter_regex, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -62,7 +64,7 @@ def p50k_edit(): special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { "name": "p50k_edit", - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": _legacy_splitter_regex, "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, } @@ -82,7 +84,7 @@ def cl100k_base(): } return { "name": "cl100k_base", - "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""", + "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s*[\r\n]|\s+(?!\S)|\s++""", "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, } From 21c56885e04f14d237cc5d2858ea55717aa1932d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 12 Feb 2024 11:46:49 +0100 Subject: [PATCH 3/5] Update regex dependencies --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 14588580..fb284741 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ crate-type = ["cdylib"] pyo3 = { version = "0.20.0", features = ["extension-module"] } # tiktoken dependencies -fancy-regex = "0.11.0" -regex = "1.8.3" +fancy-regex = "0.13.0" +regex = "1.10.3" rustc-hash = "1.1.0" bstr = "1.5.0" From 5f07fc29ba9719fe4d037c6f4f5214555eb6b347 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Mon, 12 Feb 2024 14:08:12 +0100 Subject: [PATCH 4/5] Lower backtrack_limit to fail earlier for invalid input --- src/lib.rs | 16 +++++++++++++++- tests/test_encoding.py | 3 +-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b466edd1..46712ecd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ use std::num::NonZeroU64; use std::thread; use fancy_regex::Regex; +use fancy_regex::RegexBuilder; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::pyclass; @@ -417,7 +418,7 @@ impl CoreBPE { special_tokens_encoder: HashMap, pattern: &str, ) -> PyResult { - let regex = Regex::new(pattern) + let regex = RegexBuilder::new(pattern).backtrack_limit(10_000).build() .map_err(|e| PyErr::new::(e.to_string()))?; let special_regex = { @@ -572,6 +573,7 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { #[cfg(test)] mod tests { + use fancy_regex::RegexBuilder; use rustc_hash::FxHashMap as HashMap; use crate::{byte_pair_split, Rank}; @@ -596,4 +598,16 @@ mod tests { let res = byte_pair_split(b"abab", &ranks); assert_eq!(res, vec![b"ab", b"ab"]); } + + #[test] + fn test_effect_of_backtrack_limit() { + let regex = RegexBuilder::new(r"(a|b|ab)*(?=c)") + .backtrack_limit(10) + .build() + .expect("Failed to build regex") + .clone(); + + let input = "ab".repeat(100) + "c"; + assert!(regex.is_match(&input).is_err(), "Should throw"); + } } diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 0427d6f5..687dbdcc 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -11,14 +11,13 @@ from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES -@pytest.mark.skip(reason="Takes a really long time to finish, but was added to reproduce a crash.") @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]): enc = make_enc() for c in ["^", "0", "a", "'s"]: # TODO " ", "\n" are still failing print(f"Validating `{c}`") - big_value = c * 1_000_000 + big_value = c * 10_000 assert big_value == enc.decode(enc.encode(big_value)) big_value = " " + big_value From 51c8a8a22c052add1158da8fae1e5772ad990d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc?= Date: Tue, 13 Feb 2024 12:01:44 +0100 Subject: [PATCH 5/5] Fix whitespace catastrophic backtracking --- tests/test_encoding.py | 2 +- tiktoken_ext/openai_public.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 687dbdcc..0e02b47a 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]): enc = make_enc() - for c in ["^", "0", "a", "'s"]: # TODO " ", "\n" are still failing + for c in ["^", "0", "a", "'s", " ", "\n"]: print(f"Validating `{c}`") big_value = c * 10_000 diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index c7b41541..ce33973b 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -9,7 +9,7 @@ # The pattern in the original GPT-2 release is: # r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" # This is equivalent, but executes faster: -_legacy_splitter_regex = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s+(?!\S)|\s++""" +_legacy_splitter_regex = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s""" def gpt2(): @@ -84,7 +84,7 @@ def cl100k_base(): } return { "name": "cl100k_base", - "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s*[\r\n]|\s+(?!\S)|\s++""", + "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""", "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, }