Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatch for XLM_ROBERTA tokenizer w/wo special tokens🔀 #852

Open
KarelZe opened this issue Dec 1, 2024 · 3 comments
Open

Mismatch for XLM_ROBERTA tokenizer w/wo special tokens🔀 #852

KarelZe opened this issue Dec 1, 2024 · 3 comments

Comments

@KarelZe
Copy link

KarelZe commented Dec 1, 2024

Thanks for this awesome library 💯

I'm currently trying to end-to-end convert DONUT, which internally uses the XLMRoberta tokenizer in its preprocessor (see here).
XLMRoberta is also among the supported tokenizers in onnxruntime-extensions (see here).

However, for encoded inputs, the onnx-based tokenizer yields different input ids for sequences with additional special tokens and decoding yields both different tokens for sequences with and without special tokens. See below:

from onnxruntime_extensions import (
    OrtPyFunction,
    gen_processing_models,
)
from transformers import DonutProcessor

path_model: str = "naver-clova-ix/donut-base"

processor = DonutProcessor.from_pretrained(path_model, use_fast=False)

tokenizer = processor.tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": ["<s_name>", "</s_name>"]})

print("class: ", type(tokenizer))

m_tok, m_detok = gen_processing_models(tokenizer, pre_kwargs={}, post_kwargs={})
print("Tokenizer Model Inputs:", [node.name for node in m_tok.graph.input])
print(
    "Tokenizer Model Outputs:",
    [node.name for node in m_tok.graph.output],
)

print("-" * 8, "example (plain)", "-" * 8)

text = "Hello World."
ids = tokenizer.encode(text, return_tensors="np")
actual_ids = OrtPyFunction(m_tok)([text])

print("input:", text)
print("ids (enc):", actual_ids[0], "<->", ids[0])
print(
    "tokens (dec):",
    OrtPyFunction(m_detok)(actual_ids[0])[0],
    "<->",
    tokenizer.decode(actual_ids[0]),
    "\n",
)

print("-" * 8, "example (special tokens)", "-" * 8)

text = "<s_name>Angela</s_name>"

ids = tokenizer.encode(text, return_tensors="np")
actual_ids = OrtPyFunction(m_tok)([text])
print("input:", text)
print("ids (enc):", actual_ids[0], "<->", ids[0])
print(
    "tokens (dec):",
    OrtPyFunction(m_detok)(actual_ids[0])[0],
    "<->",
    tokenizer.decode(ids[0]),
)

Output:

class:  <class 'transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer'>
Tokenizer Model Inputs: ['inputs']
Tokenizer Model Outputs: ['tokens', 'instance_indices', 'token_indices']
-------- example (plain) --------
input: Hello World.
ids (enc): [    0 37857 48225 39539     2] <-> [    0 37857 48225 39539     2]
tokens (dec):  ⁇  Ord특 que <-> <s> Hello World.</s> 

-------- example (special tokens) --------
input: <s_name>Angela</s_name>
ids (enc): [    0 41040 46192 41403  4131 34791 13892     4 40598 46192 41403  4131
 34791     2] <-> [    0 57525 46299 57526     2]
tokens (dec):  ⁇  estu W斗 таможенн BleTemp _ 이제 W斗 таможенн Ble <-> <s> <s_name> Angela </s_name> </s>

with these dependencies:

onnx==1.17.0
onnxoptimizer==0.3.13
onnxruntime==1.20.1
onnxruntime-extensions==0.13.0
tokenizers==0.20.3
transformers==4.46.3

As visible, the input ids from the encoder are the same, if no additional special tokens are present. If additional special tokens are present, however, for the onnx-based tokenizer the actual special tokens such as <s_name> are further decomposed into individual tokens. To my understanding, huggingface treats special tokens first, then applies the default tokenization. The output of the decoder is different. Interestingly, BOS <s> and EOS </s> are not handled correctly.

As an alternative I also experimented with directly converting to a sentencepiece tokenizer following this test without success. Another alternative, that I haven't yet tried, is altering the bpe model file directly, as shown here and load it here.

Could you please give me any hint, how to align the output of the onnx tokenizer with the hf tokenizer?

The issue could be related to #828 (also sentencepiece-based with sentinel tokens).

@KarelZe
Copy link
Author

KarelZe commented Dec 11, 2024

I made some progress on the wrong decoding and wrong handling of special tokens during encoding.

wrong decoding

Decoding requires "fairseq" to be true (see here9, so that an offset is added to retrieve the tokens from the sentencepiece file using the transformed fairseq indices. I wasn't yet able to figure out, why it is not passed correctly to the decoder, as it is activated here. For the time being I constructed the decoder graph manually. @wenbingl Could you please shed some light on this?

# see: https://github.com/microsoft/onnxruntime-extensions/commit/1d8b81f59bf782244508485c6dbaa6418948ba12#diff-9d505227452e3355c8bb29f6649f6ac17b358b0ca9c1405a24e4a95ffaeb28c4R289

from pathlib import Path

import onnx
from onnx import TensorProto, helper

opset = 18
sp_model_bytes = Path("extended.model").read_bytes()

sentencepiece_decoder_attributes = {
    "model": sp_model_bytes,
}

ids = helper.make_tensor_value_info(
    "ids", TensorProto.INT64, [None]
)
fairseq = helper.make_tensor_value_info(
    "fairseq", TensorProto.BOOL, [1]
)
output_tensor = helper.make_tensor_value_info(
    "str", TensorProto.STRING, [None]
)

# set fairseq for correct offset
# https://github.com/microsoft/onnxruntime-extensions/blob/c3674f5c03f1c30c4a4cbc7b890d5a2d6cac1d9f/operators/tokenizer/sentencepiece_decoder.hpp#L45
node = helper.make_node(
    "SentencepieceDecoder",
    inputs=["ids", "fairseq"],
    outputs=["str"],
    domain="ai.onnx.contrib",
    name="sentencepiece-decoder",
    **sentencepiece_decoder_attributes
)

graph = helper.make_graph(
    nodes=[node],
    name="SentencePieceDecoderGraph",
    inputs=[ids, fairseq],
    outputs=[output_tensor],
)

model = helper.make_model(
    graph,
    producer_name="custom-sentencepiece-decoder",
    opset_imports=[helper.make_opsetid("", opset)]
)

onnx.save(model, "tokenizer_decoder_manual.onnx")

handling of special tokens

I added the special tokens manually to the sentencepiece protobuf file. Elements are added add at the relevant sentencepiece index, which is off-by-one compared to fairseq index. We can then load the vocab file and convert the tokenizer.

from pathlib import Path

import sentencepiece as spm
from transformers import DonutProcessor
from transformers.convert_slow_tokenizer import import_protobuf

processor = DonutProcessor.from_pretrained("katanaml-org/invoices-donut-model-v1")
tokenizer = processor.tokenizer

vocab_file = tokenizer.vocab_file
added_tokens = tokenizer.get_added_vocab()
# filter sos, eos, mask etc. + sort for correct insertion
special_token_ids = {
    idx: tok for tok, idx in added_tokens.items() if idx >= 4
}
special_token_ids = dict(sorted(special_token_ids.items()))

model = import_protobuf()

m = model.ModelProto()
m.ParseFromString(Path(vocab_file).read_bytes())

# insert empty pieces
existant_pieces = len(m.pieces)
print(existant_pieces)

# reduce by one, as later added in transformers/onnx tokenizer
# see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py#L270
filler_pieces = min_special_token_id - existant_pieces - 1

# insert unused dummies to insert at correct index
for i in range(filler_pieces):
    filler = model.ModelProto().SentencePiece()
    filler.piece = "∅" * (i + 1)
    filler.score = 0
    m.pieces.append(filler)

print(filler_pieces)

for offset, (idx, token) in enumerate(special_token_ids.items()):
    # for ULM ok to add any piece can add any piece https://github.com/google/sentencepiece/issues/903
    new_p = model.ModelProto().SentencePiece()
    new_p.piece = token
    new_p.score = 0
    m.pieces.append(new_p)

with Path("extended.model").open(mode="wb") as f:
    f.write(m.SerializeToString())

sp = spm.SentencePieceProcessor("extended.model")
token_seq = '<s_summary><s_total_vat>$ 88,92</s_total_vat><s_total_net_worth>$ 889,20</s_total_net_worth><s_total_gross_worth>$ 978,12</s_total_gross_worth></s_summary><s_items><s_item_vat>10%</s_item_vat><s_item_qty>2,00</s_item_qty><s_item_net_worth>889,20</s_item_net_worth><s_item_net_price>444,60</s_item_net_price><s_item_gross_worth>978,12</s_item_gross_worth><s_item_desc>12" Marble Lapis Inlay Chess Table Top With 2" Pieces & 15" Wooden Stand W537</s_item_desc></s_items><s_header><s_seller_tax_id>985-73-8194</s_seller_tax_id><s_seller>Bradley-Andrade 9879 Elizabeth Common Lake Jonathan, RI 12335</s_seller><s_invoice_no>97159829</s_invoice_no><s_invoice_date>09/18/2015</s_invoice_date><s_iban>GB81LZWO32519172531418</s_iban><s_client_tax_id>994-72-1270</s_client_tax_id><s_client>Castro PLC Unit 9678 Box 9664 DPO AP 69387</s_client></s_header>'
print(sp.encode_as_pieces(token_seq))
print(sp.encode(token_seq))

tokenizer = XLMRobertaTokenizer(vocab_file="extended.model", sp_model_kwargs={})

ort_tokenizer, ort_decoder = gen_processing_models(
    tokenizer, pre_kwargs={"CAST_TOKEN_ID": True}, post_kwargs={"CAST_TOKEN_ID": True},
)

@wenbingl
Copy link
Member

I made some progress on the wrong decoding and wrong handling of special tokens during encoding.

wrong decoding

Decoding requires "fairseq" to be true (see here9, so that an offset is added to retrieve the tokens from the sentencepiece file using the transformed fairseq indices. I wasn't yet able to figure out, why it is not passed correctly to the decoder, as it is activated here. For the time being I constructed the decoder graph manually. @wenbingl Could you please shed some light on this?

see: 1d8b81f#diff-9d505227452e3355c8bb29f6649f6ac17b358b0ca9c1405a24e4a95ffaeb28c4R289

from pathlib import Path
import onnx
from onnx import TensorProto, helper

opset = 18
sp_model_bytes = Path("extended.model").read_bytes()

sentencepiece_decoder_attributes = {
"model": sp_model_bytes,
}

ids = helper.make_tensor_value_info(
"ids", TensorProto.INT64, [None]
)
fairseq = helper.make_tensor_value_info(
"fairseq", TensorProto.BOOL, [1]
)
output_tensor = helper.make_tensor_value_info(
"str", TensorProto.STRING, [None]
)

set fairseq for correct offset

if (fairseq.has_value() && (*fairseq)) {

node = helper.make_node(
"SentencepieceDecoder",
inputs=["ids", "fairseq"],
outputs=["str"],
domain="ai.onnx.contrib",
name="sentencepiece-decoder",
**sentencepiece_decoder_attributes
)

graph = helper.make_graph(
nodes=[node],
name="SentencePieceDecoderGraph",
inputs=[ids, fairseq],
outputs=[output_tensor],
)

model = helper.make_model(
graph,
producer_name="custom-sentencepiece-decoder",
opset_imports=[helper.make_opsetid("", opset)]
)

onnx.save(model, "tokenizer_decoder_manual.onnx")

handling of special tokens

I added the special tokens manually to the sentencepiece protobuf file. Elements are added add at the relevant sentencepiece index, which is off-by-one compared to fairseq index. We can then load the vocab file and convert the tokenizer.

from pathlib import Path

import sentencepiece as spm
from transformers import DonutProcessor
from transformers.convert_slow_tokenizer import import_protobuf

processor = DonutProcessor.from_pretrained("katanaml-org/invoices-donut-model-v1")
tokenizer = processor.tokenizer

vocab_file = tokenizer.vocab_file
added_tokens = tokenizer.get_added_vocab()

filter sos, eos, mask etc.

special_token_ids = {idx : tok for tok, idx in added_tokens.items() if idx >= 4}
special_token_ids = dict(sorted(special_token_ids.items(), key=lambda x: x[0]))
min_special_token_id = min(special_token_ids)

model = import_protobuf()

m = model.ModelProto()
m.ParseFromString(Path(vocab_file).read_bytes())

insert empty pieces

existant_pieces = len(m.pieces)
print(existant_pieces)

reduce by one, as later added in transformers/onnx tokenizer

see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py#L270

filler_pieces = min_special_token_id - existant_pieces - 1

insert unused dummies to insert at correct index

for i in range(filler_pieces):
filler = model.ModelProto().SentencePiece()
filler.piece = "∅" * (i + 1)
filler.score = 0
m.pieces.append(filler)

print(filler_pieces)

for offset, (idx, token) in enumerate(special_token_ids.items()):
# for ULM ok to add any piece can add any piece google/sentencepiece#903
new_p = model.ModelProto().SentencePiece()
new_p.piece = token
new_p.score = 0
m.pieces.append(new_p)

with Path("extended.model").open(mode="wb") as f:
f.write(m.SerializeToString())

sp = spm.SentencePieceProcessor("extended.model")
token_seq = '<s_summary><s_total_vat>$ 88,92</s_total_vat><s_total_net_worth>$ 889,20</s_total_net_worth><s_total_gross_worth>$ 978,12</s_total_gross_worth></s_summary><s_items><s_item_vat>10%</s_item_vat><s_item_qty>2,00</s_item_qty><s_item_net_worth>889,20</s_item_net_worth><s_item_net_price>444,60</s_item_net_price><s_item_gross_worth>978,12</s_item_gross_worth><s_item_desc>12" Marble Lapis Inlay Chess Table Top With 2" Pieces & 15" Wooden Stand W537</s_item_desc></s_items><s_header><s_seller_tax_id>985-73-8194</s_seller_tax_id><s_seller>Bradley-Andrade 9879 Elizabeth Common Lake Jonathan, RI 12335</s_seller><s_invoice_no>97159829</s_invoice_no><s_invoice_date>09/18/2015</s_invoice_date><s_iban>GB81LZWO32519172531418</s_iban><s_client_tax_id>994-72-1270</s_client_tax_id><s_client>Castro PLC Unit 9678 Box 9664 DPO AP 69387</s_client></s_header>'
print(sp.encode_as_pieces(token_seq))
print(sp.encode(token_seq))

tokenizer = XLMRobertaTokenizer(vocab_file="extended.model", sp_model_kwargs={})

ort_tokenizer, ort_decoder = gen_processing_models(
tokenizer, pre_kwargs={"CAST_TOKEN_ID": True}, post_kwargs={"CAST_TOKEN_ID": True},
)

@sayanshaw24

@wenbingl
Copy link
Member

Decoding requires "fairseq" to be true (see here9, so that an offset is added to retrieve the tokens from the sentencepiece file using the transformed fairseq indices. I wasn't yet able to figure out, why it is not passed correctly to the decoder, as it is activated here. For the time being I constructed the decoder graph manually. @wenbingl Could you please shed some light on this?

I think there is no default-int handling in post_processing of tokenizer in this function

def post_processing(self, **kwargs):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants