Skip to content

Commit

Permalink
Fix the prompt for mistral when using instruct/interactive mode. (#1013)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Oct 1, 2023
1 parent 328167e commit f6054e9
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions candle-examples/examples/quantized/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ enum Which {
Mistral7bInstruct,
}

impl Which {
fn is_mistral(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => false,
Self::Mistral7b | Self::Mistral7bInstruct => true,
}
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
Expand Down Expand Up @@ -114,17 +131,10 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat => "hf-internal-testing/llama-tokenizer",
Which::Mistral7b | Which::Mistral7bInstruct => "mistralai/Mistral-7B-v0.1",
let repo = if self.which.is_mistral() {
"mistralai/Mistral-7B-v0.1"
} else {
"hf-internal-testing/llama-tokenizer"
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
Expand Down Expand Up @@ -315,7 +325,11 @@ fn main() -> anyhow::Result<()> {
prompt.pop();
}
}
prompt
if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else {
prompt
}
}
};
print!("{}", &prompt_str);
Expand Down Expand Up @@ -351,6 +365,8 @@ fn main() -> anyhow::Result<()> {
all_tokens.push(next_token);
print_token(next_token, &tokenizer);

let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();

let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
Expand All @@ -369,6 +385,9 @@ fn main() -> anyhow::Result<()> {
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
if next_token == eos_token {
break;
};
}
let dt = start_post_prompt.elapsed();
println!(
Expand Down

0 comments on commit f6054e9

Please sign in to comment.