Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sliding window + chunking input for mistral model #1524

Merged
merged 3 commits into from
Nov 17, 2023

Conversation

minhthuc2502
Copy link
Collaborator

Mistral is similar with Llama. Add converter for Mistral and we can inference Mistral without any modification in the condition of prompt length < 4096 tokens

@vince62s
Copy link
Member

it seems a duplicate of #1516
however, same reamrk as 5258816
just change the OneAPI version it seems the other one is no longer supported.

to fully support Mistral we would need to modify slightly the kv cache so that the max length remains at sliding_window when it comes longer than sliding_window (on the seq_len dimension). (We need to remove the first item)

@minhthuc2502
Copy link
Collaborator Author

Thank you for your remark. I would update kv cache as soon as possible. In case the prompt is smaller than length of the sliding windows, it works perfectly.

@BBC-Esq
Copy link

BBC-Esq commented Oct 28, 2023

When will this be added to ctranslate2?

@vince62s
Copy link
Member

It lacks two things:

  1. you need to get the sliding_window parameter from the config file, you can add here:
    https://github.com/OpenNMT/CTranslate2/pull/1524/files#diff-589896765935bed0caf86539bd02041d80e2588e9ec675138596adfb443c67ddR1309
    and pass it here:
    https://github.com/OpenNMT/CTranslate2/pull/1524/files#diff-589896765935bed0caf86539bd02041d80e2588e9ec675138596adfb443c67ddR1338

  2. here:
    https://github.com/OpenNMT/CTranslate2/blob/master/src/layers/attention.cc#L531-L536
    this is where we concat the current K, V time step to the existing cache.
    we need to remove the first time step if the length of cache +1 > sliding_window (that we need to retrieve from above)
    we can do it directly here or add an op or do it in concat with an extra parameter
    I think adding an op would make things clearer.

for reference we do the same here in opennmt-py:
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/multi_headed_attn.py#L411-L413

@vince62s vince62s mentioned this pull request Oct 29, 2023
@minhthuc2502 minhthuc2502 changed the title Add converter for mistral model [WIP] Add converter for mistral model Oct 31, 2023
@minhthuc2502
Copy link
Collaborator Author

minhthuc2502 commented Oct 31, 2023

Hello @vince62s . Thank you for your help. I just work on it in progress. For the sliding window and the Rolling buffer cache, I wait some hardware needed for testing it. Additionally, I think to fully support, I need to implement the chunking for very large input too, but I don't understand clearly the idea of chunking in case: input 10000 tokens and the window size is 4096, after 2 first layer where we can compute attention over cache and over chunk with the size of 4096, the for the next layers, which size of query that we could use - will we take the rest of token for the query of (10000 - 2 * 4096) tokens or 4096 last tokens ?

Thank you in advance

@vince62s
Copy link
Member

I am working on it too.
I think we don't need to implement the same rolling buffer cache.
We just need to limit the size of the KV cache to "sliding_window" by removing the first entry one by one as we add the new key/value proj.

@minhthuc2502 minhthuc2502 force-pushed the dev/mistral_support branch 5 times, most recently from be1275f to 9ec70f2 Compare November 2, 2023 10:33
@minhthuc2502 minhthuc2502 changed the title [WIP] Add converter for mistral model Sliding window + chunking input for mistral model Nov 3, 2023
@vince62s
Copy link
Member

vince62s commented Nov 3, 2023

ok, fully reviewed offline with @minhthuc2502
in summary this PR does 3 things:

  • create the Mistral Config loader
  • limit the KV cache to seqlen = sliding_window which will limit the memory footprint in the MHA
  • chunk the initial prompt so that each chunk has a seqlen capped to sliding_window, process each chunk so that the KV cache updates and "slide" with the above mechanism

We could add 2 small things but for later:

  1. the chunk size could be a setting so that we may slice in bigger or smaller parts
  2. the attention mask for the initial prompt (step 0) is the standard transformer mask (upper triangular) when Mistral was trained with a "band" mask. It does not really impact inference, but could be a nice to have (and relevant only if we set a chunk size > sliding_window, hence if we implement1) ).

I tested with the examples/llama2/chat.py and works perfectly fine.

EDIT: we are still having some issues with long contexts, working on it.

closes #1501

@vince62s
Copy link
Member

vince62s commented Nov 3, 2023

please don't put "user questions" in PR's threads, uses the issue for that or better the forum. in the end the answer is: usage will be the same as for other models.

@minhthuc2502 minhthuc2502 force-pushed the dev/mistral_support branch 6 times, most recently from 7e0d6dd to b770497 Compare November 16, 2023 11:05
@vince62s vince62s merged commit 120746e into OpenNMT:master Nov 17, 2023
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants