Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoT prompting with GCD examples #59

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/benchmarking.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Benchmarking constrained generation overhead in transformers-CFG

This document provides guidelines and on benchmarking grammar constrained decoding when working with the `transformers_cfg` library.
This document provides guidelines and on benchmarking grammar constrained decoding when working with the `transformers_cfg` library.

## Table of Contents

Expand Down Expand Up @@ -30,7 +30,7 @@ The output of the script will be saved in `transformers_cfg/examples/benchmarkin

The output contains the following columns:

- `prompt`: the text of the prompt (see more on the benchmarking prompt design in the `examples/benchmarking/process_benchmarking_logs.ipynb`)
- `prompt`: the text of the prompt (see more on the benchmarking prompt design in the `examples/benchmarking/process_benchmarking_logs.ipynb`)
- `n_tokens`: number of tokens generated (can be affected by the `max_new_tokens` parameter)
- `run_id`: run id (each generation is performed 5 times per prompt to account for noise in the execution time measurmnet)
- `total_time`: total overhead (depends on the complexity of the grammar, the model, the prompt and the device)
Expand Down
16 changes: 16 additions & 0 deletions docs/json_grammar.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# JSON(JavaScript Object Notation) Grammar


## JSON standard

https://datatracker.ietf.org/doc/html/rfc7159

## Clarification

- JSON doesn't support comments.(JSON5 does but it's not in Python's standard library)
- JSON doesn't support trailing commas.


## JSON5 VS JSON

https://spec.json5.org/
177 changes: 177 additions & 0 deletions examples/CoT_aqua.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import re
import torch
import argparse
from sklearn.metrics import accuracy_score
from transformers import AutoModelForCausalLM, AutoTokenizer
import evaluate
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
from datasets import load_dataset
from tqdm import tqdm
from collections import defaultdict


def parse_args():
parser = argparse.ArgumentParser(description="Generate calflow strings")
parser.add_argument(
"--model-id",
type=str,
default="unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
help="Model ID",
)
parser.add_argument("--device", type=str, help="Device to put the model on")
return parser.parse_args()


def create_prompts(sample):
cot_in_context = "Think step-by-step, Question: How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: A)1156 B)1392 C)1480 D)1562 E)1788\nReasoning: There are 9 one-digit numbers from 1 to 9. There are 90 two-digit numbers from 10 to 99. There are 401 three-digit numbers from 100 to 500. 9 + 90 * 2 + 401 * 3 = 1392.\nAnswer: B);\n"
in_context = "Question: How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: A)1156 B)1392 C)1480 D)1562 E)1788.\nAnswer: B);\n"

sample_text = f"Question: {sample['question']}\nAnswer Choices: {' '.join(sample['options'])}\n"

prompt_cot = f"{cot_in_context}{sample_text}Reasoning: "
sample["prompt_cot"] = prompt_cot

prompt_1_shot = f"{in_context}{sample_text}Answer: "
sample["prompt_1_shot"] = prompt_1_shot

return sample


def extract_answers(batch, generations, answers):
def _parse_prediction(prediction):
pattern = r"[A-E]\)"
predcted_answer = re.search(pattern, prediction)
return predcted_answer[0][0] if predcted_answer else ""

batch_size = len(batch["prompt_cot"])

for i in range(batch_size):
prompt_1_shot = batch["prompt_1_shot"][i]
prompt_cot = batch["prompt_cot"][i]
batch_size = len(batch["prompt_cot"])

unconstrained_prediction = generations[i][len(prompt_cot) :]
constrained_cot_prediction = generations[i + batch_size][len(prompt_cot) :]
constrained_mcqa_prediction = generations[i + 2 * batch_size][
len(prompt_1_shot) :
]

answers["gt"].append(batch["correct"][i])
answers["unconstrained"].append(_parse_prediction(unconstrained_prediction))
answers["constrained_cot"].append(_parse_prediction(constrained_cot_prediction))
answers["constrained_mcqa"].append(
_parse_prediction(constrained_mcqa_prediction)
)


def count_empty(predictions):
return sum(1 for pred in predictions if not pred)


def load_grammar_processor(grammar_path, tokenizer):
with open(grammar_path, "r") as file:
grammar_str = file.read()

grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
return grammar_processor


def main():
args = parse_args()
model_id = args.model_id

# Detect if GPU is available, otherwise use CPU
device = torch.device(
args.device or ("cuda" if torch.cuda.is_available() else "cpu")
)
print(f"Using device: {device}")

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Load model to defined device
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
model.generation_config.pad_token_id = model.generation_config.eos_token_id

test_dataset = load_dataset("deepmind/aqua_rat", split="test")
test_dataset = test_dataset.map(create_prompts)

max_new_tokens = 300
batch_size = 8

answers = defaultdict(list)

for i, batch in enumerate(tqdm(test_dataset.iter(batch_size=batch_size))):
# Load grammars
cot_grammar_processor = load_grammar_processor(
"examples/grammars/chain_of_thought_mcqa.ebnf", tokenizer
)
mcqa_grammar_processor = load_grammar_processor(
"examples/grammars/mcqa.ebnf", tokenizer
)

input_ids_1_shot = tokenizer(
batch["prompt_1_shot"],
add_special_tokens=False,
return_tensors="pt",
padding=True,
)["input_ids"].to(device)

input_ids_cot = tokenizer(
batch["prompt_cot"],
add_special_tokens=False,
return_tensors="pt",
padding=True,
)["input_ids"].to(device)

unconstrained_output = model.generate(
input_ids_cot,
do_sample=False,
max_new_tokens=max_new_tokens,
repetition_penalty=1.1,
num_return_sequences=1,
)

constrained_output_cot = model.generate(
input_ids_cot,
do_sample=False,
max_new_tokens=max_new_tokens,
logits_processor=[cot_grammar_processor],
repetition_penalty=1.1,
num_return_sequences=1,
)

constrained_output_mcqa = model.generate(
input_ids_1_shot,
do_sample=False,
max_new_tokens=max_new_tokens,
logits_processor=[mcqa_grammar_processor],
repetition_penalty=1.1,
num_return_sequences=1,
)

# decode outputs (possibly of different lengths across decoding modes)
generations = (
tokenizer.batch_decode(unconstrained_output, skip_special_tokens=True)
+ tokenizer.batch_decode(constrained_output_cot, skip_special_tokens=True)
+ tokenizer.batch_decode(constrained_output_mcqa, skip_special_tokens=True)
)

extract_answers(batch, generations, answers)

print(
f"Unconstrained accuracy: {accuracy_score(y_true=answers['gt'], y_pred=answers['unconstrained']):.3f}, empty: {count_empty(answers['unconstrained'])} out of {len(answers['unconstrained'])}",
)
print(
f"Constrained accuracy (COT): {accuracy_score(y_true=answers['gt'], y_pred=answers['constrained_cot']):.3f}, empty: {count_empty(answers['constrained_cot'])} out of {len(answers['constrained_cot'])}"
)
print(
f"Constrained accuracy (MCQA): {accuracy_score(y_true=answers['gt'], y_pred=answers['constrained_mcqa']):.3f}, , empty: {count_empty(answers['constrained_mcqa'])} out of {len(answers['constrained_mcqa'])}"
)


if __name__ == "__main__":
main()
6 changes: 4 additions & 2 deletions examples/benchmarking/run_generation.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
grammar_path=$1
#!/bin/bash

grammar_path=$1
grammar_name=$(basename $grammar_path)
prompts_path=$2
model_id=${3:-"openai-community/gpt2"}
Expand All @@ -19,7 +21,7 @@ do
do
echo "Prompt: $prompt"
for run_id in {1..5}
do
do
echo "Measurment: $run_id"
kernprof -b --skip-zero -v time_benchmarking.py $grammar_path "$prompt" $max_new_tokens $model_id > $tmp_file
unconstrained_time=$(cat $tmp_file | grep "Unconstrained time: " | awk '{print $3;}')
Expand Down
129 changes: 129 additions & 0 deletions examples/generate_chain_of_though.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.recognizer import StringRecognizer
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
from transformers_cfg.parser import parse_ebnf


def parse_args():
parser = argparse.ArgumentParser(
description="Generate chain of thought arithmentic strings"
)
parser.add_argument(
"--model-id",
type=str,
default="unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
help="Model ID",
)
parser.add_argument("--device", type=str, help="Device to put the model on")
return parser.parse_args()


def main():
args = parse_args()
model_id = args.model_id

# Detect if GPU is available, otherwise use CPU
device = torch.device(
args.device or ("cuda" if torch.cuda.is_available() else "cpu")
)
print(f"Using device: {device}")

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Load model to defined device
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
model.generation_config.pad_token_id = model.generation_config.eos_token_id

# Load grammar
with open(f"examples/grammars/chain_of_thought_arithmetic.ebnf", "r") as file:
grammar_str = file.read()

grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

# Generate
prompts = [
"179*12+34=", # no CoT
"think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=", # CoT
]

input_ids = tokenizer(
prompts, add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"].to(
device
) # Move input_ids to the same device as model

n_examples = input_ids.shape[0]

max_new_tokens = 30

unconstrained_output = model.generate(
input_ids,
do_sample=False,
max_new_tokens=max_new_tokens,
repetition_penalty=1.9,
num_return_sequences=1,
)

constrained_output = model.generate(
input_ids,
do_sample=False,
max_new_tokens=max_new_tokens,
logits_processor=[grammar_processor],
repetition_penalty=1.9,
num_return_sequences=1,
)

# decode outputs (possibly of different lengths across decoding modes)
generations = tokenizer.batch_decode(
unconstrained_output, skip_special_tokens=True
) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True)

parsed_grammar = parse_ebnf(grammar_str)
string_grammar = StringRecognizer(
parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"]
)

print()
for i in range(n_examples):
print(f"Unconstrained: {generations[i]}")
constrained_generation = generations[i + n_examples]
print(f"Constrained: {constrained_generation}")
print(
f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}"
)
print(
f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}"
)
print()


if __name__ == "__main__":
main()

##########################
# Example output (no chain of thought):
# Unconstrained:
# 179*12+34=0,
# -568. Вторемьте в некоторых другие позиции (включая и
#
# Constrained:
# 179*12+34=0;
# The constrained generation matches the grammar: True
# The generated prefix matches the grammar: True
#
# Example output (with chain of thought):
# Unconstrained:
# think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=2148.0 + 117 = <<< error: invalid type comparison >>>;
# ``` | ```vbnet
# '
# Constrained:
# think step-by-step, 12+7*19=12+133=145 >>> 145; 7*8+6*9=56+54=110 >>> 110; 179*12+34=2148+34=2182 >>> 2182;
# The constrained generation matches the grammar: True
# The generated prefix matches the grammar: True
##########################
10 changes: 5 additions & 5 deletions examples/grammars/SMILES/acrylates.ebnf
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)*
root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)*

group_radical_left ::= "(" ( group_symbol_left (group_bond smiles+)? )+ ")"
group_radical_left ::= "(" ( group_symbol_left (group_bond smiles+)? )+ ")"

group_radical_right ::= "(" ( (smiles+ group_bond )? group_symbol_right )+ ")"
group_radical_right ::= "(" ( (smiles+ group_bond )? group_symbol_right )+ ")"

group_bond ::= ( "-" | "\\" | "/" )

group_symbol_left ::= "C=CC(=O)O" | "C=CC(O)=O" | "C(=C)C(=O)O" | "C(=C)C(O)=O" | "CC(=C)(=O)O" | "CC(=C)(O)=O"
group_symbol_left ::= "C=CC(=O)O" | "C=CC(O)=O" | "C(=C)C(=O)O" | "C(=C)C(O)=O" | "CC(=C)(=O)O" | "CC(=C)(O)=O"

group_symbol_right ::= "OC(=O)C=C" | "O=C(O)C=C" | "OC(=O)C(=C)" | "O=C(O)C(=C)" | "O(O=)(C=)CC" | "O=(O)(C=)CC"

Expand Down Expand Up @@ -49,7 +49,7 @@ element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) |
"S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? |
"T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) |
"U" | "V" | "W" | "Xe" | "Y" "b"? |
"Z" ( "n" | "r" )
"Z" ( "n" | "r" )


ring_closure ::= "%" [1-9] [0-9] | [0-9]
Expand Down
Loading