-
Notifications
You must be signed in to change notification settings - Fork 917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add last_segment_tokens
to stream_generate
GenerationResponse
#1123
Comments
I looked into this a bit more today. My recommendation for now is to do something like the following: segment_tokens = []
for response in stream_generate(model, tokenizer, prompt, **kwargs):
segment_tokens.append(response.token)
if response.text:
# At this point segment_tokens corresponds to response.text
# Then clear segment tokens and start again.
segment_tokens.clear() What do you think about that? Does it work for your use case? You can do the same thing with the logprobs btw. Just keep a secondary list of logprobs with |
@awni thank a lot for taking a look! See the following example text+tokens snippet from a generation using
I think the "issue" with this implementation occurs at the boundary of multi-token segments, where the text segment appears to lag one behind the segment tokens. In the above example, the text segment output at t2 is yielded because of both (I think) token id 4184 (from t1, the ' according') and token id 311 (from t2, because of the space at the start of ' to'). Combined, these were the tokens that enabled the yield-able segment? Do you think it'd be viable/reasonable to somehow reliably get something like:
Since if we had logprobs for each of the "segment tokens" in this example, we could then reasonably estimate how "confident" the model was in predicting each segment, by taking some sort of aggregate statistic across all the tokens for each segment? Otherwise, in the original t3 case:
I'm not sure how we could really confidently correlate the logprob of token 3230 to the "confidence" in the model's output of ' to' Does that make any sense, or do you think my desire falls apart/is un-achievable in some way? Also open to other suggestions you may have of how one could glean useful information from logprobs with the current buffering detokenizer situation. I'm afraid I'm having a hard time currently intuitively seeing how to do so :( |
@mattjcly I think the behavior has changed in the decoding in the latest mlx lm at least for the BPE decoder, so there should not be a lag now: from mlx_lm import load, stream_generate
model, tokenizer = load("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
prompt = "Write a story about Einstein"
messages = [{ "role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
segment_tokens = []
for response in stream_generate(model, tokenizer, prompt):
segment_tokens.append(response.token)
if response.text:
# At this point segment_tokens corresponds to response.text
# Then clear segment tokens and start again.
print(response.text, segment_tokens, tokenizer.decode(segment_tokens))
segment_tokens.clear() It [2181] It
was [574] was
a [264] a Could you try running that with your prompt and see if it makes more sense? |
It would be useful to understand exactly what tokens went into
GenerationResponse.text
as returned bystream_generate
.It would also be helpful to get the
logprobs
for those tokens, as a way to understand the models "certainty" throughout outputting the segment.However, this second part could be done offline with a lookup table between prior
GenerationResponse.token
and theirGenerationResponse.logprobs
, sincelast_segment_tokens
must be a subset of priortoken
s.I would image that the streaming detokenizer will require an expansion of its API so that
detokenizer.last_segment
will have the tokens that went into it stored and output in some way.The text was updated successfully, but these errors were encountered: