From b1dfad15234333554ad79b679b1a5f39b6d8352b Mon Sep 17 00:00:00 2001 From: Saibo Desktop Date: Sat, 25 May 2024 18:47:58 +0200 Subject: [PATCH 01/17] style: run black to format code --- README.md | 2 +- docs/benchmarking.md | 4 ++-- docs/json_grammar.md | 16 ++++++++++++++ examples/benchmarking/run_generation.sh | 6 +++-- examples/grammars/SMILES/acrylates.ebnf | 10 ++++----- examples/grammars/SMILES/chain_extenders.ebnf | 8 +++---- examples/grammars/SMILES/generic.ebnf | 4 ++-- examples/grammars/SMILES/isocyanates.ebnf | 6 ++--- examples/grammars/calflow.ebnf | 4 ++-- examples/grammars/geo_query.ebnf | 8 +++---- examples/grammars/overnight.ebnf | 6 ++--- tests/test_parsing/test_parsing.py | 1 - tests/test_string_recognizer/test_smiles.py | 22 ++++++++++++++----- transformers_cfg/tokenization/vocab_struct.py | 6 +++-- 14 files changed, 67 insertions(+), 36 deletions(-) create mode 100644 docs/json_grammar.md diff --git a/README.md b/README.md index 0312ce8..b45e909 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ - **[Token masking optimization](#efficiency)(** (2024-04-25) - **Online [Demo with JSON Grammar](https://huggingface.co/spaces/saibo/transformers-CFG-JSON-demo) at HF space** (2024-04-10) - + - **Support for Unicode(multilingual) grammars** (2024-02-29) - **Integration with Text-Generation-WebUI** (2023-12-17) diff --git a/docs/benchmarking.md b/docs/benchmarking.md index 9c60e48..8e48bc5 100644 --- a/docs/benchmarking.md +++ b/docs/benchmarking.md @@ -1,6 +1,6 @@ # Benchmarking constrained generation overhead in transformers-CFG -This document provides guidelines and on benchmarking grammar constrained decoding when working with the `transformers_cfg` library. +This document provides guidelines and on benchmarking grammar constrained decoding when working with the `transformers_cfg` library. ## Table of Contents @@ -30,7 +30,7 @@ The output of the script will be saved in `transformers_cfg/examples/benchmarkin The output contains the following columns: -- `prompt`: the text of the prompt (see more on the benchmarking prompt design in the `examples/benchmarking/process_benchmarking_logs.ipynb`) +- `prompt`: the text of the prompt (see more on the benchmarking prompt design in the `examples/benchmarking/process_benchmarking_logs.ipynb`) - `n_tokens`: number of tokens generated (can be affected by the `max_new_tokens` parameter) - `run_id`: run id (each generation is performed 5 times per prompt to account for noise in the execution time measurmnet) - `total_time`: total overhead (depends on the complexity of the grammar, the model, the prompt and the device) diff --git a/docs/json_grammar.md b/docs/json_grammar.md new file mode 100644 index 0000000..8af6471 --- /dev/null +++ b/docs/json_grammar.md @@ -0,0 +1,16 @@ +# JSON(JavaScript Object Notation) Grammar + + +## JSON standard + +https://datatracker.ietf.org/doc/html/rfc7159 + +## Clarification + +- JSON doesn't support comments.(JSON5 does but it's not in Python's standard library) +- JSON doesn't support trailing commas. + + +## JSON5 VS JSON + +https://spec.json5.org/ diff --git a/examples/benchmarking/run_generation.sh b/examples/benchmarking/run_generation.sh index ec765d1..96f31f0 100755 --- a/examples/benchmarking/run_generation.sh +++ b/examples/benchmarking/run_generation.sh @@ -1,4 +1,6 @@ -grammar_path=$1 +#!/bin/bash + +grammar_path=$1 grammar_name=$(basename $grammar_path) prompts_path=$2 model_id=${3:-"openai-community/gpt2"} @@ -19,7 +21,7 @@ do do echo "Prompt: $prompt" for run_id in {1..5} - do + do echo "Measurment: $run_id" kernprof -b --skip-zero -v time_benchmarking.py $grammar_path "$prompt" $max_new_tokens $model_id > $tmp_file unconstrained_time=$(cat $tmp_file | grep "Unconstrained time: " | awk '{print $3;}') diff --git a/examples/grammars/SMILES/acrylates.ebnf b/examples/grammars/SMILES/acrylates.ebnf index f8cadac..96fe14f 100644 --- a/examples/grammars/SMILES/acrylates.ebnf +++ b/examples/grammars/SMILES/acrylates.ebnf @@ -1,12 +1,12 @@ -root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* +root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* -group_radical_left ::= "(" ( group_symbol_left (group_bond smiles+)? )+ ")" +group_radical_left ::= "(" ( group_symbol_left (group_bond smiles+)? )+ ")" -group_radical_right ::= "(" ( (smiles+ group_bond )? group_symbol_right )+ ")" +group_radical_right ::= "(" ( (smiles+ group_bond )? group_symbol_right )+ ")" group_bond ::= ( "-" | "\\" | "/" ) -group_symbol_left ::= "C=CC(=O)O" | "C=CC(O)=O" | "C(=C)C(=O)O" | "C(=C)C(O)=O" | "CC(=C)(=O)O" | "CC(=C)(O)=O" +group_symbol_left ::= "C=CC(=O)O" | "C=CC(O)=O" | "C(=C)C(=O)O" | "C(=C)C(O)=O" | "CC(=C)(=O)O" | "CC(=C)(O)=O" group_symbol_right ::= "OC(=O)C=C" | "O=C(O)C=C" | "OC(=O)C(=C)" | "O=C(O)C(=C)" | "O(O=)(C=)CC" | "O=(O)(C=)CC" @@ -49,7 +49,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | "U" | "V" | "W" | "Xe" | "Y" "b"? | - "Z" ( "n" | "r" ) + "Z" ( "n" | "r" ) ring_closure ::= "%" [1-9] [0-9] | [0-9] diff --git a/examples/grammars/SMILES/chain_extenders.ebnf b/examples/grammars/SMILES/chain_extenders.ebnf index a5ca486..2b83e0c 100644 --- a/examples/grammars/SMILES/chain_extenders.ebnf +++ b/examples/grammars/SMILES/chain_extenders.ebnf @@ -1,8 +1,8 @@ -root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* +root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* -group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" +group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" -group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" +group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" group_bond ::= ( "-" | "\\" | "/" ) @@ -50,7 +50,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | "U" | "V" | "W" | "Xe" | "Y" "b"? | - "Z" ( "n" | "r" ) + "Z" ( "n" | "r" ) ring_closure ::= "%" [1-9] [0-9] | [0-9] diff --git a/examples/grammars/SMILES/generic.ebnf b/examples/grammars/SMILES/generic.ebnf index 1471600..4886cb3 100644 --- a/examples/grammars/SMILES/generic.ebnf +++ b/examples/grammars/SMILES/generic.ebnf @@ -16,7 +16,7 @@ wildcard ::= "*" atom_spec ::= "[" isotope? ( "se" | "as" | aromatic_symbol | element_symbol | wildcard ) chiral_class? h_count? ( charge | class? ) "]" -organic_symbol ::= "B" | "C" | "N" | "O" | "P" | "S" | "F" | "I" | "Br" | "Cl" | "At" | "Ts" +organic_symbol ::= "B" | "C" | "N" | "O" | "P" | "S" | "F" | "I" | "Br" | "Cl" | "At" | "Ts" aromatic_symbol ::= "b" | "c" | "n" | "o" | "p" | "s" @@ -39,7 +39,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | "U" | "V" | "W" | "Xe" | "Y" "b"? | - "Z" ( "n" | "r" ) + "Z" ( "n" | "r" ) ring_closure ::= "%" [1-9] [0-9] | [0-9] diff --git a/examples/grammars/SMILES/isocyanates.ebnf b/examples/grammars/SMILES/isocyanates.ebnf index 606e2c3..f99f75a 100644 --- a/examples/grammars/SMILES/isocyanates.ebnf +++ b/examples/grammars/SMILES/isocyanates.ebnf @@ -1,8 +1,8 @@ root ::= ( group_symbol_left group_bond? | (smiles bond?)* group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right (bond? smiles)* | group_bond? group_symbol_right ) -group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" +group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" -group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" +group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" group_bond ::= ( "-" | "\\" | "/" ) @@ -50,7 +50,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | "U" | "V" | "W" | "Xe" | "Y" "b"? | - "Z" ( "n" | "r" ) + "Z" ( "n" | "r" ) ring_closure ::= "%" [1-9] [0-9] | [0-9] diff --git a/examples/grammars/calflow.ebnf b/examples/grammars/calflow.ebnf index e3eaaae..ad6fdd4 100644 --- a/examples/grammars/calflow.ebnf +++ b/examples/grammars/calflow.ebnf @@ -1,4 +1,4 @@ -root ::= call +root ::= call call ::= event | "(Yield " org ")" | "(Yield " "(size" org "))" | "(Yield " event ")" | "(Yield " weather ")" | "(Yield " "(> " "(size" event ")" number "))" | "(do " datetime " " call ")" | "(do " call " " call ")" | "(do " org " " call ")" @@ -57,7 +57,7 @@ event_constraint ::= "(& " event_constraint " " event_constraint ")" | "(FindLastEvent " event_constraint ")" | "(^(Event) " "EmptyStructConstraint)" -location_constraint ::= "(?= " location ")" | "(roomRequest)" | "(&" location_constraint " " location_constraint ")" +location_constraint ::= "(?= " location ")" | "(roomRequest)" | "(&" location_constraint " " location_constraint ")" location ::= "(Event.location " event_constraint ")" | "(LocationKeyphrase.apply " string ")" diff --git a/examples/grammars/geo_query.ebnf b/examples/grammars/geo_query.ebnf index d1d7c70..3a29b35 100644 --- a/examples/grammars/geo_query.ebnf +++ b/examples/grammars/geo_query.ebnf @@ -22,7 +22,7 @@ city ::= "city(" city ")" | "smallest_one(density_1(" city "))" | ALL_CITY -place ::= "placeid('" PLACENAME "')" | +place ::= "placeid('" PLACENAME "')" | "lake(" place ")" | "mountain(" place ")" | "place(" place ")" | @@ -41,10 +41,10 @@ place ::= "placeid('" PLACENAME "')" | "exclude(" place coma_sep place ")" | ALL_PLACE -river ::= "river(" river ")" | +river ::= "river(" river ")" | "riverid('" RIVERNAME "')" | - "major(" river ")" | - "loc_2(" country ")" | + "major(" river ")" | + "loc_2(" country ")" | "loc_2(" state ")" | "longer(" river ")" | "traverse_2(" city ")" | diff --git a/examples/grammars/overnight.ebnf b/examples/grammars/overnight.ebnf index b74ba43..a9ce8f9 100644 --- a/examples/grammars/overnight.ebnf +++ b/examples/grammars/overnight.ebnf @@ -1,4 +1,4 @@ -root ::= "(listValue " list_value ")" +root ::= "(listValue " list_value ")" list_value ::= "(filter " ( list_value " " PROPERTY | list_value " " PROPERTY OP list_value | list_value " " "(ensureNumericProperty " PROPERTY ")" OP "(ensureNumericEntity " list_value ")" ) ")" | @@ -15,7 +15,7 @@ PROPERTY ::= "shape" | "color" | "length" | "is_special" | "width" | "height" | "(reverse " ( "left" | "right" | "above" | "below" ) ")" -SINGLETON_VALUE ::= "en.block" | "en.shape" | "en.color" +SINGLETON_VALUE ::= "en.block" | "en.shape" | "en.color" ENTITY_VALUE ::= "en.block.block1" | "en.block.block2" | "en.shape.pyramid" | "en.shape.cube" | "en.color.red" | "en.color.green" @@ -24,4 +24,4 @@ NUMBER_VALUE ::= ( "3" | "6" ) " " "en.inch" | "2" OP ::= " " ( "=" | ">" | "<" | ">=" | "<=" | "!=" ) " " -AGGREGATE ::= " " ("sum" | "max" | "min" | "avg" ) " " +AGGREGATE ::= " " ("sum" | "max" | "min" | "avg" ) " " diff --git a/tests/test_parsing/test_parsing.py b/tests/test_parsing/test_parsing.py index 81793de..97d8063 100644 --- a/tests/test_parsing/test_parsing.py +++ b/tests/test_parsing/test_parsing.py @@ -138,7 +138,6 @@ def test__parse_literal_string(self): self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3") self.assertListEqual([2, ord("你"), ord("你")], outbuf) - def test__parse_escape(self): escaped_char_src = '"\\n"' outbuf = [] diff --git a/tests/test_string_recognizer/test_smiles.py b/tests/test_string_recognizer/test_smiles.py index 15e8b3a..62d3c58 100644 --- a/tests/test_string_recognizer/test_smiles.py +++ b/tests/test_string_recognizer/test_smiles.py @@ -64,9 +64,17 @@ class MoleculeTestCase: MoleculeTestCase("trans_bond_left", "O=C=N\\C1CC(C\\N=C=O)(CC(C1)(C)C)C"), MoleculeTestCase("trans_bond", "O=C=N\\CCCCCC/N=C=O"), MoleculeTestCase("group_radicals", "CCOC(C(N=C=O)CCCCN=C=O)=O"), - MoleculeTestCase("simple_atom", "O=C=NC1=CC=CC(CC2=CC=C(C=C2N=C=O)CC3=CC=C(C=C3)N=C=O)=C1"), - MoleculeTestCase("single_bond_no_hyphen", "O=C=NC1=CC(CC2=C(C=C(C=C2)CC3=CC=C(C=C3N=C=O)CC4=CC=C(C=C4)N=C=O)N=C=O)=CC=C1"), - MoleculeTestCase("double_bond", "O=C=NC1=CC=C(C=C1)CC2=CC=C(C=C2N=C=O)CC3=C(C=C(C=C3)CC4=CC=C(C=C4N=C=O)CC5=CC=C(C=C5)N=C=O)N=C=O"), + MoleculeTestCase( + "simple_atom", "O=C=NC1=CC=CC(CC2=CC=C(C=C2N=C=O)CC3=CC=C(C=C3)N=C=O)=C1" + ), + MoleculeTestCase( + "single_bond_no_hyphen", + "O=C=NC1=CC(CC2=C(C=C(C=C2)CC3=CC=C(C=C3N=C=O)CC4=CC=C(C=C4)N=C=O)N=C=O)=CC=C1", + ), + MoleculeTestCase( + "double_bond", + "O=C=NC1=CC=C(C=C1)CC2=CC=C(C=C2N=C=O)CC3=C(C=C(C=C3)CC4=CC=C(C=C4N=C=O)CC5=CC=C(C=C5)N=C=O)N=C=O", + ), MoleculeTestCase("interleaved_cycle_explicit", "CC1(CC(CC(CN=C=O)(C1)C)N=C=O)C"), MoleculeTestCase("interleaved_cycle_colon", "CC1=C(C=C(C=C1)CN=C=O)N=C=O"), MoleculeTestCase("cycles", "O=C=N\\c1ccc(cc1)Cc2ccc(\\N=C=O)cc2"), @@ -98,8 +106,12 @@ class MoleculeTestCase: MoleculeTestCase("", "C=CC(=O)OCC(CO)(COC(=O)C=C)COC(=O)C=C"), MoleculeTestCase("", "CCC(COCCCOC(=O)C=C)(COCCCOC(=O)C=C)COCCCOC(=O)C=C"), MoleculeTestCase("", "CCC(COCC(CC)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C"), - MoleculeTestCase("", "C=CC(=O)OCC(CO)(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)COC(=O)C=C"), - MoleculeTestCase("", "C=CC(=O)OCC(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C") + MoleculeTestCase( + "", "C=CC(=O)OCC(CO)(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)COC(=O)C=C" + ), + MoleculeTestCase( + "", "C=CC(=O)OCC(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C" + ), ] valid_chain_extender_sentences = [ diff --git a/transformers_cfg/tokenization/vocab_struct.py b/transformers_cfg/tokenization/vocab_struct.py index eb06a50..f1c606d 100644 --- a/transformers_cfg/tokenization/vocab_struct.py +++ b/transformers_cfg/tokenization/vocab_struct.py @@ -28,8 +28,10 @@ def replace_hex(match): hex_value = match.group(1) return chr(int(hex_value, 16)) - if ("gpt2" in tokenizer.__class__.__name__.lower() - or "pretrained" in tokenizer.__class__.__name__.lower()): # llama3 tokenizer + if ( + "gpt2" in tokenizer.__class__.__name__.lower() + or "pretrained" in tokenizer.__class__.__name__.lower() + ): # llama3 tokenizer special = tokenizer.additional_special_tokens_ids # Here, the decoder does a string replace on a bunch of sequences From f7f929e4e8196306a37bdf783da4b82901855e91 Mon Sep 17 00:00:00 2001 From: Saibo Desktop Date: Sat, 25 May 2024 23:18:31 +0200 Subject: [PATCH 02/17] style: rename TokenTrie into CodepointTrie and refactor implementation --- transformers_cfg/token_grammar_recognizer.py | 20 ++-- .../tokenization/{trie.py => byte_trie.py} | 0 .../tokenization/codepoint_trie.py | 88 ++++++++++++++++++ transformers_cfg/tokenization/vocab_struct.py | 92 ------------------- 4 files changed, 98 insertions(+), 102 deletions(-) rename transformers_cfg/tokenization/{trie.py => byte_trie.py} (100%) create mode 100644 transformers_cfg/tokenization/codepoint_trie.py delete mode 100644 transformers_cfg/tokenization/vocab_struct.py diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index d8195e6..a8cb58d 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -7,8 +7,8 @@ from transformers_cfg.recognizer import StringRecognizer, AcceptState from transformers_cfg.parser import parse_ebnf -from transformers_cfg.tokenization.trie import ByteTrie -from transformers_cfg.tokenization.vocab_struct import LEAF, TokenTrie +from transformers_cfg.tokenization.byte_trie import ByteTrie +from transformers_cfg.tokenization.codepoint_trie import LEAF, CodePointTrie from transformers_cfg.tokenization.mapping import get_mapping logger = logging.getLogger(__name__) @@ -30,14 +30,14 @@ def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False ) self.eos_token_id = tokenizer.eos_token_id - self.token_trie = TokenTrie(tokenizer) + self.code_point_token_trie = CodePointTrie(tokenizer) self.tokenizer = tokenizer self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id) self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode) self.mapping = get_mapping(tokenizer, unicode=unicode) assert len(self.mapping) == len( - self.token_trie - ), f"{len(self.mapping)}, {len(self.token_trie)}" + self.code_point_token_trie + ), f"{len(self.mapping)}, {len(self.code_point_token_trie)}" def _consume_token_id( self, token_id: int, accept_state: AcceptState @@ -142,7 +142,7 @@ def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device): else: accepts = [False] * len(self.mapping) token_acceptance = check_token_acceptance_in_trie( - self.token_trie.trie, + self.code_point_token_trie.trie, [stack], self.string_recognizer, self.eos_token_id, @@ -252,11 +252,11 @@ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts) continue new_stacks = set() - for stk in stacks: - if not stk: + for stack in stacks: + if not stack: continue - next_element_offset = stk[-1] + next_element_offset = stack[-1] num_chars = grammar.grammar_encoding[next_element_offset] if not grammar.char_acceptance_at_element(next_element_offset).get( @@ -266,7 +266,7 @@ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts) continue next_element_offset += num_chars + 1 - new_stack = list(stk[:-1]) + new_stack = list(stack[:-1]) if grammar.grammar_encoding[next_element_offset]: new_stack.append(next_element_offset) new_stacks.update(grammar.expand_stack_head(tuple(new_stack))) diff --git a/transformers_cfg/tokenization/trie.py b/transformers_cfg/tokenization/byte_trie.py similarity index 100% rename from transformers_cfg/tokenization/trie.py rename to transformers_cfg/tokenization/byte_trie.py diff --git a/transformers_cfg/tokenization/codepoint_trie.py b/transformers_cfg/tokenization/codepoint_trie.py new file mode 100644 index 0000000..53080c8 --- /dev/null +++ b/transformers_cfg/tokenization/codepoint_trie.py @@ -0,0 +1,88 @@ +################# +# DATA STRUCTURES +################# + +import logging +import re +from typing import List + +logger = logging.getLogger(__name__) + +LEAF = -1 + + +def fmt_token_as_codepoints(token_id, tokenizer, only_ascii=True) -> List[int]: + + special_token_ids = tokenizer.additional_special_tokens_ids + + tokenizer_class_name = tokenizer.__class__.__name__.lower() + + if "gpt2" in tokenizer_class_name or "pretrained" in tokenizer_class_name: + # GPT-2 or Pretrained tokenizers + # No additional space handling needed + handle_spaces = False + elif "llama" in tokenizer_class_name or "t5" in tokenizer_class_name: + # Llama or T5 tokenizers + # Handle leading space in token + handle_spaces = True + else: + # logger.warning( + # "Warning: unrecognized tokenizer: using default token formatting" + # ) + handle_spaces = False + + if token_id in special_token_ids: + return None + token = tokenizer.decode([token_id], clean_up_tokenization_spaces=False) + if handle_spaces: + raw_token = tokenizer.convert_ids_to_tokens(token_id) + if raw_token.startswith("▁"): + token = " " + token + code_points = [ord(c) for c in token] + # keep only code points within ASCII range + code_points = code_points if all(c < 128 for c in code_points) else None + return code_points + + +class CodePointTrie: + def __init__(self, tokenizer, only_ascii=True): + self.eos_token_id = tokenizer.eos_token_id + self.all_token_codepoints = [] + self.trie = {} + # we only keep ASCII code points + # the reason why we should do this is because to handle unicode properly, we need to handle multi-byte characters + # this can not be done with a simple code point trie + # if we set only_ascii to False, we will be able to handle a subset of unicode characters + # this behavior is probably not what we want + self.only_ascii = only_ascii + self.load_tokens(tokenizer) + + def id2str(self, token_id): + return self.all_token_codepoints[token_id] + + def __len__(self): + return len(self.all_token_codepoints) + + def load_tokens(self, tokenizer): + self.all_token_codepoints = [ + fmt_token_as_codepoints(token_id, tokenizer, self.only_ascii) + for token_id in range(len(tokenizer.get_vocab())) + ] + for token_id, token_codepoints in enumerate(self.all_token_codepoints): + if token_codepoints is not None: + self.insert_into_trie(self.trie, token_codepoints, token_id) + + def insert_into_trie(self, trie, token_bytes, token_id): + current = trie + for byte in token_bytes: + if byte not in current: + current[byte] = {} + current = current[byte] + current[LEAF] = token_id + + +if __name__ == "__main__": + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + token_trie = CodePointTrie(tokenizer) diff --git a/transformers_cfg/tokenization/vocab_struct.py b/transformers_cfg/tokenization/vocab_struct.py deleted file mode 100644 index f1c606d..0000000 --- a/transformers_cfg/tokenization/vocab_struct.py +++ /dev/null @@ -1,92 +0,0 @@ -################# -# DATA STRUCTURES -################# - -import logging -import re - -logger = logging.getLogger(__name__) - -LEAF = -1 - - -class TokenTrie: - def __init__(self, tokenizer): - self.eos_token_id = tokenizer.eos_token_id - self.tokens = [] - self.trie = {} - self.load_tokens(tokenizer) - - def id2str(self, token_id): - return self.tokens[token_id] - - def __len__(self): - return len(self.tokens) - - def load_tokens(self, tokenizer): - def replace_hex(match): - hex_value = match.group(1) - return chr(int(hex_value, 16)) - - if ( - "gpt2" in tokenizer.__class__.__name__.lower() - or "pretrained" in tokenizer.__class__.__name__.lower() - ): # llama3 tokenizer - special = tokenizer.additional_special_tokens_ids - - # Here, the decoder does a string replace on a bunch of sequences - # like ' .' for '.'. This interferes with our assumptions, where a - # token should always have exactly one representation. - # Fortunately(?) text-generation-inference doesn't seem to run this - # cleanup, so we get extraneous spaces. So, in order to generate - # the right token set for TGI, we have to skip the space trimming. - # See: - # https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3588-L3600 - def fmt_token(id): - if id in special: - return None - return bytes( - tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8" - ) - - elif ( - "llama" in tokenizer.__class__.__name__.lower() - or "t5" in tokenizer.__class__.__name__.lower() - ): - - def fmt_token(id): - token = tokenizer.convert_ids_to_tokens(id) - token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token) - token = token.replace("▁", " ") - return bytes(token, "utf-8") - - else: - logger.warning( - "Warning: unrecognized tokenizer: using default token formatting" - ) - - def fmt_token(id): - token = tokenizer.convert_ids_to_tokens(id) - return bytes(token, "utf-8") - - # note: vocab_size doesn't work here because there are also - # get_added_vocab() tokens - self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))] - for token_id, token_bytes in enumerate(self.tokens): - if token_bytes is not None: - self.insert_into_trie(self.trie, token_bytes, token_id) - - def insert_into_trie(self, trie, token_bytes, token_id): - current = trie - for byte in token_bytes: - if byte not in current: - current[byte] = {} - current = current[byte] - current[LEAF] = token_id - - -if __name__ == "__main__": - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - token_trie = TokenTrie(tokenizer) From 218820b5ba6156fcf5eca9d46df4ea21c8eb36d6 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 27 May 2024 14:16:51 +0200 Subject: [PATCH 03/17] trie related recognizer refactoring --- transformers_cfg/recognizer.py | 73 +++--------------- transformers_cfg/token_grammar_recognizer.py | 78 ++++++++------------ transformers_cfg/tokenization/byte_trie.py | 22 +++--- transformers_cfg/utf8_utils.py | 52 +++++++------ 4 files changed, 81 insertions(+), 144 deletions(-) diff --git a/transformers_cfg/recognizer.py b/transformers_cfg/recognizer.py index dfbc92e..27f2eec 100644 --- a/transformers_cfg/recognizer.py +++ b/transformers_cfg/recognizer.py @@ -100,9 +100,6 @@ def init_stack(self, start_rule_id: int) -> Set[Tuple[int]]: def get_initial_accept_state(self) -> AcceptState: return AcceptState(self.init_stack(self.start_rule_id), PartialUTF8()) - def get_termination_accept_state(self) -> AcceptState: - return AcceptState(set(), PartialUTF8()) - @lru_cache(maxsize=32768) def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]: """ @@ -130,7 +127,6 @@ def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]: new_stacks: Set[Tuple[int]] = set() # Loop over alternates of referenced rule to build new stacks while self.grammar_encoding[ref_subrule_offset] != END_OF_RULE_MARKER: - # copy the original stack without the last element new_stack = list(stack[:-1]) # if the rule ref is followed by another element, we add it to the stack next_element_offset = cur_element_offset + 2 @@ -150,12 +146,6 @@ def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]: return new_stacks - def _consume_byte(self, byte: int, accept_state: AcceptState) -> AcceptState: - # suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes - # we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times - return self._consume_bytes(bytes([byte]), accept_state) - - # @lru_cache(maxsize=32768) def _try_accept_bytes( self, byte_seq: bytes, @@ -167,8 +157,6 @@ def _try_accept_bytes( The difference between accept_bytes and consume_bytes is that accept_bytes returns a boolean and consume_bytes returns a new accept state """ - if type(byte_seq) is list: - byte_seq = bytes(byte_seq) code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) if verbose: logging.debug( @@ -177,7 +165,6 @@ def _try_accept_bytes( new_stacks = self._consume_code_points_for_all_stacks(code_points, stacks) for stack in new_stacks: - # stack is empty, meaning that the variables are all consumed if len(stack) == 0: return True @@ -192,12 +179,14 @@ def _consume_bytes( accept_state: Optional[AcceptState] = None, verbose=True, ) -> AcceptState: + if accept_state is None: accept_state = self.get_initial_accept_state() stacks = accept_state.stacks partial_utf8 = accept_state.partial_utf8 if type(byte_seq) is list: byte_seq = bytes(byte_seq) + code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) if verbose: logging.debug( @@ -226,7 +215,7 @@ def _consume_code_point_for_all_stacks( ) -> Set[Tuple[int]]: """ consume a character from the stack - char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] + code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] """ new_stacks: Set[Tuple[int]] = set() @@ -244,7 +233,7 @@ def _consume_code_point_for_single_stack( ) -> Set[Tuple[int]]: """ consume a character from the stack - char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] + code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] """ # TODO, the below code will raise an error when the stack is empty, but why is this happening? # if len(stacks) == 0: @@ -253,15 +242,13 @@ def _consume_code_point_for_single_stack( # to indicate that the character is not accepted new_stacks: Set[Tuple[int]] = set() - if code_point == 0: - return new_stacks - # stack is empty - if len(stack) == 0: + + if code_point == 0 or len(stack) == 0: return new_stacks element_offset = stack[-1] - found = self.accept_code_point_at_element(code_point, element_offset) + if not found: return new_stacks @@ -398,20 +385,6 @@ def _accept_string(self, string: str, accept_state: Optional[AcceptState] = None ) return at_least_one_stack_is_empty - def _can_stop(self, stacks: Set[Tuple[int]]): - # This happens in practice, but maybe it shouldn't? TODO - if len(stacks) == 0: - return True - # if any of the stack is empty, we can stop - for stack in stacks: - if len(stack) == 0: - return True - else: - return False - - def _must_stop(self, stacks: Set[Tuple[int]]): - return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks) - ############################# # # Not Used @@ -422,50 +395,28 @@ def _must_stop(self, stacks: Set[Tuple[int]]): @lru_cache(maxsize=None) def char_acceptance_at_element(self, element_offset): """ - Caches and returns a dictionary indicating whether a Unicode character is accepted + Caches and returns a set of accepted Unicode characters at a given rule position. This function considers Unicode characters, dynamically - inserting accepted ranges into a dictionary to optimize memory usage. + inserting accepted ranges into the set to optimize memory usage. Args: - rule_offset: The offset in the grammar encoding where the rule starts. Returns: - - A dictionary where each key is a Unicode character (or range) and the value is True if accepted. + - A set of accepted Unicode characters (or range). """ logging.debug(f"element_offset: {element_offset}") - acceptance = {} + acceptance = set() num_chars = self.grammar_encoding[element_offset] element_offset += 1 for i in range(0, num_chars, 2): start = self.grammar_encoding[element_offset + i] end = self.grammar_encoding[element_offset + i + 1] for j in range(start, end + 1): - acceptance[j] = True + acceptance.add(j) logging.debug(acceptance) return acceptance - # def _consume_code_points_new( - # self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False - # ) -> Set[Tuple[int]]: - # new_stacks: Set[Tuple[int]] = set() - # for stack in stacks: - # new_stacks.update( - # self._consume_code_points_per_stack(tuple(code_points), stack, verbose) - # ) - # return new_stacks - # - # @lru_cache(maxsize=30000) - # def _consume_code_points_per_stack( - # self, code_points: Tuple[int], stack: Tuple[int], verbose=False - # ) -> Set[Tuple[int]]: - # stacks = {stack} - # - # for code_point in code_points: - # # Update the stacks variable by consuming each code point. - # stacks = self._consume_code_point_for_all_stacks(code_point, (stack,)) - # - # return stacks - if __name__ == "__main__": # set logging level diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index a8cb58d..90272c1 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -10,6 +10,7 @@ from transformers_cfg.tokenization.byte_trie import ByteTrie from transformers_cfg.tokenization.codepoint_trie import LEAF, CodePointTrie from transformers_cfg.tokenization.mapping import get_mapping +from typing import Set, Tuple logger = logging.getLogger(__name__) @@ -35,25 +36,32 @@ def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id) self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode) self.mapping = get_mapping(tokenizer, unicode=unicode) - assert len(self.mapping) == len( + self.vocab_size = len(self.mapping) + assert self.vocab_size == len( self.code_point_token_trie - ), f"{len(self.mapping)}, {len(self.code_point_token_trie)}" + ), f"{self.vocab_size}, {len(self.code_point_token_trie)}" + + def _must_stop(self, stacks: Set[Tuple[int]]): + return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks) + + def _can_stop(self, stacks: Set[Tuple[int]]): + # if at least one of the stack is empty, we can stop + return len(stacks) == 0 or any(len(stack) == 0 for stack in stacks) def _consume_token_id( self, token_id: int, accept_state: AcceptState ) -> AcceptState: - if self.string_recognizer._must_stop(accept_state.stacks): + if self._must_stop(accept_state.stacks): if token_id == self.eos_token_id: - return self.string_recognizer.get_termination_accept_state() + return AcceptState.empty_state() else: raise ValueError( f"All stacks are empty, so the only token accepted is EOS({self.eos_token_id}), but got {token_id}" ) if token_id == self.eos_token_id: - if self.string_recognizer._can_stop(accept_state.stacks): - # if at least one of the stack is empty, we can stop + if self._can_stop(accept_state.stacks): # we clear all the stacks, meaning that we don't accept any token after EOS - return self.string_recognizer.get_termination_accept_state() + return AcceptState.empty_state() else: raise ValueError( f"At least one of the stack should be empty when EOS is reached. However, " @@ -66,28 +74,6 @@ def _consume_token_id( ) return accept_state - def try_accept_token_id(self, token_id: int, accept_state: AcceptState) -> bool: - stacks = accept_state.stacks - if self.string_recognizer._must_stop(stacks): - if token_id == self.eos_token_id: - return True - else: - return False - if token_id == self.eos_token_id: - if self.string_recognizer._can_stop(stacks): - # if at least one of the stack is empty, we can stop - # we clear all the stacks, meaning that we don't accept any token after EOS - return True - else: - return False - # for code_point in self.mapping.map(token_id): - # stacks = self.grammar._consume_char_code_point(code_point, stacks) - bytes_or_codepoints = self.mapping.map(token_id, verbose=False) - new_acc_state = self.string_recognizer._consume_bytes( - bytes_or_codepoints, accept_state, verbose=False - ) - return len(new_acc_state.stacks) > 0 - def consume_token_ids(self, *args, **kwargs): """Process a list of tokens according to the grammar rules.""" raise NotImplementedError @@ -102,10 +88,9 @@ def filter_vocab(self, accept_state, device) -> torch.Tensor: if not accept_state.stacks: # Check if stacks is empty # Handle the empty case: for example, return a tensor of False # The size of the tensor should match the size of your vocabulary - vocab_size = len(self.mapping) logger.debug(f"Empty stack, sum of acceptance: {0}") # size of the vocab - accepts = [False] * vocab_size + accepts = [False] * self.vocab_size accepts[self.eos_token_id] = True return torch.tensor(accepts, dtype=torch.bool, device=device) @@ -127,26 +112,30 @@ def get_token_acceptance(self, accept_state, device) -> torch.Tensor: return acceptance @lru_cache(maxsize=32768) - def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device): - # stack = list(stack) # needs to come in as a tuple for lru_cache - assert isinstance(stack, tuple) + def get_token_acceptance_array_for_stack(self, stack: Tuple, partial_utf8, device): + assert isinstance(stack, tuple) + + token_acceptance = [False] * self.vocab_size + if self.byte_encoding: - + # boolean function checking if a byte sequence is accepted by the grammar accept_f = lambda x: self.string_recognizer._try_accept_bytes( - x, {stack}, partial_utf8=partial_utf8 + bytes(x), {stack}, partial_utf8=partial_utf8 ) - token_acceptance = self.unicode_trie.get_token_acceptance( - accept=accept_f, accept_eos=False, eos_token_id=self.eos_token_id + self.unicode_trie.get_token_acceptance( + accept=accept_f, + accept_eos=False, + eos_token_id=self.eos_token_id, + token_acceptance=token_acceptance ) else: - accepts = [False] * len(self.mapping) - token_acceptance = check_token_acceptance_in_trie( + check_token_acceptance_in_trie( self.code_point_token_trie.trie, [stack], self.string_recognizer, self.eos_token_id, - accepts, + token_acceptance, ) x = torch.tensor(token_acceptance, dtype=torch.bool, device=device) x_eos = self.validate_and_set_eos_acceptance(x) @@ -241,7 +230,6 @@ def _consume_token_ids( def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts): - for byte, next_trie in trie.items(): if byte == LEAF: token_id = next_trie @@ -259,10 +247,8 @@ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts) next_element_offset = stack[-1] num_chars = grammar.grammar_encoding[next_element_offset] - if not grammar.char_acceptance_at_element(next_element_offset).get( - byte, False - ): - # if the current byte is not accepted by the current rule, we need to try next rule + # if the current byte is not accepted by the current rule, we need to try next rule + if not grammar.accept_code_point_at_element(byte, next_element_offset): continue next_element_offset += num_chars + 1 diff --git a/transformers_cfg/tokenization/byte_trie.py b/transformers_cfg/tokenization/byte_trie.py index dd4cbe3..fcb48e0 100644 --- a/transformers_cfg/tokenization/byte_trie.py +++ b/transformers_cfg/tokenization/byte_trie.py @@ -5,17 +5,8 @@ from transformers_cfg.tokenization.mapping import get_mapping -# from transformers_cfg.parser import parse_ebnf -# from transformers_cfg.recognizer import GrammarRecognizer -# from transformers_cfg.token_grammar_recognizer import IncrementalTokenGrammarRecognizer - logger = logging.getLogger(__name__) -# def check_token_acceptance_in_trie(trie, stacks, grammar, partial_utf8, accept_eos=True, eos_token_id=None) -> List[bool]: -# accept_f = lambda x: grammar._probe_bytes_partial_match(x, stack=stacks, partial_utf8=partial_utf8) -# accepts = trie.get_token_acceptance(accept=accept_f, accept_eos=accept_eos, eos_token_id=eos_token_id) -# return accepts - class TrieNode: def __init__(self): @@ -76,14 +67,16 @@ def dfs(self, accept=lambda x: True, verbose=False) -> List[Tuple[List[int], int def bfs( self, predicate=lambda x: True, verbose=False ) -> List[Tuple[List[int], int]]: + queue = deque([(self.root, [])]) + # TODO: do we need to keep track of the byte sequence? valid_byte_seqs: List[Tuple[List[int], int]] = [] counter = {"visited": 0, "pruned": 0} while queue: counter["visited"] += 1 node, byte_seq = queue.popleft() - if predicate(byte_seq): + if predicate(bytes(byte_seq)): if node.is_end_of_word: valid_byte_seqs.append((byte_seq, node.token_id)) for char, next_node in node.children.items(): @@ -95,18 +88,21 @@ def bfs( return valid_byte_seqs def get_token_acceptance( - self, accept=lambda x: True, accept_eos=True, eos_token_id=None + self, accept, accept_eos, eos_token_id, token_acceptance ) -> List[bool]: + """ + Finds all acceptable tokens for a fixed stack (verified with accept function). + Modifies token_acceptance: a list of booleans, where the ith element is True if the ith token is acceptable, False otherwise. + """ valid_byte_seqs: List[Tuple[List[int], int]] = self.bfs(accept, verbose=True) valid_token_ids: List[int] = [token_id for _, token_id in valid_byte_seqs] - token_acceptance: List[bool] = [False] * (len(self)) + for token_id in valid_token_ids: token_acceptance[token_id] = True if not accept_eos: # eos_token is mapped to an empty string, so it's always accepted regardless of the accept function # this can be undesirable, so we can set it to False to ignore it token_acceptance[eos_token_id] = False - return token_acceptance def _dfs( diff --git a/transformers_cfg/utf8_utils.py b/transformers_cfg/utf8_utils.py index f82dbaf..7bcf490 100644 --- a/transformers_cfg/utf8_utils.py +++ b/transformers_cfg/utf8_utils.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from typing import Tuple - -from dataclasses import dataclass +from typing import List, Tuple +from functools import lru_cache @dataclass @@ -36,34 +35,41 @@ def __eq__(self, other): return self.value == other.value and self.n_remain == other.n_remain -from typing import List, Tuple -from functools import lru_cache +@lru_cache(maxsize=3000000) +def decode_utf8_intermediate( + src: bytes, pos: int, value: int, n_remain: int +) -> Tuple[int, int, int]: + while pos < len(src) and n_remain > 0: + next_byte = src[pos] # Get the next byte to process + # Check if the continuation byte format is correct (`10xxxxxx`) + if (next_byte >> 6) != 2: + return -1, -1, -1 + + # Accumulate the value by shifting left and adding the relevant 6 bits + value = (value << 6) + (next_byte & 0x3F) + pos += 1 # Move to the next byte + n_remain -= 1 # Decrement the number of remaining bytes + return value, n_remain, pos @lru_cache(maxsize=3000000) def decode_utf8( src: bytes, partial_start: PartialUTF8 ) -> Tuple[List[int], PartialUTF8]: + # Lookup table for determining the total bytes based on the first byte's high 4 bits lookup = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4] pos = 0 # Position in the src bytes to start decoding from code_points = [] # List to store the decoded Unicode code points + value = partial_start.value # Start with any previously partial decoded value n_remain = partial_start.n_remain # Number of bytes remaining from a partial decode # If there's a partial sequence left from last decode, try to continue decoding it - while pos < len(src) and n_remain > 0: - next_byte = src[pos] # Get the next byte to process - # Check if the continuation byte format is correct (`10xxxxxx`) - if (next_byte >> 6) != 2: - # If not, it's an invalid sequence. Abort and return a special error state. - code_points = [0] - return code_points, PartialUTF8(0, -1) - - # Accumulate the value by shifting left and adding the relevant 6 bits - value = (value << 6) + (next_byte & 0x3F) - pos += 1 # Move to the next byte - n_remain -= 1 # Decrement the number of remaining bytes + value, n_remain, pos = decode_utf8_intermediate(src, pos, value, n_remain) + # Invalid sequence, return a special error state. + if value == -1 and n_remain == -1 and pos == -1: + return [0], PartialUTF8(0, -1) # If we've completed a partial sequence, add its value to the code points if partial_start.n_remain > 0 and n_remain == 0: @@ -86,13 +92,11 @@ def decode_utf8( value = first_byte & mask # Apply the mask to get the initial value pos += 1 # Move to the next byte - # Process the continuation bytes - while pos < len(src) and n_remain > 0: - next_byte = src[pos] - # Shift the accumulated value and add the next 6 significant bits - value = (value << 6) + (next_byte & 0x3F) - pos += 1 # Move to the next byte - n_remain -= 1 # Decrement the number of remaining bytes + # Decode the continuation bytes + value, n_remain, pos = decode_utf8_intermediate(src, pos, value, n_remain) + # Invalid sequence, return a special error state. + if value == -1 and n_remain == -1 and pos == -1: + return [0], PartialUTF8(0, -1) # If the sequence is complete, add its decoded value to the code points if n_remain == 0: From 657d1b0194ccf60334211101aa0b4b1765073806 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 27 May 2024 16:02:45 +0200 Subject: [PATCH 04/17] mapping refactored + sanity checks for t5/bloom/llama3 and t5 added --- tests/test_tokenizers/test_bloom.py | 1 - tests/test_tokenizers/test_llama3.py | 15 +++++ tests/test_tokenizers/test_phi3.py | 16 +++++ tests/test_tokenizers/test_t5.py | 1 - transformers_cfg/tokenization/mapping.py | 85 ++++++++++++++---------- 5 files changed, 80 insertions(+), 38 deletions(-) create mode 100644 tests/test_tokenizers/test_llama3.py create mode 100644 tests/test_tokenizers/test_phi3.py diff --git a/tests/test_tokenizers/test_bloom.py b/tests/test_tokenizers/test_bloom.py index ebee494..6253698 100644 --- a/tests/test_tokenizers/test_bloom.py +++ b/tests/test_tokenizers/test_bloom.py @@ -7,7 +7,6 @@ import logging -# @unittest.skip("GPTNeoXTokenizerFast is not available for testing") class BloomTokenizerTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = BloomTokenizerFast diff --git a/tests/test_tokenizers/test_llama3.py b/tests/test_tokenizers/test_llama3.py new file mode 100644 index 0000000..8a32e68 --- /dev/null +++ b/tests/test_tokenizers/test_llama3.py @@ -0,0 +1,15 @@ +import unittest + +from transformers import GPT2TokenizerFast +from tests._tokenizer_common import TokenizerTesterMixin + +import logging + + +class Llama3TokenizerTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = GPT2TokenizerFast + pretrained_name = "meta-llama/Meta-Llama-3-8B" + + def setUp(self): + super().setUp() diff --git a/tests/test_tokenizers/test_phi3.py b/tests/test_tokenizers/test_phi3.py new file mode 100644 index 0000000..29bca94 --- /dev/null +++ b/tests/test_tokenizers/test_phi3.py @@ -0,0 +1,16 @@ +import unittest + +from transformers import T5TokenizerFast + +from tests._tokenizer_common import TokenizerTesterMixin + +import logging + + +class Phi3TokenizerTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = T5TokenizerFast + pretrained_name = "microsoft/Phi-3-mini-4k-instruct" + + def setUp(self): + super().setUp() diff --git a/tests/test_tokenizers/test_t5.py b/tests/test_tokenizers/test_t5.py index 936bd39..f4d6bbf 100644 --- a/tests/test_tokenizers/test_t5.py +++ b/tests/test_tokenizers/test_t5.py @@ -7,7 +7,6 @@ import logging -@unittest.skip("T5Tokenizer's mapping is not well defined, not working") class T5TokenizerTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = T5TokenizerFast diff --git a/transformers_cfg/tokenization/mapping.py b/transformers_cfg/tokenization/mapping.py index 7cd1913..a89f1c2 100644 --- a/transformers_cfg/tokenization/mapping.py +++ b/transformers_cfg/tokenization/mapping.py @@ -2,6 +2,7 @@ from transformers_cfg.utils import get_tokenizer_model_type, ints2bytes from transformers import AutoTokenizer +import re import logging log = logging.getLogger(__name__) @@ -10,25 +11,22 @@ def get_mapping(tokenizer, unicode=False): log.debug(f"tokenizer type: {tokenizer.__class__.__name__}") log.debug(f"tokenizer model type: {get_tokenizer_model_type(tokenizer)}") + tokenizer_name = tokenizer.__class__.__name__.lower() if not unicode: - if ( - "gpt2" in tokenizer.__class__.__name__.lower() - or "bloom" in tokenizer.__class__.__name__.lower() - or "pretrainedtokenizer" in tokenizer.__class__.__name__.lower() - or "codegen" in tokenizer.__class__.__name__.lower() - or "gptneox" in tokenizer.__class__.__name__.lower() + if re.match( + r"gpt2|bloom|pretrainedtokenizer|codegen|gptneox|Llama-3", tokenizer_name ): return BBPEMapping(tokenizer) - elif "t5" in tokenizer.__class__.__name__.lower(): + elif re.match(r"t5|Phi-3", tokenizer_name): return BPEMapping(tokenizer) - elif "llama" in tokenizer.__class__.__name__.lower(): + elif "llama" in tokenizer_name: return LlamaBPEMapping(tokenizer) - elif "xglm" in tokenizer.__class__.__name__.lower(): + elif "xglm" in tokenizer_name: return UniGramMapping(tokenizer) else: raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__.__name__}") else: - if "gpt2" in tokenizer.__class__.__name__.lower(): + if "gpt2" in tokenizer_name: return UnicodeBBPEMapping(tokenizer) else: raise NotImplementedError( @@ -36,6 +34,16 @@ def get_mapping(tokenizer, unicode=False): ) +class ReplacePrefixMixin: + def __init__(self, prefix): + self.prefix = prefix + + def _replace_prefix(self, token: str) -> str: + if token.startswith(self.prefix): + return token.replace(self.prefix, "", 1) + return token + + class Mapping: def __init__(self, tokenizer): self.eos_token_id = tokenizer.eos_token_id @@ -48,20 +56,23 @@ def __len__(self): return self._length def _map(self, token_id: int) -> str: - # This is the case for BOS, - if token_id in self.special: - return "" # if token_id is tensor, convert it to int if hasattr(token_id, "item"): token_id = token_id.item() + # This is the case for BOS, + if token_id in self.special: + return "" raw_token = self.tokenizer.convert_ids_to_tokens(token_id) return raw_token + def _encode(self, token: str) -> bytes: + return bytes(token, "utf-8") + def map(self, token_id: int, verbose=False) -> bytes: token = self._map(token_id) if verbose: log.debug(f"token_id: {token_id}, token: {token}") - return bytes(token, "utf-8") + return self._encode(token) class BBPEMapping(Mapping): @@ -71,7 +82,7 @@ def __init__(self, *args, **kwargs): def _map(self, token_id: int) -> str: raw_token = super()._map(token_id) if raw_token.startswith("Ġ"): - raw_token = raw_token.replace("Ġ", " ") + raw_token = raw_token.replace("Ġ", " ", 1) return raw_token @@ -82,17 +93,8 @@ def __init__(self, *args, **kwargs): self.tokenizer ) - def _map(self, token_id: int, verbose=False) -> str: - raw_token = super()._map(token_id) - # if raw_token.startswith("Ġ"): - # raw_token = raw_token.replace("Ġ", " ") - return raw_token - - def map(self, token_id: int, verbose=False) -> bytes: - raw_token = self._map(token_id, verbose) - if verbose: - log.debug(f"token_id: {token_id}, raw_token: {raw_token}") - return self.intermediate_encoding.token2bytes(raw_token) + def _encode(self, token: str) -> bytes: + return self.intermediate_encoding.token2bytes(token) @staticmethod def get_intermediate_encoding(tokenizer): @@ -107,17 +109,19 @@ def __init__(self, tokenizer): super().__init__(tokenizer) self.last_token_id = None + def _check_bos_token(self, token_id: int) -> bool: + # specific to BPE + at_bos = self.last_token_id is None + self.last_token_id = token_id if token_id != self.eos_token_id else None + return at_bos + def _map(self, token_id: int) -> str: raw_token = super()._map(token_id) - # we need to check if the token is at the beginning of the sentence to remove the space # specific to BPE - at_bos = False - if self.last_token_id is not None and self.last_token_id == self.bos_token_id: - at_bos = True - self.last_token_id = token_id + at_bos = self._check_bos_token(token_id) if raw_token.startswith("▁"): - raw_token = raw_token.replace("▁", " ") + raw_token = raw_token.replace("▁", " ", 1) if at_bos: # remove space at the beginning of the sentence raw_token = raw_token[1:] @@ -128,6 +132,11 @@ class LlamaBPEMapping(BPEMapping): def __init__(self, tokenizer): super().__init__(tokenizer) + def _check_bos_token(self, token_id: int) -> bool: + at_bos = self.last_token_id and (self.last_token_id == self.bos_token_id) + self.last_token_id = token_id + return at_bos + def _map(self, token_id: int) -> str: raw_token = super()._map(token_id) # if the token is hex, token is a string like "<0x00>" @@ -183,11 +192,11 @@ def __init__(self, tokenizer): self.char2byte: Dict[str, int] = tokenizer.byte_decoder # code point to byte self.cdp2byte: Dict[int, int] = {ord(c): b for c, b in self.char2byte.items()} - self.byte2cdp: Dict[int, int] = {v: k for k, v in self.cdp2byte.items()} + self.byte2cdp: Dict[int, int] = {b: c for c, b in self.cdp2byte.items()} def map(self, byte: int) -> int: assert 0 <= byte < 256, f"byte: {byte} is not in the range [0, 256)" - return ord(self.byte2char[byte]) + return self.byte2cdp[byte] def token_ids2bytes(self, token_ids: List[int]) -> bytes: tokens: List[str] = self.tokenizer.convert_ids_to_tokens(token_ids) @@ -196,10 +205,14 @@ def token_ids2bytes(self, token_ids: List[int]) -> bytes: tokens = [ "" if token in self.tokenizer.all_special_ids else token for token in tokens ] - bytes: List[List[int]] = [self.token2bytes(token) for token in tokens] + bytes_per_token: List[List[int]] = [self.token2bytes(token) for token in tokens] # join the bytes - return ints2bytes(sum(bytes, [])) + bytes = sum(bytes_per_token, []) + # verify range and convert to bytes + bytes = ints2bytes(bytes) + return bytes + # Not used def token_id2bytes(self, token_id: int) -> bytes: token: str = self.tokenizer.convert_ids_to_tokens(token_id) return self.token2bytes(token) From 315e0f9af876f83e2a49ee3ee706148300ed5e78 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Tue, 4 Jun 2024 15:21:47 +0200 Subject: [PATCH 05/17] fix for t5 --- tests/_tokenizer_common.py | 58 ++++++++++++-------- tests/test_tokenizers/test_phi3.py | 4 +- transformers_cfg/token_grammar_recognizer.py | 15 +++-- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/tests/_tokenizer_common.py b/tests/_tokenizer_common.py index d0660ee..b5501b5 100644 --- a/tests/_tokenizer_common.py +++ b/tests/_tokenizer_common.py @@ -29,6 +29,15 @@ class TokenizerTesterMixin: # test_sentencepiece must also be set to True test_sentencepiece_ignore_case = False + def _check_for_unk(self, token_ids): + for token_id in token_ids: + if token_id == self.tokenizer.unk_token_id: + warnings.warn( + f"unk token found in input_token_ids: {token_ids}, skipping test" + ) + return True + return False + def setUp(self): self.tokenizer = self.get_tokenizer() @@ -51,12 +60,8 @@ def test_json_parsable(self): pprint_token_ids(self.tokenizer, token_ids) # check if there is unk token - for token_id in token_ids: - if token_id == self.tokenizer.unk_token_id: - warnings.warn( - f"unk token found in input_token_ids: {token_ids}, skipping test" - ) - return + if self._check_for_unk(token_ids): + return acc_state = JsontokenRecognizer._consume_token_ids(token_ids, as_string=False) # the json object is complete, so the stacks should be empty @@ -78,12 +83,8 @@ def test_balanced_parentheses(self): pprint_token_ids(self.tokenizer, token_ids) # check if there is unk token - for token_id in token_ids: - if token_id == self.tokenizer.unk_token_id: - warnings.warn( - f"unk token found in input_token_ids: {token_ids}, skipping test" - ) - return + if self._check_for_unk(token_ids): + return accept_state = recognizer._consume_token_ids(token_ids, as_string=False) # the json object is complete, so the stacks should be empty @@ -92,16 +93,29 @@ def test_balanced_parentheses(self): f"stacks: {accept_state.stacks}, not empty", ) - # inbalanced_parentheses = "((((((((()))))))))))))" - # token_ids = self.tokenizer.encode(inbalanced_parentheses) - # pprint_token_ids(self.tokenizer, token_ids) - # - # # check if there is unk token - # stacks = recognizer._consume_token_ids( - # token_ids, recognizer.grammar.stacks, as_string=False - # ) - # - # self.assertTrue(stacks != [] and stacks != [[]], f"stacks: {stacks}, empty") + def test_multiple_sequences(self): + # Test that the global bos setting works with multiple sequences + with open("examples/grammars/balanced_parentheses.ebnf", "r") as file: + input_text = file.read() + recognizer = IncrementalTokenRecognizer( + grammar_str=input_text, start_rule_name="root", tokenizer=self.tokenizer + ) + + balanced_parentheses_samples = ["((((((((()))))))))", "()"] + + # check if there is unk token + for sample in balanced_parentheses_samples: + token_ids = self.tokenizer.encode(sample) + pprint_token_ids(self.tokenizer, token_ids) + if self._check_for_unk(token_ids): + return + + accept_state = recognizer._consume_token_ids(token_ids, as_string=False) + # the json object is complete, so the stacks should be empty + self.assertTrue( + accept_state.stacks == set() or accept_state.stacks == set(tuple()), + f"stacks: {accept_state.stacks}, not empty", + ) @unittest.skip("Not implemented") def test_emoji(self): diff --git a/tests/test_tokenizers/test_phi3.py b/tests/test_tokenizers/test_phi3.py index 29bca94..32fc586 100644 --- a/tests/test_tokenizers/test_phi3.py +++ b/tests/test_tokenizers/test_phi3.py @@ -1,6 +1,6 @@ import unittest -from transformers import T5TokenizerFast +from transformers import LlamaTokenizerFast from tests._tokenizer_common import TokenizerTesterMixin @@ -9,7 +9,7 @@ class Phi3TokenizerTest(TokenizerTesterMixin, unittest.TestCase): - tokenizer_class = T5TokenizerFast + tokenizer_class = LlamaTokenizerFast pretrained_name = "microsoft/Phi-3-mini-4k-instruct" def setUp(self): diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index 90272c1..dffdc2a 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -53,6 +53,7 @@ def _consume_token_id( ) -> AcceptState: if self._must_stop(accept_state.stacks): if token_id == self.eos_token_id: + self.mapping.last_token_id = None return AcceptState.empty_state() else: raise ValueError( @@ -60,6 +61,7 @@ def _consume_token_id( ) if token_id == self.eos_token_id: if self._can_stop(accept_state.stacks): + self.mapping.last_token_id = None # we clear all the stacks, meaning that we don't accept any token after EOS return AcceptState.empty_state() else: @@ -115,19 +117,19 @@ def get_token_acceptance(self, accept_state, device) -> torch.Tensor: def get_token_acceptance_array_for_stack(self, stack: Tuple, partial_utf8, device): assert isinstance(stack, tuple) - + token_acceptance = [False] * self.vocab_size - + if self.byte_encoding: # boolean function checking if a byte sequence is accepted by the grammar accept_f = lambda x: self.string_recognizer._try_accept_bytes( bytes(x), {stack}, partial_utf8=partial_utf8 ) self.unicode_trie.get_token_acceptance( - accept=accept_f, - accept_eos=False, - eos_token_id=self.eos_token_id, - token_acceptance=token_acceptance + accept=accept_f, + accept_eos=False, + eos_token_id=self.eos_token_id, + token_acceptance=token_acceptance, ) else: check_token_acceptance_in_trie( @@ -219,6 +221,7 @@ def _consume_token_ids( string = self.tokenizer.decode(token_ids) accept_state = self.string_recognizer._consume_string(string, accept_state) else: + print(self.tokenizer.eos_token_id in token_ids) for i, token_id in enumerate(token_ids): accept_state = self._consume_token_id(token_id, accept_state) if len(accept_state.stacks) > 0: From 3a4667f3e8e8aa0996137a9c02a15800f066cd97 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 17 Jun 2024 15:05:25 +0200 Subject: [PATCH 06/17] mcqa and arithmetic CoT --- examples/CoT_aqua.py | 177 ++++++++++++++++++ examples/generate_chain_of_though.py | 127 +++++++++++++ .../grammars/chain_of_thought_arithmetic.ebnf | 7 + examples/grammars/chain_of_thought_mcqa.ebnf | 5 + examples/grammars/mcqa.ebnf | 1 + 5 files changed, 317 insertions(+) create mode 100644 examples/CoT_aqua.py create mode 100644 examples/generate_chain_of_though.py create mode 100644 examples/grammars/chain_of_thought_arithmetic.ebnf create mode 100644 examples/grammars/chain_of_thought_mcqa.ebnf create mode 100644 examples/grammars/mcqa.ebnf diff --git a/examples/CoT_aqua.py b/examples/CoT_aqua.py new file mode 100644 index 0000000..e3e32ea --- /dev/null +++ b/examples/CoT_aqua.py @@ -0,0 +1,177 @@ +import re +import torch +import argparse +from sklearn.metrics import accuracy_score +from transformers import AutoModelForCausalLM, AutoTokenizer +import evaluate +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor +from datasets import load_dataset +from tqdm import tqdm +from collections import defaultdict + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate calflow strings") + parser.add_argument( + "--model-id", + type=str, + default="unsloth/mistral-7b-instruct-v0.2-bnb-4bit", + help="Model ID", + ) + parser.add_argument("--device", type=str, help="Device to put the model on") + return parser.parse_args() + + +def create_prompts(sample): + cot_in_context = "Think step-by-step, Question: How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: A)1156 B)1392 C)1480 D)1562 E)1788\nReasoning: There are 9 one-digit numbers from 1 to 9. There are 90 two-digit numbers from 10 to 99. There are 401 three-digit numbers from 100 to 500. 9 + 90 * 2 + 401 * 3 = 1392.\nAnswer: B);\n" + in_context = "Question: How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: A)1156 B)1392 C)1480 D)1562 E)1788.\nAnswer: B);\n" + + sample_text = f"Question: {sample['question']}\nAnswer Choices: {' '.join(sample['options'])}\n" + + prompt_cot = f"{cot_in_context}{sample_text}Reasoning: " + sample["prompt_cot"] = prompt_cot + + prompt_1_shot = f"{in_context}{sample_text}Answer: " + sample["prompt_1_shot"] = prompt_1_shot + + return sample + + +def extract_answers(batch, generations, answers): + def _parse_prediction(prediction): + pattern = r"[A-E]\)" + predcted_answer = re.search(pattern, prediction) + return predcted_answer[0][0] if predcted_answer else "" + + batch_size = len(batch["prompt_cot"]) + + for i in range(batch_size): + prompt_1_shot = batch["prompt_1_shot"][i] + prompt_cot = batch["prompt_cot"][i] + batch_size = len(batch["prompt_cot"]) + + unconstrained_prediction = generations[i][len(prompt_cot) :] + constrained_cot_prediction = generations[i + batch_size][len(prompt_cot) :] + constrained_mcqa_prediction = generations[i + 2 * batch_size][ + len(prompt_1_shot) : + ] + + answers["gt"].append(batch["correct"][i]) + answers["unconstrained"].append(_parse_prediction(unconstrained_prediction)) + answers["constrained_cot"].append(_parse_prediction(constrained_cot_prediction)) + answers["constrained_mcqa"].append( + _parse_prediction(constrained_mcqa_prediction) + ) + + +def count_empty(predictions): + return sum(1 for pred in predictions if not pred) + + +def load_grammar_processor(grammar_path, tokenizer): + with open(grammar_path, "r") as file: + grammar_str = file.read() + + grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + return grammar_processor + + +def main(): + args = parse_args() + model_id = args.model_id + + # Detect if GPU is available, otherwise use CPU + device = torch.device( + args.device or ("cuda" if torch.cuda.is_available() else "cpu") + ) + print(f"Using device: {device}") + + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + # Load model to defined device + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + model.generation_config.pad_token_id = model.generation_config.eos_token_id + + test_dataset = load_dataset("deepmind/aqua_rat", split="test") + test_dataset = test_dataset.map(create_prompts) + + max_new_tokens = 300 + batch_size = 8 + + answers = defaultdict(list) + + for i, batch in enumerate(tqdm(test_dataset.iter(batch_size=batch_size))): + # Load grammars + cot_grammar_processor = load_grammar_processor( + "examples/grammars/chain_of_thought_mcqa.ebnf", tokenizer + ) + mcqa_grammar_processor = load_grammar_processor( + "examples/grammars/mcqa.ebnf", tokenizer + ) + + input_ids_1_shot = tokenizer( + batch["prompt_1_shot"], + add_special_tokens=False, + return_tensors="pt", + padding=True, + )["input_ids"].to(device) + + input_ids_cot = tokenizer( + batch["prompt_cot"], + add_special_tokens=False, + return_tensors="pt", + padding=True, + )["input_ids"].to(device) + + unconstrained_output = model.generate( + input_ids_cot, + do_sample=False, + max_new_tokens=max_new_tokens, + repetition_penalty=1.1, + num_return_sequences=1, + ) + + constrained_output_cot = model.generate( + input_ids_cot, + do_sample=False, + max_new_tokens=max_new_tokens, + logits_processor=[cot_grammar_processor], + repetition_penalty=1.1, + num_return_sequences=1, + ) + + constrained_output_mcqa = model.generate( + input_ids_1_shot, + do_sample=False, + max_new_tokens=max_new_tokens, + logits_processor=[mcqa_grammar_processor], + repetition_penalty=1.1, + num_return_sequences=1, + ) + + # decode outputs (possibly of different lengths across decoding modes) + generations = ( + tokenizer.batch_decode(unconstrained_output, skip_special_tokens=True) + + tokenizer.batch_decode(constrained_output_cot, skip_special_tokens=True) + + tokenizer.batch_decode(constrained_output_mcqa, skip_special_tokens=True) + ) + + extract_answers(batch, generations, answers) + + print( + f"Unconstrained accuracy: {accuracy_score(y_true=answers['gt'], y_pred=answers['unconstrained']):.3f}, empty: {count_empty(answers['unconstrained'])} out of {len(answers['unconstrained'])}", + ) + print( + f"Constrained accuracy (COT): {accuracy_score(y_true=answers['gt'], y_pred=answers['constrained_cot']):.3f}, empty: {count_empty(answers['constrained_cot'])} out of {len(answers['constrained_cot'])}" + ) + print( + f"Constrained accuracy (MCQA): {accuracy_score(y_true=answers['gt'], y_pred=answers['constrained_mcqa']):.3f}, , empty: {count_empty(answers['constrained_mcqa'])} out of {len(answers['constrained_mcqa'])}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/generate_chain_of_though.py b/examples/generate_chain_of_though.py new file mode 100644 index 0000000..45c51ac --- /dev/null +++ b/examples/generate_chain_of_though.py @@ -0,0 +1,127 @@ +import torch +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.recognizer import StringRecognizer +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor +from transformers_cfg.parser import parse_ebnf + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate calflow strings") + parser.add_argument( + "--model-id", + type=str, + default="unsloth/mistral-7b-instruct-v0.2-bnb-4bit", + help="Model ID", + ) + parser.add_argument("--device", type=str, help="Device to put the model on") + return parser.parse_args() + + +def main(): + args = parse_args() + model_id = args.model_id + + # Detect if GPU is available, otherwise use CPU + device = torch.device( + args.device or ("cuda" if torch.cuda.is_available() else "cpu") + ) + print(f"Using device: {device}") + + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + # Load model to defined device + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + model.generation_config.pad_token_id = model.generation_config.eos_token_id + + # Load grammar + with open(f"examples/grammars/chain_of_thought_arithmetic.ebnf", "r") as file: + grammar_str = file.read() + + grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + + # Generate + prompts = [ + "179*12+34=", # no CoT + "think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=", # CoT + ] + + input_ids = tokenizer( + prompts, add_special_tokens=False, return_tensors="pt", padding=True + )["input_ids"].to( + device + ) # Move input_ids to the same device as model + + n_examples = input_ids.shape[0] + + max_new_tokens = 30 + + unconstrained_output = model.generate( + input_ids, + do_sample=False, + max_new_tokens=max_new_tokens, + repetition_penalty=1.9, + num_return_sequences=1, + ) + + constrained_output = model.generate( + input_ids, + do_sample=False, + max_new_tokens=max_new_tokens, + logits_processor=[grammar_processor], + repetition_penalty=1.9, + num_return_sequences=1, + ) + + # decode outputs (possibly of different lengths across decoding modes) + generations = tokenizer.batch_decode( + unconstrained_output, skip_special_tokens=True + ) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True) + + parsed_grammar = parse_ebnf(grammar_str) + string_grammar = StringRecognizer( + parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"] + ) + + print() + for i in range(n_examples): + print(f"Unconstrained: {generations[i]}") + constrained_generation = generations[i + n_examples] + print(f"Constrained: {constrained_generation}") + print( + f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}" + ) + print( + f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}" + ) + print() + + +if __name__ == "__main__": + main() + +########################## +# Example output (no chain of thought): +# Unconstrained: +# 179*12+34=0, +# -568. Вторемьте в некоторых другие позиции (включая и +# +# Constrained: +# 179*12+34=0; +# The constrained generation matches the grammar: True +# The generated prefix matches the grammar: True +# +# Example output (with chain of thought): +# Unconstrained: +# think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=2148.0 + 117 = <<< error: invalid type comparison >>>; +# ``` | ```vbnet +# ' +# Constrained: +# think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=2148+34=2182 >>> 2182; +# The constrained generation matches the grammar: True +# The generated prefix matches the grammar: True +########################## diff --git a/examples/grammars/chain_of_thought_arithmetic.ebnf b/examples/grammars/chain_of_thought_arithmetic.ebnf new file mode 100644 index 0000000..860df16 --- /dev/null +++ b/examples/grammars/chain_of_thought_arithmetic.ebnf @@ -0,0 +1,7 @@ +root ::= cot | result + +cot ::= ([-+*/=0-9])* " " result_mark " " result + +result_mark ::= ">>>" + +result ::= [0-9]+ ";" diff --git a/examples/grammars/chain_of_thought_mcqa.ebnf b/examples/grammars/chain_of_thought_mcqa.ebnf new file mode 100644 index 0000000..8c064db --- /dev/null +++ b/examples/grammars/chain_of_thought_mcqa.ebnf @@ -0,0 +1,5 @@ +root ::= cot | result + +cot ::= [\[\]-+*/=% 0-9a-zA-Z., ]* "." "\n" result + +result ::= "Answer: " [A-E] ")" ";" \ No newline at end of file diff --git a/examples/grammars/mcqa.ebnf b/examples/grammars/mcqa.ebnf new file mode 100644 index 0000000..d66532f --- /dev/null +++ b/examples/grammars/mcqa.ebnf @@ -0,0 +1 @@ +root ::= [A-E] ")" ";" \ No newline at end of file From 266cbc7534a49ab3da2797c2e115cd513a8eedc0 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 17 Jun 2024 15:21:22 +0200 Subject: [PATCH 07/17] formatting upd --- examples/grammars/chain_of_thought_mcqa.ebnf | 2 +- examples/grammars/mcqa.ebnf | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/grammars/chain_of_thought_mcqa.ebnf b/examples/grammars/chain_of_thought_mcqa.ebnf index 8c064db..2030f6a 100644 --- a/examples/grammars/chain_of_thought_mcqa.ebnf +++ b/examples/grammars/chain_of_thought_mcqa.ebnf @@ -2,4 +2,4 @@ root ::= cot | result cot ::= [\[\]-+*/=% 0-9a-zA-Z., ]* "." "\n" result -result ::= "Answer: " [A-E] ")" ";" \ No newline at end of file +result ::= "Answer: " [A-E] ")" ";" diff --git a/examples/grammars/mcqa.ebnf b/examples/grammars/mcqa.ebnf index d66532f..8707a2f 100644 --- a/examples/grammars/mcqa.ebnf +++ b/examples/grammars/mcqa.ebnf @@ -1 +1 @@ -root ::= [A-E] ")" ";" \ No newline at end of file +root ::= [A-E] ")" ";" From f58c0554bb589f5390bacff76eb85bf5aafd87cc Mon Sep 17 00:00:00 2001 From: Saibo Desktop Date: Sat, 25 May 2024 18:47:58 +0200 Subject: [PATCH 08/17] style: run black to format code --- README.md | 2 +- docs/benchmarking.md | 4 ++-- docs/json_grammar.md | 16 ++++++++++++++ examples/benchmarking/run_generation.sh | 6 +++-- examples/grammars/SMILES/acrylates.ebnf | 10 ++++----- examples/grammars/SMILES/chain_extenders.ebnf | 8 +++---- examples/grammars/SMILES/generic.ebnf | 4 ++-- examples/grammars/SMILES/isocyanates.ebnf | 6 ++--- examples/grammars/calflow.ebnf | 4 ++-- examples/grammars/geo_query.ebnf | 8 +++---- examples/grammars/overnight.ebnf | 6 ++--- tests/test_parsing/test_parsing.py | 1 - tests/test_string_recognizer/test_smiles.py | 22 ++++++++++++++----- transformers_cfg/tokenization/vocab_struct.py | 6 +++-- 14 files changed, 67 insertions(+), 36 deletions(-) create mode 100644 docs/json_grammar.md diff --git a/README.md b/README.md index 0d0b9a5..739f027 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ - **[Token masking optimization](#efficiency)(** (2024-04-25) - **Online [Demo with JSON Grammar](https://huggingface.co/spaces/saibo/transformers-CFG-JSON-demo) at HF space** (2024-04-10) - + - **Support for Unicode(multilingual) grammars** (2024-02-29) - **Integration with Text-Generation-WebUI** (2023-12-17) diff --git a/docs/benchmarking.md b/docs/benchmarking.md index 9c60e48..8e48bc5 100644 --- a/docs/benchmarking.md +++ b/docs/benchmarking.md @@ -1,6 +1,6 @@ # Benchmarking constrained generation overhead in transformers-CFG -This document provides guidelines and on benchmarking grammar constrained decoding when working with the `transformers_cfg` library. +This document provides guidelines and on benchmarking grammar constrained decoding when working with the `transformers_cfg` library. ## Table of Contents @@ -30,7 +30,7 @@ The output of the script will be saved in `transformers_cfg/examples/benchmarkin The output contains the following columns: -- `prompt`: the text of the prompt (see more on the benchmarking prompt design in the `examples/benchmarking/process_benchmarking_logs.ipynb`) +- `prompt`: the text of the prompt (see more on the benchmarking prompt design in the `examples/benchmarking/process_benchmarking_logs.ipynb`) - `n_tokens`: number of tokens generated (can be affected by the `max_new_tokens` parameter) - `run_id`: run id (each generation is performed 5 times per prompt to account for noise in the execution time measurmnet) - `total_time`: total overhead (depends on the complexity of the grammar, the model, the prompt and the device) diff --git a/docs/json_grammar.md b/docs/json_grammar.md new file mode 100644 index 0000000..8af6471 --- /dev/null +++ b/docs/json_grammar.md @@ -0,0 +1,16 @@ +# JSON(JavaScript Object Notation) Grammar + + +## JSON standard + +https://datatracker.ietf.org/doc/html/rfc7159 + +## Clarification + +- JSON doesn't support comments.(JSON5 does but it's not in Python's standard library) +- JSON doesn't support trailing commas. + + +## JSON5 VS JSON + +https://spec.json5.org/ diff --git a/examples/benchmarking/run_generation.sh b/examples/benchmarking/run_generation.sh index ec765d1..96f31f0 100755 --- a/examples/benchmarking/run_generation.sh +++ b/examples/benchmarking/run_generation.sh @@ -1,4 +1,6 @@ -grammar_path=$1 +#!/bin/bash + +grammar_path=$1 grammar_name=$(basename $grammar_path) prompts_path=$2 model_id=${3:-"openai-community/gpt2"} @@ -19,7 +21,7 @@ do do echo "Prompt: $prompt" for run_id in {1..5} - do + do echo "Measurment: $run_id" kernprof -b --skip-zero -v time_benchmarking.py $grammar_path "$prompt" $max_new_tokens $model_id > $tmp_file unconstrained_time=$(cat $tmp_file | grep "Unconstrained time: " | awk '{print $3;}') diff --git a/examples/grammars/SMILES/acrylates.ebnf b/examples/grammars/SMILES/acrylates.ebnf index f8cadac..96fe14f 100644 --- a/examples/grammars/SMILES/acrylates.ebnf +++ b/examples/grammars/SMILES/acrylates.ebnf @@ -1,12 +1,12 @@ -root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* +root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* -group_radical_left ::= "(" ( group_symbol_left (group_bond smiles+)? )+ ")" +group_radical_left ::= "(" ( group_symbol_left (group_bond smiles+)? )+ ")" -group_radical_right ::= "(" ( (smiles+ group_bond )? group_symbol_right )+ ")" +group_radical_right ::= "(" ( (smiles+ group_bond )? group_symbol_right )+ ")" group_bond ::= ( "-" | "\\" | "/" ) -group_symbol_left ::= "C=CC(=O)O" | "C=CC(O)=O" | "C(=C)C(=O)O" | "C(=C)C(O)=O" | "CC(=C)(=O)O" | "CC(=C)(O)=O" +group_symbol_left ::= "C=CC(=O)O" | "C=CC(O)=O" | "C(=C)C(=O)O" | "C(=C)C(O)=O" | "CC(=C)(=O)O" | "CC(=C)(O)=O" group_symbol_right ::= "OC(=O)C=C" | "O=C(O)C=C" | "OC(=O)C(=C)" | "O=C(O)C(=C)" | "O(O=)(C=)CC" | "O=(O)(C=)CC" @@ -49,7 +49,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | "U" | "V" | "W" | "Xe" | "Y" "b"? | - "Z" ( "n" | "r" ) + "Z" ( "n" | "r" ) ring_closure ::= "%" [1-9] [0-9] | [0-9] diff --git a/examples/grammars/SMILES/chain_extenders.ebnf b/examples/grammars/SMILES/chain_extenders.ebnf index a5ca486..2b83e0c 100644 --- a/examples/grammars/SMILES/chain_extenders.ebnf +++ b/examples/grammars/SMILES/chain_extenders.ebnf @@ -1,8 +1,8 @@ -root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* +root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* -group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" +group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" -group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" +group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" group_bond ::= ( "-" | "\\" | "/" ) @@ -50,7 +50,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | "U" | "V" | "W" | "Xe" | "Y" "b"? | - "Z" ( "n" | "r" ) + "Z" ( "n" | "r" ) ring_closure ::= "%" [1-9] [0-9] | [0-9] diff --git a/examples/grammars/SMILES/generic.ebnf b/examples/grammars/SMILES/generic.ebnf index 1471600..4886cb3 100644 --- a/examples/grammars/SMILES/generic.ebnf +++ b/examples/grammars/SMILES/generic.ebnf @@ -16,7 +16,7 @@ wildcard ::= "*" atom_spec ::= "[" isotope? ( "se" | "as" | aromatic_symbol | element_symbol | wildcard ) chiral_class? h_count? ( charge | class? ) "]" -organic_symbol ::= "B" | "C" | "N" | "O" | "P" | "S" | "F" | "I" | "Br" | "Cl" | "At" | "Ts" +organic_symbol ::= "B" | "C" | "N" | "O" | "P" | "S" | "F" | "I" | "Br" | "Cl" | "At" | "Ts" aromatic_symbol ::= "b" | "c" | "n" | "o" | "p" | "s" @@ -39,7 +39,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | "U" | "V" | "W" | "Xe" | "Y" "b"? | - "Z" ( "n" | "r" ) + "Z" ( "n" | "r" ) ring_closure ::= "%" [1-9] [0-9] | [0-9] diff --git a/examples/grammars/SMILES/isocyanates.ebnf b/examples/grammars/SMILES/isocyanates.ebnf index 606e2c3..f99f75a 100644 --- a/examples/grammars/SMILES/isocyanates.ebnf +++ b/examples/grammars/SMILES/isocyanates.ebnf @@ -1,8 +1,8 @@ root ::= ( group_symbol_left group_bond? | (smiles bond?)* group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right (bond? smiles)* | group_bond? group_symbol_right ) -group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" +group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" -group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" +group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" group_bond ::= ( "-" | "\\" | "/" ) @@ -50,7 +50,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | "U" | "V" | "W" | "Xe" | "Y" "b"? | - "Z" ( "n" | "r" ) + "Z" ( "n" | "r" ) ring_closure ::= "%" [1-9] [0-9] | [0-9] diff --git a/examples/grammars/calflow.ebnf b/examples/grammars/calflow.ebnf index e3eaaae..ad6fdd4 100644 --- a/examples/grammars/calflow.ebnf +++ b/examples/grammars/calflow.ebnf @@ -1,4 +1,4 @@ -root ::= call +root ::= call call ::= event | "(Yield " org ")" | "(Yield " "(size" org "))" | "(Yield " event ")" | "(Yield " weather ")" | "(Yield " "(> " "(size" event ")" number "))" | "(do " datetime " " call ")" | "(do " call " " call ")" | "(do " org " " call ")" @@ -57,7 +57,7 @@ event_constraint ::= "(& " event_constraint " " event_constraint ")" | "(FindLastEvent " event_constraint ")" | "(^(Event) " "EmptyStructConstraint)" -location_constraint ::= "(?= " location ")" | "(roomRequest)" | "(&" location_constraint " " location_constraint ")" +location_constraint ::= "(?= " location ")" | "(roomRequest)" | "(&" location_constraint " " location_constraint ")" location ::= "(Event.location " event_constraint ")" | "(LocationKeyphrase.apply " string ")" diff --git a/examples/grammars/geo_query.ebnf b/examples/grammars/geo_query.ebnf index d1d7c70..3a29b35 100644 --- a/examples/grammars/geo_query.ebnf +++ b/examples/grammars/geo_query.ebnf @@ -22,7 +22,7 @@ city ::= "city(" city ")" | "smallest_one(density_1(" city "))" | ALL_CITY -place ::= "placeid('" PLACENAME "')" | +place ::= "placeid('" PLACENAME "')" | "lake(" place ")" | "mountain(" place ")" | "place(" place ")" | @@ -41,10 +41,10 @@ place ::= "placeid('" PLACENAME "')" | "exclude(" place coma_sep place ")" | ALL_PLACE -river ::= "river(" river ")" | +river ::= "river(" river ")" | "riverid('" RIVERNAME "')" | - "major(" river ")" | - "loc_2(" country ")" | + "major(" river ")" | + "loc_2(" country ")" | "loc_2(" state ")" | "longer(" river ")" | "traverse_2(" city ")" | diff --git a/examples/grammars/overnight.ebnf b/examples/grammars/overnight.ebnf index b74ba43..a9ce8f9 100644 --- a/examples/grammars/overnight.ebnf +++ b/examples/grammars/overnight.ebnf @@ -1,4 +1,4 @@ -root ::= "(listValue " list_value ")" +root ::= "(listValue " list_value ")" list_value ::= "(filter " ( list_value " " PROPERTY | list_value " " PROPERTY OP list_value | list_value " " "(ensureNumericProperty " PROPERTY ")" OP "(ensureNumericEntity " list_value ")" ) ")" | @@ -15,7 +15,7 @@ PROPERTY ::= "shape" | "color" | "length" | "is_special" | "width" | "height" | "(reverse " ( "left" | "right" | "above" | "below" ) ")" -SINGLETON_VALUE ::= "en.block" | "en.shape" | "en.color" +SINGLETON_VALUE ::= "en.block" | "en.shape" | "en.color" ENTITY_VALUE ::= "en.block.block1" | "en.block.block2" | "en.shape.pyramid" | "en.shape.cube" | "en.color.red" | "en.color.green" @@ -24,4 +24,4 @@ NUMBER_VALUE ::= ( "3" | "6" ) " " "en.inch" | "2" OP ::= " " ( "=" | ">" | "<" | ">=" | "<=" | "!=" ) " " -AGGREGATE ::= " " ("sum" | "max" | "min" | "avg" ) " " +AGGREGATE ::= " " ("sum" | "max" | "min" | "avg" ) " " diff --git a/tests/test_parsing/test_parsing.py b/tests/test_parsing/test_parsing.py index 81793de..97d8063 100644 --- a/tests/test_parsing/test_parsing.py +++ b/tests/test_parsing/test_parsing.py @@ -138,7 +138,6 @@ def test__parse_literal_string(self): self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3") self.assertListEqual([2, ord("你"), ord("你")], outbuf) - def test__parse_escape(self): escaped_char_src = '"\\n"' outbuf = [] diff --git a/tests/test_string_recognizer/test_smiles.py b/tests/test_string_recognizer/test_smiles.py index 15e8b3a..62d3c58 100644 --- a/tests/test_string_recognizer/test_smiles.py +++ b/tests/test_string_recognizer/test_smiles.py @@ -64,9 +64,17 @@ class MoleculeTestCase: MoleculeTestCase("trans_bond_left", "O=C=N\\C1CC(C\\N=C=O)(CC(C1)(C)C)C"), MoleculeTestCase("trans_bond", "O=C=N\\CCCCCC/N=C=O"), MoleculeTestCase("group_radicals", "CCOC(C(N=C=O)CCCCN=C=O)=O"), - MoleculeTestCase("simple_atom", "O=C=NC1=CC=CC(CC2=CC=C(C=C2N=C=O)CC3=CC=C(C=C3)N=C=O)=C1"), - MoleculeTestCase("single_bond_no_hyphen", "O=C=NC1=CC(CC2=C(C=C(C=C2)CC3=CC=C(C=C3N=C=O)CC4=CC=C(C=C4)N=C=O)N=C=O)=CC=C1"), - MoleculeTestCase("double_bond", "O=C=NC1=CC=C(C=C1)CC2=CC=C(C=C2N=C=O)CC3=C(C=C(C=C3)CC4=CC=C(C=C4N=C=O)CC5=CC=C(C=C5)N=C=O)N=C=O"), + MoleculeTestCase( + "simple_atom", "O=C=NC1=CC=CC(CC2=CC=C(C=C2N=C=O)CC3=CC=C(C=C3)N=C=O)=C1" + ), + MoleculeTestCase( + "single_bond_no_hyphen", + "O=C=NC1=CC(CC2=C(C=C(C=C2)CC3=CC=C(C=C3N=C=O)CC4=CC=C(C=C4)N=C=O)N=C=O)=CC=C1", + ), + MoleculeTestCase( + "double_bond", + "O=C=NC1=CC=C(C=C1)CC2=CC=C(C=C2N=C=O)CC3=C(C=C(C=C3)CC4=CC=C(C=C4N=C=O)CC5=CC=C(C=C5)N=C=O)N=C=O", + ), MoleculeTestCase("interleaved_cycle_explicit", "CC1(CC(CC(CN=C=O)(C1)C)N=C=O)C"), MoleculeTestCase("interleaved_cycle_colon", "CC1=C(C=C(C=C1)CN=C=O)N=C=O"), MoleculeTestCase("cycles", "O=C=N\\c1ccc(cc1)Cc2ccc(\\N=C=O)cc2"), @@ -98,8 +106,12 @@ class MoleculeTestCase: MoleculeTestCase("", "C=CC(=O)OCC(CO)(COC(=O)C=C)COC(=O)C=C"), MoleculeTestCase("", "CCC(COCCCOC(=O)C=C)(COCCCOC(=O)C=C)COCCCOC(=O)C=C"), MoleculeTestCase("", "CCC(COCC(CC)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C"), - MoleculeTestCase("", "C=CC(=O)OCC(CO)(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)COC(=O)C=C"), - MoleculeTestCase("", "C=CC(=O)OCC(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C") + MoleculeTestCase( + "", "C=CC(=O)OCC(CO)(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)COC(=O)C=C" + ), + MoleculeTestCase( + "", "C=CC(=O)OCC(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C" + ), ] valid_chain_extender_sentences = [ diff --git a/transformers_cfg/tokenization/vocab_struct.py b/transformers_cfg/tokenization/vocab_struct.py index eb06a50..f1c606d 100644 --- a/transformers_cfg/tokenization/vocab_struct.py +++ b/transformers_cfg/tokenization/vocab_struct.py @@ -28,8 +28,10 @@ def replace_hex(match): hex_value = match.group(1) return chr(int(hex_value, 16)) - if ("gpt2" in tokenizer.__class__.__name__.lower() - or "pretrained" in tokenizer.__class__.__name__.lower()): # llama3 tokenizer + if ( + "gpt2" in tokenizer.__class__.__name__.lower() + or "pretrained" in tokenizer.__class__.__name__.lower() + ): # llama3 tokenizer special = tokenizer.additional_special_tokens_ids # Here, the decoder does a string replace on a bunch of sequences From 6ef563d4454f51db0ba85c90763e6b42f0174058 Mon Sep 17 00:00:00 2001 From: Saibo Desktop Date: Sat, 25 May 2024 23:18:31 +0200 Subject: [PATCH 09/17] style: rename TokenTrie into CodepointTrie and refactor implementation --- transformers_cfg/token_grammar_recognizer.py | 20 ++-- .../tokenization/{trie.py => byte_trie.py} | 0 .../tokenization/codepoint_trie.py | 88 ++++++++++++++++++ transformers_cfg/tokenization/vocab_struct.py | 92 ------------------- 4 files changed, 98 insertions(+), 102 deletions(-) rename transformers_cfg/tokenization/{trie.py => byte_trie.py} (100%) create mode 100644 transformers_cfg/tokenization/codepoint_trie.py delete mode 100644 transformers_cfg/tokenization/vocab_struct.py diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index d8195e6..a8cb58d 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -7,8 +7,8 @@ from transformers_cfg.recognizer import StringRecognizer, AcceptState from transformers_cfg.parser import parse_ebnf -from transformers_cfg.tokenization.trie import ByteTrie -from transformers_cfg.tokenization.vocab_struct import LEAF, TokenTrie +from transformers_cfg.tokenization.byte_trie import ByteTrie +from transformers_cfg.tokenization.codepoint_trie import LEAF, CodePointTrie from transformers_cfg.tokenization.mapping import get_mapping logger = logging.getLogger(__name__) @@ -30,14 +30,14 @@ def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False ) self.eos_token_id = tokenizer.eos_token_id - self.token_trie = TokenTrie(tokenizer) + self.code_point_token_trie = CodePointTrie(tokenizer) self.tokenizer = tokenizer self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id) self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode) self.mapping = get_mapping(tokenizer, unicode=unicode) assert len(self.mapping) == len( - self.token_trie - ), f"{len(self.mapping)}, {len(self.token_trie)}" + self.code_point_token_trie + ), f"{len(self.mapping)}, {len(self.code_point_token_trie)}" def _consume_token_id( self, token_id: int, accept_state: AcceptState @@ -142,7 +142,7 @@ def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device): else: accepts = [False] * len(self.mapping) token_acceptance = check_token_acceptance_in_trie( - self.token_trie.trie, + self.code_point_token_trie.trie, [stack], self.string_recognizer, self.eos_token_id, @@ -252,11 +252,11 @@ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts) continue new_stacks = set() - for stk in stacks: - if not stk: + for stack in stacks: + if not stack: continue - next_element_offset = stk[-1] + next_element_offset = stack[-1] num_chars = grammar.grammar_encoding[next_element_offset] if not grammar.char_acceptance_at_element(next_element_offset).get( @@ -266,7 +266,7 @@ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts) continue next_element_offset += num_chars + 1 - new_stack = list(stk[:-1]) + new_stack = list(stack[:-1]) if grammar.grammar_encoding[next_element_offset]: new_stack.append(next_element_offset) new_stacks.update(grammar.expand_stack_head(tuple(new_stack))) diff --git a/transformers_cfg/tokenization/trie.py b/transformers_cfg/tokenization/byte_trie.py similarity index 100% rename from transformers_cfg/tokenization/trie.py rename to transformers_cfg/tokenization/byte_trie.py diff --git a/transformers_cfg/tokenization/codepoint_trie.py b/transformers_cfg/tokenization/codepoint_trie.py new file mode 100644 index 0000000..53080c8 --- /dev/null +++ b/transformers_cfg/tokenization/codepoint_trie.py @@ -0,0 +1,88 @@ +################# +# DATA STRUCTURES +################# + +import logging +import re +from typing import List + +logger = logging.getLogger(__name__) + +LEAF = -1 + + +def fmt_token_as_codepoints(token_id, tokenizer, only_ascii=True) -> List[int]: + + special_token_ids = tokenizer.additional_special_tokens_ids + + tokenizer_class_name = tokenizer.__class__.__name__.lower() + + if "gpt2" in tokenizer_class_name or "pretrained" in tokenizer_class_name: + # GPT-2 or Pretrained tokenizers + # No additional space handling needed + handle_spaces = False + elif "llama" in tokenizer_class_name or "t5" in tokenizer_class_name: + # Llama or T5 tokenizers + # Handle leading space in token + handle_spaces = True + else: + # logger.warning( + # "Warning: unrecognized tokenizer: using default token formatting" + # ) + handle_spaces = False + + if token_id in special_token_ids: + return None + token = tokenizer.decode([token_id], clean_up_tokenization_spaces=False) + if handle_spaces: + raw_token = tokenizer.convert_ids_to_tokens(token_id) + if raw_token.startswith("▁"): + token = " " + token + code_points = [ord(c) for c in token] + # keep only code points within ASCII range + code_points = code_points if all(c < 128 for c in code_points) else None + return code_points + + +class CodePointTrie: + def __init__(self, tokenizer, only_ascii=True): + self.eos_token_id = tokenizer.eos_token_id + self.all_token_codepoints = [] + self.trie = {} + # we only keep ASCII code points + # the reason why we should do this is because to handle unicode properly, we need to handle multi-byte characters + # this can not be done with a simple code point trie + # if we set only_ascii to False, we will be able to handle a subset of unicode characters + # this behavior is probably not what we want + self.only_ascii = only_ascii + self.load_tokens(tokenizer) + + def id2str(self, token_id): + return self.all_token_codepoints[token_id] + + def __len__(self): + return len(self.all_token_codepoints) + + def load_tokens(self, tokenizer): + self.all_token_codepoints = [ + fmt_token_as_codepoints(token_id, tokenizer, self.only_ascii) + for token_id in range(len(tokenizer.get_vocab())) + ] + for token_id, token_codepoints in enumerate(self.all_token_codepoints): + if token_codepoints is not None: + self.insert_into_trie(self.trie, token_codepoints, token_id) + + def insert_into_trie(self, trie, token_bytes, token_id): + current = trie + for byte in token_bytes: + if byte not in current: + current[byte] = {} + current = current[byte] + current[LEAF] = token_id + + +if __name__ == "__main__": + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + token_trie = CodePointTrie(tokenizer) diff --git a/transformers_cfg/tokenization/vocab_struct.py b/transformers_cfg/tokenization/vocab_struct.py deleted file mode 100644 index f1c606d..0000000 --- a/transformers_cfg/tokenization/vocab_struct.py +++ /dev/null @@ -1,92 +0,0 @@ -################# -# DATA STRUCTURES -################# - -import logging -import re - -logger = logging.getLogger(__name__) - -LEAF = -1 - - -class TokenTrie: - def __init__(self, tokenizer): - self.eos_token_id = tokenizer.eos_token_id - self.tokens = [] - self.trie = {} - self.load_tokens(tokenizer) - - def id2str(self, token_id): - return self.tokens[token_id] - - def __len__(self): - return len(self.tokens) - - def load_tokens(self, tokenizer): - def replace_hex(match): - hex_value = match.group(1) - return chr(int(hex_value, 16)) - - if ( - "gpt2" in tokenizer.__class__.__name__.lower() - or "pretrained" in tokenizer.__class__.__name__.lower() - ): # llama3 tokenizer - special = tokenizer.additional_special_tokens_ids - - # Here, the decoder does a string replace on a bunch of sequences - # like ' .' for '.'. This interferes with our assumptions, where a - # token should always have exactly one representation. - # Fortunately(?) text-generation-inference doesn't seem to run this - # cleanup, so we get extraneous spaces. So, in order to generate - # the right token set for TGI, we have to skip the space trimming. - # See: - # https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3588-L3600 - def fmt_token(id): - if id in special: - return None - return bytes( - tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8" - ) - - elif ( - "llama" in tokenizer.__class__.__name__.lower() - or "t5" in tokenizer.__class__.__name__.lower() - ): - - def fmt_token(id): - token = tokenizer.convert_ids_to_tokens(id) - token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token) - token = token.replace("▁", " ") - return bytes(token, "utf-8") - - else: - logger.warning( - "Warning: unrecognized tokenizer: using default token formatting" - ) - - def fmt_token(id): - token = tokenizer.convert_ids_to_tokens(id) - return bytes(token, "utf-8") - - # note: vocab_size doesn't work here because there are also - # get_added_vocab() tokens - self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))] - for token_id, token_bytes in enumerate(self.tokens): - if token_bytes is not None: - self.insert_into_trie(self.trie, token_bytes, token_id) - - def insert_into_trie(self, trie, token_bytes, token_id): - current = trie - for byte in token_bytes: - if byte not in current: - current[byte] = {} - current = current[byte] - current[LEAF] = token_id - - -if __name__ == "__main__": - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - token_trie = TokenTrie(tokenizer) From 943a8a6a4c5a641907d131b9f09a2292c2d35888 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 27 May 2024 14:16:51 +0200 Subject: [PATCH 10/17] trie related recognizer refactoring --- transformers_cfg/recognizer.py | 73 +++--------------- transformers_cfg/token_grammar_recognizer.py | 78 ++++++++------------ transformers_cfg/tokenization/byte_trie.py | 22 +++--- transformers_cfg/utf8_utils.py | 52 +++++++------ 4 files changed, 81 insertions(+), 144 deletions(-) diff --git a/transformers_cfg/recognizer.py b/transformers_cfg/recognizer.py index dfbc92e..27f2eec 100644 --- a/transformers_cfg/recognizer.py +++ b/transformers_cfg/recognizer.py @@ -100,9 +100,6 @@ def init_stack(self, start_rule_id: int) -> Set[Tuple[int]]: def get_initial_accept_state(self) -> AcceptState: return AcceptState(self.init_stack(self.start_rule_id), PartialUTF8()) - def get_termination_accept_state(self) -> AcceptState: - return AcceptState(set(), PartialUTF8()) - @lru_cache(maxsize=32768) def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]: """ @@ -130,7 +127,6 @@ def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]: new_stacks: Set[Tuple[int]] = set() # Loop over alternates of referenced rule to build new stacks while self.grammar_encoding[ref_subrule_offset] != END_OF_RULE_MARKER: - # copy the original stack without the last element new_stack = list(stack[:-1]) # if the rule ref is followed by another element, we add it to the stack next_element_offset = cur_element_offset + 2 @@ -150,12 +146,6 @@ def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]: return new_stacks - def _consume_byte(self, byte: int, accept_state: AcceptState) -> AcceptState: - # suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes - # we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times - return self._consume_bytes(bytes([byte]), accept_state) - - # @lru_cache(maxsize=32768) def _try_accept_bytes( self, byte_seq: bytes, @@ -167,8 +157,6 @@ def _try_accept_bytes( The difference between accept_bytes and consume_bytes is that accept_bytes returns a boolean and consume_bytes returns a new accept state """ - if type(byte_seq) is list: - byte_seq = bytes(byte_seq) code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) if verbose: logging.debug( @@ -177,7 +165,6 @@ def _try_accept_bytes( new_stacks = self._consume_code_points_for_all_stacks(code_points, stacks) for stack in new_stacks: - # stack is empty, meaning that the variables are all consumed if len(stack) == 0: return True @@ -192,12 +179,14 @@ def _consume_bytes( accept_state: Optional[AcceptState] = None, verbose=True, ) -> AcceptState: + if accept_state is None: accept_state = self.get_initial_accept_state() stacks = accept_state.stacks partial_utf8 = accept_state.partial_utf8 if type(byte_seq) is list: byte_seq = bytes(byte_seq) + code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) if verbose: logging.debug( @@ -226,7 +215,7 @@ def _consume_code_point_for_all_stacks( ) -> Set[Tuple[int]]: """ consume a character from the stack - char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] + code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] """ new_stacks: Set[Tuple[int]] = set() @@ -244,7 +233,7 @@ def _consume_code_point_for_single_stack( ) -> Set[Tuple[int]]: """ consume a character from the stack - char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] + code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] """ # TODO, the below code will raise an error when the stack is empty, but why is this happening? # if len(stacks) == 0: @@ -253,15 +242,13 @@ def _consume_code_point_for_single_stack( # to indicate that the character is not accepted new_stacks: Set[Tuple[int]] = set() - if code_point == 0: - return new_stacks - # stack is empty - if len(stack) == 0: + + if code_point == 0 or len(stack) == 0: return new_stacks element_offset = stack[-1] - found = self.accept_code_point_at_element(code_point, element_offset) + if not found: return new_stacks @@ -398,20 +385,6 @@ def _accept_string(self, string: str, accept_state: Optional[AcceptState] = None ) return at_least_one_stack_is_empty - def _can_stop(self, stacks: Set[Tuple[int]]): - # This happens in practice, but maybe it shouldn't? TODO - if len(stacks) == 0: - return True - # if any of the stack is empty, we can stop - for stack in stacks: - if len(stack) == 0: - return True - else: - return False - - def _must_stop(self, stacks: Set[Tuple[int]]): - return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks) - ############################# # # Not Used @@ -422,50 +395,28 @@ def _must_stop(self, stacks: Set[Tuple[int]]): @lru_cache(maxsize=None) def char_acceptance_at_element(self, element_offset): """ - Caches and returns a dictionary indicating whether a Unicode character is accepted + Caches and returns a set of accepted Unicode characters at a given rule position. This function considers Unicode characters, dynamically - inserting accepted ranges into a dictionary to optimize memory usage. + inserting accepted ranges into the set to optimize memory usage. Args: - rule_offset: The offset in the grammar encoding where the rule starts. Returns: - - A dictionary where each key is a Unicode character (or range) and the value is True if accepted. + - A set of accepted Unicode characters (or range). """ logging.debug(f"element_offset: {element_offset}") - acceptance = {} + acceptance = set() num_chars = self.grammar_encoding[element_offset] element_offset += 1 for i in range(0, num_chars, 2): start = self.grammar_encoding[element_offset + i] end = self.grammar_encoding[element_offset + i + 1] for j in range(start, end + 1): - acceptance[j] = True + acceptance.add(j) logging.debug(acceptance) return acceptance - # def _consume_code_points_new( - # self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False - # ) -> Set[Tuple[int]]: - # new_stacks: Set[Tuple[int]] = set() - # for stack in stacks: - # new_stacks.update( - # self._consume_code_points_per_stack(tuple(code_points), stack, verbose) - # ) - # return new_stacks - # - # @lru_cache(maxsize=30000) - # def _consume_code_points_per_stack( - # self, code_points: Tuple[int], stack: Tuple[int], verbose=False - # ) -> Set[Tuple[int]]: - # stacks = {stack} - # - # for code_point in code_points: - # # Update the stacks variable by consuming each code point. - # stacks = self._consume_code_point_for_all_stacks(code_point, (stack,)) - # - # return stacks - if __name__ == "__main__": # set logging level diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index a8cb58d..90272c1 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -10,6 +10,7 @@ from transformers_cfg.tokenization.byte_trie import ByteTrie from transformers_cfg.tokenization.codepoint_trie import LEAF, CodePointTrie from transformers_cfg.tokenization.mapping import get_mapping +from typing import Set, Tuple logger = logging.getLogger(__name__) @@ -35,25 +36,32 @@ def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id) self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode) self.mapping = get_mapping(tokenizer, unicode=unicode) - assert len(self.mapping) == len( + self.vocab_size = len(self.mapping) + assert self.vocab_size == len( self.code_point_token_trie - ), f"{len(self.mapping)}, {len(self.code_point_token_trie)}" + ), f"{self.vocab_size}, {len(self.code_point_token_trie)}" + + def _must_stop(self, stacks: Set[Tuple[int]]): + return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks) + + def _can_stop(self, stacks: Set[Tuple[int]]): + # if at least one of the stack is empty, we can stop + return len(stacks) == 0 or any(len(stack) == 0 for stack in stacks) def _consume_token_id( self, token_id: int, accept_state: AcceptState ) -> AcceptState: - if self.string_recognizer._must_stop(accept_state.stacks): + if self._must_stop(accept_state.stacks): if token_id == self.eos_token_id: - return self.string_recognizer.get_termination_accept_state() + return AcceptState.empty_state() else: raise ValueError( f"All stacks are empty, so the only token accepted is EOS({self.eos_token_id}), but got {token_id}" ) if token_id == self.eos_token_id: - if self.string_recognizer._can_stop(accept_state.stacks): - # if at least one of the stack is empty, we can stop + if self._can_stop(accept_state.stacks): # we clear all the stacks, meaning that we don't accept any token after EOS - return self.string_recognizer.get_termination_accept_state() + return AcceptState.empty_state() else: raise ValueError( f"At least one of the stack should be empty when EOS is reached. However, " @@ -66,28 +74,6 @@ def _consume_token_id( ) return accept_state - def try_accept_token_id(self, token_id: int, accept_state: AcceptState) -> bool: - stacks = accept_state.stacks - if self.string_recognizer._must_stop(stacks): - if token_id == self.eos_token_id: - return True - else: - return False - if token_id == self.eos_token_id: - if self.string_recognizer._can_stop(stacks): - # if at least one of the stack is empty, we can stop - # we clear all the stacks, meaning that we don't accept any token after EOS - return True - else: - return False - # for code_point in self.mapping.map(token_id): - # stacks = self.grammar._consume_char_code_point(code_point, stacks) - bytes_or_codepoints = self.mapping.map(token_id, verbose=False) - new_acc_state = self.string_recognizer._consume_bytes( - bytes_or_codepoints, accept_state, verbose=False - ) - return len(new_acc_state.stacks) > 0 - def consume_token_ids(self, *args, **kwargs): """Process a list of tokens according to the grammar rules.""" raise NotImplementedError @@ -102,10 +88,9 @@ def filter_vocab(self, accept_state, device) -> torch.Tensor: if not accept_state.stacks: # Check if stacks is empty # Handle the empty case: for example, return a tensor of False # The size of the tensor should match the size of your vocabulary - vocab_size = len(self.mapping) logger.debug(f"Empty stack, sum of acceptance: {0}") # size of the vocab - accepts = [False] * vocab_size + accepts = [False] * self.vocab_size accepts[self.eos_token_id] = True return torch.tensor(accepts, dtype=torch.bool, device=device) @@ -127,26 +112,30 @@ def get_token_acceptance(self, accept_state, device) -> torch.Tensor: return acceptance @lru_cache(maxsize=32768) - def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device): - # stack = list(stack) # needs to come in as a tuple for lru_cache - assert isinstance(stack, tuple) + def get_token_acceptance_array_for_stack(self, stack: Tuple, partial_utf8, device): + assert isinstance(stack, tuple) + + token_acceptance = [False] * self.vocab_size + if self.byte_encoding: - + # boolean function checking if a byte sequence is accepted by the grammar accept_f = lambda x: self.string_recognizer._try_accept_bytes( - x, {stack}, partial_utf8=partial_utf8 + bytes(x), {stack}, partial_utf8=partial_utf8 ) - token_acceptance = self.unicode_trie.get_token_acceptance( - accept=accept_f, accept_eos=False, eos_token_id=self.eos_token_id + self.unicode_trie.get_token_acceptance( + accept=accept_f, + accept_eos=False, + eos_token_id=self.eos_token_id, + token_acceptance=token_acceptance ) else: - accepts = [False] * len(self.mapping) - token_acceptance = check_token_acceptance_in_trie( + check_token_acceptance_in_trie( self.code_point_token_trie.trie, [stack], self.string_recognizer, self.eos_token_id, - accepts, + token_acceptance, ) x = torch.tensor(token_acceptance, dtype=torch.bool, device=device) x_eos = self.validate_and_set_eos_acceptance(x) @@ -241,7 +230,6 @@ def _consume_token_ids( def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts): - for byte, next_trie in trie.items(): if byte == LEAF: token_id = next_trie @@ -259,10 +247,8 @@ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts) next_element_offset = stack[-1] num_chars = grammar.grammar_encoding[next_element_offset] - if not grammar.char_acceptance_at_element(next_element_offset).get( - byte, False - ): - # if the current byte is not accepted by the current rule, we need to try next rule + # if the current byte is not accepted by the current rule, we need to try next rule + if not grammar.accept_code_point_at_element(byte, next_element_offset): continue next_element_offset += num_chars + 1 diff --git a/transformers_cfg/tokenization/byte_trie.py b/transformers_cfg/tokenization/byte_trie.py index dd4cbe3..fcb48e0 100644 --- a/transformers_cfg/tokenization/byte_trie.py +++ b/transformers_cfg/tokenization/byte_trie.py @@ -5,17 +5,8 @@ from transformers_cfg.tokenization.mapping import get_mapping -# from transformers_cfg.parser import parse_ebnf -# from transformers_cfg.recognizer import GrammarRecognizer -# from transformers_cfg.token_grammar_recognizer import IncrementalTokenGrammarRecognizer - logger = logging.getLogger(__name__) -# def check_token_acceptance_in_trie(trie, stacks, grammar, partial_utf8, accept_eos=True, eos_token_id=None) -> List[bool]: -# accept_f = lambda x: grammar._probe_bytes_partial_match(x, stack=stacks, partial_utf8=partial_utf8) -# accepts = trie.get_token_acceptance(accept=accept_f, accept_eos=accept_eos, eos_token_id=eos_token_id) -# return accepts - class TrieNode: def __init__(self): @@ -76,14 +67,16 @@ def dfs(self, accept=lambda x: True, verbose=False) -> List[Tuple[List[int], int def bfs( self, predicate=lambda x: True, verbose=False ) -> List[Tuple[List[int], int]]: + queue = deque([(self.root, [])]) + # TODO: do we need to keep track of the byte sequence? valid_byte_seqs: List[Tuple[List[int], int]] = [] counter = {"visited": 0, "pruned": 0} while queue: counter["visited"] += 1 node, byte_seq = queue.popleft() - if predicate(byte_seq): + if predicate(bytes(byte_seq)): if node.is_end_of_word: valid_byte_seqs.append((byte_seq, node.token_id)) for char, next_node in node.children.items(): @@ -95,18 +88,21 @@ def bfs( return valid_byte_seqs def get_token_acceptance( - self, accept=lambda x: True, accept_eos=True, eos_token_id=None + self, accept, accept_eos, eos_token_id, token_acceptance ) -> List[bool]: + """ + Finds all acceptable tokens for a fixed stack (verified with accept function). + Modifies token_acceptance: a list of booleans, where the ith element is True if the ith token is acceptable, False otherwise. + """ valid_byte_seqs: List[Tuple[List[int], int]] = self.bfs(accept, verbose=True) valid_token_ids: List[int] = [token_id for _, token_id in valid_byte_seqs] - token_acceptance: List[bool] = [False] * (len(self)) + for token_id in valid_token_ids: token_acceptance[token_id] = True if not accept_eos: # eos_token is mapped to an empty string, so it's always accepted regardless of the accept function # this can be undesirable, so we can set it to False to ignore it token_acceptance[eos_token_id] = False - return token_acceptance def _dfs( diff --git a/transformers_cfg/utf8_utils.py b/transformers_cfg/utf8_utils.py index f82dbaf..7bcf490 100644 --- a/transformers_cfg/utf8_utils.py +++ b/transformers_cfg/utf8_utils.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from typing import Tuple - -from dataclasses import dataclass +from typing import List, Tuple +from functools import lru_cache @dataclass @@ -36,34 +35,41 @@ def __eq__(self, other): return self.value == other.value and self.n_remain == other.n_remain -from typing import List, Tuple -from functools import lru_cache +@lru_cache(maxsize=3000000) +def decode_utf8_intermediate( + src: bytes, pos: int, value: int, n_remain: int +) -> Tuple[int, int, int]: + while pos < len(src) and n_remain > 0: + next_byte = src[pos] # Get the next byte to process + # Check if the continuation byte format is correct (`10xxxxxx`) + if (next_byte >> 6) != 2: + return -1, -1, -1 + + # Accumulate the value by shifting left and adding the relevant 6 bits + value = (value << 6) + (next_byte & 0x3F) + pos += 1 # Move to the next byte + n_remain -= 1 # Decrement the number of remaining bytes + return value, n_remain, pos @lru_cache(maxsize=3000000) def decode_utf8( src: bytes, partial_start: PartialUTF8 ) -> Tuple[List[int], PartialUTF8]: + # Lookup table for determining the total bytes based on the first byte's high 4 bits lookup = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4] pos = 0 # Position in the src bytes to start decoding from code_points = [] # List to store the decoded Unicode code points + value = partial_start.value # Start with any previously partial decoded value n_remain = partial_start.n_remain # Number of bytes remaining from a partial decode # If there's a partial sequence left from last decode, try to continue decoding it - while pos < len(src) and n_remain > 0: - next_byte = src[pos] # Get the next byte to process - # Check if the continuation byte format is correct (`10xxxxxx`) - if (next_byte >> 6) != 2: - # If not, it's an invalid sequence. Abort and return a special error state. - code_points = [0] - return code_points, PartialUTF8(0, -1) - - # Accumulate the value by shifting left and adding the relevant 6 bits - value = (value << 6) + (next_byte & 0x3F) - pos += 1 # Move to the next byte - n_remain -= 1 # Decrement the number of remaining bytes + value, n_remain, pos = decode_utf8_intermediate(src, pos, value, n_remain) + # Invalid sequence, return a special error state. + if value == -1 and n_remain == -1 and pos == -1: + return [0], PartialUTF8(0, -1) # If we've completed a partial sequence, add its value to the code points if partial_start.n_remain > 0 and n_remain == 0: @@ -86,13 +92,11 @@ def decode_utf8( value = first_byte & mask # Apply the mask to get the initial value pos += 1 # Move to the next byte - # Process the continuation bytes - while pos < len(src) and n_remain > 0: - next_byte = src[pos] - # Shift the accumulated value and add the next 6 significant bits - value = (value << 6) + (next_byte & 0x3F) - pos += 1 # Move to the next byte - n_remain -= 1 # Decrement the number of remaining bytes + # Decode the continuation bytes + value, n_remain, pos = decode_utf8_intermediate(src, pos, value, n_remain) + # Invalid sequence, return a special error state. + if value == -1 and n_remain == -1 and pos == -1: + return [0], PartialUTF8(0, -1) # If the sequence is complete, add its decoded value to the code points if n_remain == 0: From 19f9e4e38e825d98a734b8e82e41bbcf4f63ce21 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 27 May 2024 16:02:45 +0200 Subject: [PATCH 11/17] mapping refactored + sanity checks for t5/bloom/llama3 and t5 added --- tests/test_tokenizers/test_bloom.py | 1 - tests/test_tokenizers/test_llama3.py | 15 +++++ tests/test_tokenizers/test_phi3.py | 16 +++++ tests/test_tokenizers/test_t5.py | 1 - transformers_cfg/tokenization/mapping.py | 85 ++++++++++++++---------- 5 files changed, 80 insertions(+), 38 deletions(-) create mode 100644 tests/test_tokenizers/test_llama3.py create mode 100644 tests/test_tokenizers/test_phi3.py diff --git a/tests/test_tokenizers/test_bloom.py b/tests/test_tokenizers/test_bloom.py index ebee494..6253698 100644 --- a/tests/test_tokenizers/test_bloom.py +++ b/tests/test_tokenizers/test_bloom.py @@ -7,7 +7,6 @@ import logging -# @unittest.skip("GPTNeoXTokenizerFast is not available for testing") class BloomTokenizerTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = BloomTokenizerFast diff --git a/tests/test_tokenizers/test_llama3.py b/tests/test_tokenizers/test_llama3.py new file mode 100644 index 0000000..8a32e68 --- /dev/null +++ b/tests/test_tokenizers/test_llama3.py @@ -0,0 +1,15 @@ +import unittest + +from transformers import GPT2TokenizerFast +from tests._tokenizer_common import TokenizerTesterMixin + +import logging + + +class Llama3TokenizerTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = GPT2TokenizerFast + pretrained_name = "meta-llama/Meta-Llama-3-8B" + + def setUp(self): + super().setUp() diff --git a/tests/test_tokenizers/test_phi3.py b/tests/test_tokenizers/test_phi3.py new file mode 100644 index 0000000..29bca94 --- /dev/null +++ b/tests/test_tokenizers/test_phi3.py @@ -0,0 +1,16 @@ +import unittest + +from transformers import T5TokenizerFast + +from tests._tokenizer_common import TokenizerTesterMixin + +import logging + + +class Phi3TokenizerTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = T5TokenizerFast + pretrained_name = "microsoft/Phi-3-mini-4k-instruct" + + def setUp(self): + super().setUp() diff --git a/tests/test_tokenizers/test_t5.py b/tests/test_tokenizers/test_t5.py index 936bd39..f4d6bbf 100644 --- a/tests/test_tokenizers/test_t5.py +++ b/tests/test_tokenizers/test_t5.py @@ -7,7 +7,6 @@ import logging -@unittest.skip("T5Tokenizer's mapping is not well defined, not working") class T5TokenizerTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = T5TokenizerFast diff --git a/transformers_cfg/tokenization/mapping.py b/transformers_cfg/tokenization/mapping.py index 7cd1913..a89f1c2 100644 --- a/transformers_cfg/tokenization/mapping.py +++ b/transformers_cfg/tokenization/mapping.py @@ -2,6 +2,7 @@ from transformers_cfg.utils import get_tokenizer_model_type, ints2bytes from transformers import AutoTokenizer +import re import logging log = logging.getLogger(__name__) @@ -10,25 +11,22 @@ def get_mapping(tokenizer, unicode=False): log.debug(f"tokenizer type: {tokenizer.__class__.__name__}") log.debug(f"tokenizer model type: {get_tokenizer_model_type(tokenizer)}") + tokenizer_name = tokenizer.__class__.__name__.lower() if not unicode: - if ( - "gpt2" in tokenizer.__class__.__name__.lower() - or "bloom" in tokenizer.__class__.__name__.lower() - or "pretrainedtokenizer" in tokenizer.__class__.__name__.lower() - or "codegen" in tokenizer.__class__.__name__.lower() - or "gptneox" in tokenizer.__class__.__name__.lower() + if re.match( + r"gpt2|bloom|pretrainedtokenizer|codegen|gptneox|Llama-3", tokenizer_name ): return BBPEMapping(tokenizer) - elif "t5" in tokenizer.__class__.__name__.lower(): + elif re.match(r"t5|Phi-3", tokenizer_name): return BPEMapping(tokenizer) - elif "llama" in tokenizer.__class__.__name__.lower(): + elif "llama" in tokenizer_name: return LlamaBPEMapping(tokenizer) - elif "xglm" in tokenizer.__class__.__name__.lower(): + elif "xglm" in tokenizer_name: return UniGramMapping(tokenizer) else: raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__.__name__}") else: - if "gpt2" in tokenizer.__class__.__name__.lower(): + if "gpt2" in tokenizer_name: return UnicodeBBPEMapping(tokenizer) else: raise NotImplementedError( @@ -36,6 +34,16 @@ def get_mapping(tokenizer, unicode=False): ) +class ReplacePrefixMixin: + def __init__(self, prefix): + self.prefix = prefix + + def _replace_prefix(self, token: str) -> str: + if token.startswith(self.prefix): + return token.replace(self.prefix, "", 1) + return token + + class Mapping: def __init__(self, tokenizer): self.eos_token_id = tokenizer.eos_token_id @@ -48,20 +56,23 @@ def __len__(self): return self._length def _map(self, token_id: int) -> str: - # This is the case for BOS, - if token_id in self.special: - return "" # if token_id is tensor, convert it to int if hasattr(token_id, "item"): token_id = token_id.item() + # This is the case for BOS, + if token_id in self.special: + return "" raw_token = self.tokenizer.convert_ids_to_tokens(token_id) return raw_token + def _encode(self, token: str) -> bytes: + return bytes(token, "utf-8") + def map(self, token_id: int, verbose=False) -> bytes: token = self._map(token_id) if verbose: log.debug(f"token_id: {token_id}, token: {token}") - return bytes(token, "utf-8") + return self._encode(token) class BBPEMapping(Mapping): @@ -71,7 +82,7 @@ def __init__(self, *args, **kwargs): def _map(self, token_id: int) -> str: raw_token = super()._map(token_id) if raw_token.startswith("Ġ"): - raw_token = raw_token.replace("Ġ", " ") + raw_token = raw_token.replace("Ġ", " ", 1) return raw_token @@ -82,17 +93,8 @@ def __init__(self, *args, **kwargs): self.tokenizer ) - def _map(self, token_id: int, verbose=False) -> str: - raw_token = super()._map(token_id) - # if raw_token.startswith("Ġ"): - # raw_token = raw_token.replace("Ġ", " ") - return raw_token - - def map(self, token_id: int, verbose=False) -> bytes: - raw_token = self._map(token_id, verbose) - if verbose: - log.debug(f"token_id: {token_id}, raw_token: {raw_token}") - return self.intermediate_encoding.token2bytes(raw_token) + def _encode(self, token: str) -> bytes: + return self.intermediate_encoding.token2bytes(token) @staticmethod def get_intermediate_encoding(tokenizer): @@ -107,17 +109,19 @@ def __init__(self, tokenizer): super().__init__(tokenizer) self.last_token_id = None + def _check_bos_token(self, token_id: int) -> bool: + # specific to BPE + at_bos = self.last_token_id is None + self.last_token_id = token_id if token_id != self.eos_token_id else None + return at_bos + def _map(self, token_id: int) -> str: raw_token = super()._map(token_id) - # we need to check if the token is at the beginning of the sentence to remove the space # specific to BPE - at_bos = False - if self.last_token_id is not None and self.last_token_id == self.bos_token_id: - at_bos = True - self.last_token_id = token_id + at_bos = self._check_bos_token(token_id) if raw_token.startswith("▁"): - raw_token = raw_token.replace("▁", " ") + raw_token = raw_token.replace("▁", " ", 1) if at_bos: # remove space at the beginning of the sentence raw_token = raw_token[1:] @@ -128,6 +132,11 @@ class LlamaBPEMapping(BPEMapping): def __init__(self, tokenizer): super().__init__(tokenizer) + def _check_bos_token(self, token_id: int) -> bool: + at_bos = self.last_token_id and (self.last_token_id == self.bos_token_id) + self.last_token_id = token_id + return at_bos + def _map(self, token_id: int) -> str: raw_token = super()._map(token_id) # if the token is hex, token is a string like "<0x00>" @@ -183,11 +192,11 @@ def __init__(self, tokenizer): self.char2byte: Dict[str, int] = tokenizer.byte_decoder # code point to byte self.cdp2byte: Dict[int, int] = {ord(c): b for c, b in self.char2byte.items()} - self.byte2cdp: Dict[int, int] = {v: k for k, v in self.cdp2byte.items()} + self.byte2cdp: Dict[int, int] = {b: c for c, b in self.cdp2byte.items()} def map(self, byte: int) -> int: assert 0 <= byte < 256, f"byte: {byte} is not in the range [0, 256)" - return ord(self.byte2char[byte]) + return self.byte2cdp[byte] def token_ids2bytes(self, token_ids: List[int]) -> bytes: tokens: List[str] = self.tokenizer.convert_ids_to_tokens(token_ids) @@ -196,10 +205,14 @@ def token_ids2bytes(self, token_ids: List[int]) -> bytes: tokens = [ "" if token in self.tokenizer.all_special_ids else token for token in tokens ] - bytes: List[List[int]] = [self.token2bytes(token) for token in tokens] + bytes_per_token: List[List[int]] = [self.token2bytes(token) for token in tokens] # join the bytes - return ints2bytes(sum(bytes, [])) + bytes = sum(bytes_per_token, []) + # verify range and convert to bytes + bytes = ints2bytes(bytes) + return bytes + # Not used def token_id2bytes(self, token_id: int) -> bytes: token: str = self.tokenizer.convert_ids_to_tokens(token_id) return self.token2bytes(token) From b3f8b1d40ea6d6420abce7a6868aa1552e16cfc4 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Tue, 4 Jun 2024 15:21:47 +0200 Subject: [PATCH 12/17] fix for t5 --- tests/_tokenizer_common.py | 58 ++++++++++++-------- tests/test_tokenizers/test_phi3.py | 4 +- transformers_cfg/token_grammar_recognizer.py | 15 +++-- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/tests/_tokenizer_common.py b/tests/_tokenizer_common.py index d0660ee..b5501b5 100644 --- a/tests/_tokenizer_common.py +++ b/tests/_tokenizer_common.py @@ -29,6 +29,15 @@ class TokenizerTesterMixin: # test_sentencepiece must also be set to True test_sentencepiece_ignore_case = False + def _check_for_unk(self, token_ids): + for token_id in token_ids: + if token_id == self.tokenizer.unk_token_id: + warnings.warn( + f"unk token found in input_token_ids: {token_ids}, skipping test" + ) + return True + return False + def setUp(self): self.tokenizer = self.get_tokenizer() @@ -51,12 +60,8 @@ def test_json_parsable(self): pprint_token_ids(self.tokenizer, token_ids) # check if there is unk token - for token_id in token_ids: - if token_id == self.tokenizer.unk_token_id: - warnings.warn( - f"unk token found in input_token_ids: {token_ids}, skipping test" - ) - return + if self._check_for_unk(token_ids): + return acc_state = JsontokenRecognizer._consume_token_ids(token_ids, as_string=False) # the json object is complete, so the stacks should be empty @@ -78,12 +83,8 @@ def test_balanced_parentheses(self): pprint_token_ids(self.tokenizer, token_ids) # check if there is unk token - for token_id in token_ids: - if token_id == self.tokenizer.unk_token_id: - warnings.warn( - f"unk token found in input_token_ids: {token_ids}, skipping test" - ) - return + if self._check_for_unk(token_ids): + return accept_state = recognizer._consume_token_ids(token_ids, as_string=False) # the json object is complete, so the stacks should be empty @@ -92,16 +93,29 @@ def test_balanced_parentheses(self): f"stacks: {accept_state.stacks}, not empty", ) - # inbalanced_parentheses = "((((((((()))))))))))))" - # token_ids = self.tokenizer.encode(inbalanced_parentheses) - # pprint_token_ids(self.tokenizer, token_ids) - # - # # check if there is unk token - # stacks = recognizer._consume_token_ids( - # token_ids, recognizer.grammar.stacks, as_string=False - # ) - # - # self.assertTrue(stacks != [] and stacks != [[]], f"stacks: {stacks}, empty") + def test_multiple_sequences(self): + # Test that the global bos setting works with multiple sequences + with open("examples/grammars/balanced_parentheses.ebnf", "r") as file: + input_text = file.read() + recognizer = IncrementalTokenRecognizer( + grammar_str=input_text, start_rule_name="root", tokenizer=self.tokenizer + ) + + balanced_parentheses_samples = ["((((((((()))))))))", "()"] + + # check if there is unk token + for sample in balanced_parentheses_samples: + token_ids = self.tokenizer.encode(sample) + pprint_token_ids(self.tokenizer, token_ids) + if self._check_for_unk(token_ids): + return + + accept_state = recognizer._consume_token_ids(token_ids, as_string=False) + # the json object is complete, so the stacks should be empty + self.assertTrue( + accept_state.stacks == set() or accept_state.stacks == set(tuple()), + f"stacks: {accept_state.stacks}, not empty", + ) @unittest.skip("Not implemented") def test_emoji(self): diff --git a/tests/test_tokenizers/test_phi3.py b/tests/test_tokenizers/test_phi3.py index 29bca94..32fc586 100644 --- a/tests/test_tokenizers/test_phi3.py +++ b/tests/test_tokenizers/test_phi3.py @@ -1,6 +1,6 @@ import unittest -from transformers import T5TokenizerFast +from transformers import LlamaTokenizerFast from tests._tokenizer_common import TokenizerTesterMixin @@ -9,7 +9,7 @@ class Phi3TokenizerTest(TokenizerTesterMixin, unittest.TestCase): - tokenizer_class = T5TokenizerFast + tokenizer_class = LlamaTokenizerFast pretrained_name = "microsoft/Phi-3-mini-4k-instruct" def setUp(self): diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index 90272c1..dffdc2a 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -53,6 +53,7 @@ def _consume_token_id( ) -> AcceptState: if self._must_stop(accept_state.stacks): if token_id == self.eos_token_id: + self.mapping.last_token_id = None return AcceptState.empty_state() else: raise ValueError( @@ -60,6 +61,7 @@ def _consume_token_id( ) if token_id == self.eos_token_id: if self._can_stop(accept_state.stacks): + self.mapping.last_token_id = None # we clear all the stacks, meaning that we don't accept any token after EOS return AcceptState.empty_state() else: @@ -115,19 +117,19 @@ def get_token_acceptance(self, accept_state, device) -> torch.Tensor: def get_token_acceptance_array_for_stack(self, stack: Tuple, partial_utf8, device): assert isinstance(stack, tuple) - + token_acceptance = [False] * self.vocab_size - + if self.byte_encoding: # boolean function checking if a byte sequence is accepted by the grammar accept_f = lambda x: self.string_recognizer._try_accept_bytes( bytes(x), {stack}, partial_utf8=partial_utf8 ) self.unicode_trie.get_token_acceptance( - accept=accept_f, - accept_eos=False, - eos_token_id=self.eos_token_id, - token_acceptance=token_acceptance + accept=accept_f, + accept_eos=False, + eos_token_id=self.eos_token_id, + token_acceptance=token_acceptance, ) else: check_token_acceptance_in_trie( @@ -219,6 +221,7 @@ def _consume_token_ids( string = self.tokenizer.decode(token_ids) accept_state = self.string_recognizer._consume_string(string, accept_state) else: + print(self.tokenizer.eos_token_id in token_ids) for i, token_id in enumerate(token_ids): accept_state = self._consume_token_id(token_id, accept_state) if len(accept_state.stacks) > 0: From 798a234aa7372c6f56319d41448cdc3a1d4e0ed6 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 17 Jun 2024 15:05:25 +0200 Subject: [PATCH 13/17] mcqa and arithmetic CoT --- examples/CoT_aqua.py | 177 ++++++++++++++++++ examples/generate_chain_of_though.py | 127 +++++++++++++ .../grammars/chain_of_thought_arithmetic.ebnf | 7 + examples/grammars/chain_of_thought_mcqa.ebnf | 5 + examples/grammars/mcqa.ebnf | 1 + 5 files changed, 317 insertions(+) create mode 100644 examples/CoT_aqua.py create mode 100644 examples/generate_chain_of_though.py create mode 100644 examples/grammars/chain_of_thought_arithmetic.ebnf create mode 100644 examples/grammars/chain_of_thought_mcqa.ebnf create mode 100644 examples/grammars/mcqa.ebnf diff --git a/examples/CoT_aqua.py b/examples/CoT_aqua.py new file mode 100644 index 0000000..e3e32ea --- /dev/null +++ b/examples/CoT_aqua.py @@ -0,0 +1,177 @@ +import re +import torch +import argparse +from sklearn.metrics import accuracy_score +from transformers import AutoModelForCausalLM, AutoTokenizer +import evaluate +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor +from datasets import load_dataset +from tqdm import tqdm +from collections import defaultdict + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate calflow strings") + parser.add_argument( + "--model-id", + type=str, + default="unsloth/mistral-7b-instruct-v0.2-bnb-4bit", + help="Model ID", + ) + parser.add_argument("--device", type=str, help="Device to put the model on") + return parser.parse_args() + + +def create_prompts(sample): + cot_in_context = "Think step-by-step, Question: How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: A)1156 B)1392 C)1480 D)1562 E)1788\nReasoning: There are 9 one-digit numbers from 1 to 9. There are 90 two-digit numbers from 10 to 99. There are 401 three-digit numbers from 100 to 500. 9 + 90 * 2 + 401 * 3 = 1392.\nAnswer: B);\n" + in_context = "Question: How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: A)1156 B)1392 C)1480 D)1562 E)1788.\nAnswer: B);\n" + + sample_text = f"Question: {sample['question']}\nAnswer Choices: {' '.join(sample['options'])}\n" + + prompt_cot = f"{cot_in_context}{sample_text}Reasoning: " + sample["prompt_cot"] = prompt_cot + + prompt_1_shot = f"{in_context}{sample_text}Answer: " + sample["prompt_1_shot"] = prompt_1_shot + + return sample + + +def extract_answers(batch, generations, answers): + def _parse_prediction(prediction): + pattern = r"[A-E]\)" + predcted_answer = re.search(pattern, prediction) + return predcted_answer[0][0] if predcted_answer else "" + + batch_size = len(batch["prompt_cot"]) + + for i in range(batch_size): + prompt_1_shot = batch["prompt_1_shot"][i] + prompt_cot = batch["prompt_cot"][i] + batch_size = len(batch["prompt_cot"]) + + unconstrained_prediction = generations[i][len(prompt_cot) :] + constrained_cot_prediction = generations[i + batch_size][len(prompt_cot) :] + constrained_mcqa_prediction = generations[i + 2 * batch_size][ + len(prompt_1_shot) : + ] + + answers["gt"].append(batch["correct"][i]) + answers["unconstrained"].append(_parse_prediction(unconstrained_prediction)) + answers["constrained_cot"].append(_parse_prediction(constrained_cot_prediction)) + answers["constrained_mcqa"].append( + _parse_prediction(constrained_mcqa_prediction) + ) + + +def count_empty(predictions): + return sum(1 for pred in predictions if not pred) + + +def load_grammar_processor(grammar_path, tokenizer): + with open(grammar_path, "r") as file: + grammar_str = file.read() + + grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + return grammar_processor + + +def main(): + args = parse_args() + model_id = args.model_id + + # Detect if GPU is available, otherwise use CPU + device = torch.device( + args.device or ("cuda" if torch.cuda.is_available() else "cpu") + ) + print(f"Using device: {device}") + + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + # Load model to defined device + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + model.generation_config.pad_token_id = model.generation_config.eos_token_id + + test_dataset = load_dataset("deepmind/aqua_rat", split="test") + test_dataset = test_dataset.map(create_prompts) + + max_new_tokens = 300 + batch_size = 8 + + answers = defaultdict(list) + + for i, batch in enumerate(tqdm(test_dataset.iter(batch_size=batch_size))): + # Load grammars + cot_grammar_processor = load_grammar_processor( + "examples/grammars/chain_of_thought_mcqa.ebnf", tokenizer + ) + mcqa_grammar_processor = load_grammar_processor( + "examples/grammars/mcqa.ebnf", tokenizer + ) + + input_ids_1_shot = tokenizer( + batch["prompt_1_shot"], + add_special_tokens=False, + return_tensors="pt", + padding=True, + )["input_ids"].to(device) + + input_ids_cot = tokenizer( + batch["prompt_cot"], + add_special_tokens=False, + return_tensors="pt", + padding=True, + )["input_ids"].to(device) + + unconstrained_output = model.generate( + input_ids_cot, + do_sample=False, + max_new_tokens=max_new_tokens, + repetition_penalty=1.1, + num_return_sequences=1, + ) + + constrained_output_cot = model.generate( + input_ids_cot, + do_sample=False, + max_new_tokens=max_new_tokens, + logits_processor=[cot_grammar_processor], + repetition_penalty=1.1, + num_return_sequences=1, + ) + + constrained_output_mcqa = model.generate( + input_ids_1_shot, + do_sample=False, + max_new_tokens=max_new_tokens, + logits_processor=[mcqa_grammar_processor], + repetition_penalty=1.1, + num_return_sequences=1, + ) + + # decode outputs (possibly of different lengths across decoding modes) + generations = ( + tokenizer.batch_decode(unconstrained_output, skip_special_tokens=True) + + tokenizer.batch_decode(constrained_output_cot, skip_special_tokens=True) + + tokenizer.batch_decode(constrained_output_mcqa, skip_special_tokens=True) + ) + + extract_answers(batch, generations, answers) + + print( + f"Unconstrained accuracy: {accuracy_score(y_true=answers['gt'], y_pred=answers['unconstrained']):.3f}, empty: {count_empty(answers['unconstrained'])} out of {len(answers['unconstrained'])}", + ) + print( + f"Constrained accuracy (COT): {accuracy_score(y_true=answers['gt'], y_pred=answers['constrained_cot']):.3f}, empty: {count_empty(answers['constrained_cot'])} out of {len(answers['constrained_cot'])}" + ) + print( + f"Constrained accuracy (MCQA): {accuracy_score(y_true=answers['gt'], y_pred=answers['constrained_mcqa']):.3f}, , empty: {count_empty(answers['constrained_mcqa'])} out of {len(answers['constrained_mcqa'])}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/generate_chain_of_though.py b/examples/generate_chain_of_though.py new file mode 100644 index 0000000..45c51ac --- /dev/null +++ b/examples/generate_chain_of_though.py @@ -0,0 +1,127 @@ +import torch +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.recognizer import StringRecognizer +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor +from transformers_cfg.parser import parse_ebnf + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate calflow strings") + parser.add_argument( + "--model-id", + type=str, + default="unsloth/mistral-7b-instruct-v0.2-bnb-4bit", + help="Model ID", + ) + parser.add_argument("--device", type=str, help="Device to put the model on") + return parser.parse_args() + + +def main(): + args = parse_args() + model_id = args.model_id + + # Detect if GPU is available, otherwise use CPU + device = torch.device( + args.device or ("cuda" if torch.cuda.is_available() else "cpu") + ) + print(f"Using device: {device}") + + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + # Load model to defined device + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + model.generation_config.pad_token_id = model.generation_config.eos_token_id + + # Load grammar + with open(f"examples/grammars/chain_of_thought_arithmetic.ebnf", "r") as file: + grammar_str = file.read() + + grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + + # Generate + prompts = [ + "179*12+34=", # no CoT + "think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=", # CoT + ] + + input_ids = tokenizer( + prompts, add_special_tokens=False, return_tensors="pt", padding=True + )["input_ids"].to( + device + ) # Move input_ids to the same device as model + + n_examples = input_ids.shape[0] + + max_new_tokens = 30 + + unconstrained_output = model.generate( + input_ids, + do_sample=False, + max_new_tokens=max_new_tokens, + repetition_penalty=1.9, + num_return_sequences=1, + ) + + constrained_output = model.generate( + input_ids, + do_sample=False, + max_new_tokens=max_new_tokens, + logits_processor=[grammar_processor], + repetition_penalty=1.9, + num_return_sequences=1, + ) + + # decode outputs (possibly of different lengths across decoding modes) + generations = tokenizer.batch_decode( + unconstrained_output, skip_special_tokens=True + ) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True) + + parsed_grammar = parse_ebnf(grammar_str) + string_grammar = StringRecognizer( + parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"] + ) + + print() + for i in range(n_examples): + print(f"Unconstrained: {generations[i]}") + constrained_generation = generations[i + n_examples] + print(f"Constrained: {constrained_generation}") + print( + f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}" + ) + print( + f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}" + ) + print() + + +if __name__ == "__main__": + main() + +########################## +# Example output (no chain of thought): +# Unconstrained: +# 179*12+34=0, +# -568. Вторемьте в некоторых другие позиции (включая и +# +# Constrained: +# 179*12+34=0; +# The constrained generation matches the grammar: True +# The generated prefix matches the grammar: True +# +# Example output (with chain of thought): +# Unconstrained: +# think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=2148.0 + 117 = <<< error: invalid type comparison >>>; +# ``` | ```vbnet +# ' +# Constrained: +# think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=2148+34=2182 >>> 2182; +# The constrained generation matches the grammar: True +# The generated prefix matches the grammar: True +########################## diff --git a/examples/grammars/chain_of_thought_arithmetic.ebnf b/examples/grammars/chain_of_thought_arithmetic.ebnf new file mode 100644 index 0000000..860df16 --- /dev/null +++ b/examples/grammars/chain_of_thought_arithmetic.ebnf @@ -0,0 +1,7 @@ +root ::= cot | result + +cot ::= ([-+*/=0-9])* " " result_mark " " result + +result_mark ::= ">>>" + +result ::= [0-9]+ ";" diff --git a/examples/grammars/chain_of_thought_mcqa.ebnf b/examples/grammars/chain_of_thought_mcqa.ebnf new file mode 100644 index 0000000..8c064db --- /dev/null +++ b/examples/grammars/chain_of_thought_mcqa.ebnf @@ -0,0 +1,5 @@ +root ::= cot | result + +cot ::= [\[\]-+*/=% 0-9a-zA-Z., ]* "." "\n" result + +result ::= "Answer: " [A-E] ")" ";" \ No newline at end of file diff --git a/examples/grammars/mcqa.ebnf b/examples/grammars/mcqa.ebnf new file mode 100644 index 0000000..d66532f --- /dev/null +++ b/examples/grammars/mcqa.ebnf @@ -0,0 +1 @@ +root ::= [A-E] ")" ";" \ No newline at end of file From 97d535d25467f9758c23f5d75ed0d9b09a48a958 Mon Sep 17 00:00:00 2001 From: Saibo-creator <53392976+Saibo-creator@users.noreply.github.com> Date: Sun, 9 Jun 2024 18:41:38 +0200 Subject: [PATCH 14/17] fix: update func api in debugging_custom_grammar.md (#56) --- docs/debugging_custom_grammars.md | 8 ++++---- transformers_cfg/recognizer.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/debugging_custom_grammars.md b/docs/debugging_custom_grammars.md index 4542206..496c0bf 100644 --- a/docs/debugging_custom_grammars.md +++ b/docs/debugging_custom_grammars.md @@ -90,18 +90,18 @@ We provide a simple script to do this: ```python from transformers_cfg.parser import parse_ebnf -from transformers_cfg.recognizer import GrammarRecognizer +from transformers_cfg.recognizer import StringRecognizer with open("examples/grammars/json.ebnf", "r") as file: input_text = file.read() parsed_grammar = parse_ebnf(input_text) start_rule_id = parsed_grammar.symbol_table["root"] -recognizer = GrammarRecognizer(parsed_grammar.grammar_encoding, start_rule_id) +recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) # Test the grammar with a simple input json_input = '{"foo": "bar", "baz": "bat"}' -is_accepted = recognizer._accept_prefix(json_input, recognizer.stacks) +is_accepted = recognizer._accept_prefix(json_input) print(is_accepted) ``` @@ -112,7 +112,7 @@ N.B. the recognizer can accept partial input, so you can try the following: ```python json_input = '{"foo": "bar"' -is_accepted = recognizer._accept_prefix(json_input, recognizer.stacks) +is_accepted = recognizer._accept_prefix(json_input) print(is_accepted) ``` diff --git a/transformers_cfg/recognizer.py b/transformers_cfg/recognizer.py index 27f2eec..577156f 100644 --- a/transformers_cfg/recognizer.py +++ b/transformers_cfg/recognizer.py @@ -186,7 +186,7 @@ def _consume_bytes( partial_utf8 = accept_state.partial_utf8 if type(byte_seq) is list: byte_seq = bytes(byte_seq) - + code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) if verbose: logging.debug( @@ -418,6 +418,9 @@ def char_acceptance_at_element(self, element_offset): return acceptance +# backward compatibility, add alias of StringRecognizer to GrammarRecognizer +GrammarRecognizer = StringRecognizer + if __name__ == "__main__": # set logging level From 991585e3a77113e819d64587a046474cf2d62918 Mon Sep 17 00:00:00 2001 From: Saibo-creator <53392976+Saibo-creator@users.noreply.github.com> Date: Thu, 13 Jun 2024 22:47:16 +0200 Subject: [PATCH 15/17] use server demo in README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 739f027..9a16f74 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ - **[Token masking optimization](#efficiency)(** (2024-04-25) -- **Online [Demo with JSON Grammar](https://huggingface.co/spaces/saibo/transformers-CFG-JSON-demo) at HF space** (2024-04-10) - +- **Online [Demo with JSON Grammar](http://saibo-creator.xyz:7860/) at HF space** (2024-04-10) + - **Support for Unicode(multilingual) grammars** (2024-02-29) - **Integration with Text-Generation-WebUI** (2023-12-17) From d5559aba0ea4a66721c86d0cb75b136878ab4e04 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 17 Jun 2024 15:21:22 +0200 Subject: [PATCH 16/17] formatting upd --- examples/grammars/chain_of_thought_mcqa.ebnf | 2 +- examples/grammars/mcqa.ebnf | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/grammars/chain_of_thought_mcqa.ebnf b/examples/grammars/chain_of_thought_mcqa.ebnf index 8c064db..2030f6a 100644 --- a/examples/grammars/chain_of_thought_mcqa.ebnf +++ b/examples/grammars/chain_of_thought_mcqa.ebnf @@ -2,4 +2,4 @@ root ::= cot | result cot ::= [\[\]-+*/=% 0-9a-zA-Z., ]* "." "\n" result -result ::= "Answer: " [A-E] ")" ";" \ No newline at end of file +result ::= "Answer: " [A-E] ")" ";" diff --git a/examples/grammars/mcqa.ebnf b/examples/grammars/mcqa.ebnf index d66532f..8707a2f 100644 --- a/examples/grammars/mcqa.ebnf +++ b/examples/grammars/mcqa.ebnf @@ -1 +1 @@ -root ::= [A-E] ")" ";" \ No newline at end of file +root ::= [A-E] ")" ";" From 35854e81c59c3c067c77245018fba6ec0f78f455 Mon Sep 17 00:00:00 2001 From: Arina Rak Date: Mon, 17 Jun 2024 15:29:05 +0200 Subject: [PATCH 17/17] change script description --- examples/generate_chain_of_though.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/generate_chain_of_though.py b/examples/generate_chain_of_though.py index 45c51ac..4b1a291 100644 --- a/examples/generate_chain_of_though.py +++ b/examples/generate_chain_of_though.py @@ -8,7 +8,9 @@ def parse_args(): - parser = argparse.ArgumentParser(description="Generate calflow strings") + parser = argparse.ArgumentParser( + description="Generate chain of thought arithmentic strings" + ) parser.add_argument( "--model-id", type=str,