Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add last_segment_tokens to stream_generate GenerationResponse #1123

Open
mattjcly opened this issue Nov 25, 2024 · 3 comments
Open

Add last_segment_tokens to stream_generate GenerationResponse #1123

mattjcly opened this issue Nov 25, 2024 · 3 comments

Comments

@mattjcly
Copy link

It would be useful to understand exactly what tokens went into GenerationResponse.text as returned by stream_generate.

It would also be helpful to get the logprobs for those tokens, as a way to understand the models "certainty" throughout outputting the segment.

However, this second part could be done offline with a lookup table between prior GenerationResponse.token and their GenerationResponse.logprobs, since last_segment_tokens must be a subset of prior tokens.

I would image that the streaming detokenizer will require an expansion of its API so that detokenizer.last_segment will have the tokens that went into it stored and output in some way.

@awni
Copy link
Member

awni commented Dec 16, 2024

I looked into this a bit more today. My recommendation for now is to do something like the following:

segment_tokens = []
for response in stream_generate(model, tokenizer, prompt, **kwargs):
    segment_tokens.append(response.token)
    if response.text:
        # At this point segment_tokens corresponds to response.text
        # Then clear segment tokens and start again.
        segment_tokens.clear()

What do you think about that? Does it work for your use case?

You can do the same thing with the logprobs btw. Just keep a secondary list of logprobs with response.logprobs appended and clear it at the same time you clear the tokens.

@mattjcly
Copy link
Author

@awni thank a lot for taking a look!
This is actually currently what we have in our https://github.com/lmstudio-ai/mlx-engine.

See the following example text+tokens snippet from a generation using mlx-community/Meta-Llama-3.1-8B-Instruct-4bit:

// t1
text segment: ' kings)'
segment tokens: [[8, ')'], [4184, ' according']]

// t2
text segment: ' according'
segment tokens: [[311, ' to']]

// t3
text segment: ' to'
segment tokens: [[3230, ' specific']]

I think the "issue" with this implementation occurs at the boundary of multi-token segments, where the text segment appears to lag one behind the segment tokens.

In the above example, the text segment output at t2 is yielded because of both (I think) token id 4184 (from t1, the ' according') and token id 311 (from t2, because of the space at the start of ' to'). Combined, these were the tokens that enabled the yield-able segment? Do you think it'd be viable/reasonable to somehow reliably get something like:

// t1
text segment: ' kings)'
segment tokens: [[<id>, ' kings'], [8, ')'], [4184, ' according']]

// t2
text segment: ' according'
segment tokens: [[4184, ' according'], [311, ' to']]

// t3
text segment: ' to'
segment tokens: [[311, ' to'], [3230, ' specific']]

Since if we had logprobs for each of the "segment tokens" in this example, we could then reasonably estimate how "confident" the model was in predicting each segment, by taking some sort of aggregate statistic across all the tokens for each segment?

Otherwise, in the original t3 case:

// t3
text segment: ' to'
segment tokens: [[3230, ' specific']]

I'm not sure how we could really confidently correlate the logprob of token 3230 to the "confidence" in the model's output of ' to'

Does that make any sense, or do you think my desire falls apart/is un-achievable in some way? Also open to other suggestions you may have of how one could glean useful information from logprobs with the current buffering detokenizer situation. I'm afraid I'm having a hard time currently intuitively seeing how to do so :(

@awni
Copy link
Member

awni commented Dec 17, 2024

@mattjcly I think the behavior has changed in the decoding in the latest mlx lm at least for the BPE decoder, so there should not be a lag now:

from mlx_lm import load, stream_generate
model, tokenizer = load("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")

prompt = "Write a story about Einstein"
messages = [{ "role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)


segment_tokens = []
for response in stream_generate(model, tokenizer, prompt):
    segment_tokens.append(response.token)
    if response.text:
        # At this point segment_tokens corresponds to response.text
        # Then clear segment tokens and start again.
        print(response.text, segment_tokens, tokenizer.decode(segment_tokens))
        segment_tokens.clear()
It [2181] It
 was [574]  was
 a [264]  a

Could you try running that with your prompt and see if it makes more sense?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants