From e5113e8d746bfc10b70d956a3ae64dd460becfda Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 30 Dec 2024 03:40:34 +0000 Subject: [PATCH] Add --jinja and --chat-template-file flags --- Makefile | 2 + common/CMakeLists.txt | 2 + common/arg.cpp | 43 ++++++++++- common/common.cpp | 68 +++++++++++++++- common/common.h | 14 +++- examples/server/README.md | 2 +- examples/server/server.cpp | 67 ++++++++++++---- .../server/tests/unit/test_chat_completion.py | 15 ++-- examples/server/tests/utils.py | 7 +- examples/server/utils.hpp | 40 ++++++---- scripts/get_hf_chat_template.py | 77 +++++++++++++++++++ src/CMakeLists.txt | 2 +- 12 files changed, 289 insertions(+), 50 deletions(-) create mode 100755 scripts/get_hf_chat_template.py diff --git a/Makefile b/Makefile index 19ae0d5f1c87b..295522ba356b4 100644 --- a/Makefile +++ b/Makefile @@ -1361,7 +1361,9 @@ llama-server: \ examples/server/httplib.h \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ + common/chat-template.hpp \ common/json.hpp \ + common/minja.hpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index df1cdf9a59af3..24b7f8741aab4 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-template.hpp common.cpp common.h console.cpp @@ -64,6 +65,7 @@ add_library(${TARGET} STATIC json.hpp log.cpp log.h + minja.hpp ngram-cache.cpp ngram-cache.h sampling.cpp diff --git a/common/arg.cpp b/common/arg.cpp index deb11378657f4..edcda60e08e16 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1889,24 +1889,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--jinja"}, + "use jinja template for chat (default: disabled)", + [](common_params & params) { + params.use_jinja = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( "set custom jinja chat template (default: template taken from model's metadata)\n" "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() ), [](common_params & params, const std::string & value) { - if (!common_chat_verify_template(value)) { + if (!common_chat_verify_template(value, params.use_jinja)) { throw std::runtime_error(string_format( - "error: the supplied chat template is not supported: %s\n" - "note: llama.cpp does not use jinja parser, we only support commonly used templates\n", - value.c_str() + "error: the supplied chat template is not supported: %s%s\n", + value.c_str(), + params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" )); } params.chat_template = value; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + add_opt(common_arg( + {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", + "set custom jinja chat template file (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::string chat_template; + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(chat_template) + ); + if (!common_chat_verify_template(chat_template, params.use_jinja)) { + throw std::runtime_error(string_format( + "error: the supplied chat template is not supported: %s%s\n", + value.c_str(), + params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" + )); + } + params.chat_template = chat_template; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index 20be9291161ca..6bdcd80a1b756 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1576,13 +1576,13 @@ std::vector common_tokenize( return result; } -std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { +static std::string _common_token_to_piece(const struct llama_model * model, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); - int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { @@ -1592,6 +1592,10 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token return piece; } +std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { + return _common_token_to_piece(llama_get_model(ctx), token, special); +} + std::string common_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { std::string text; text.resize(std::max(text.capacity(), tokens.size())); @@ -1612,7 +1616,21 @@ std::string common_detokenize(llama_context * ctx, const std::vector", ""); + chat_template.apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } + llama_chat_message chat[] = {{"user", "test"}}; int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; @@ -1693,6 +1711,48 @@ std::string common_chat_format_example(const struct llama_model * model, return common_chat_apply_template(model, tmpl, msgs, true); } +static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) { + int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); + if (tlen > 0) { + std::vector curr_tmpl_buf(tlen + 1, 0); + if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { + return std::string(curr_tmpl_buf.data(), tlen); + } + } + return ""; +} + +llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) +{ + auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true); + auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true); + std::string default_template_src = chat_template_override; + std::string tool_use_template_src = chat_template_override; + if (chat_template_override.empty()) { + default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template"); + tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use"); + } + if (default_template_src.empty() || default_template_src == "chatml") { + if (!tool_use_template_src.empty()) { + default_template_src = tool_use_template_src; + } else { + default_template_src = R"( + {%- for message in messages -%} + {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} + {%- endfor -%} + {%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} + {%- endif -%} + )"; + } + } + return { + .default_template = { default_template_src, bos_token, eos_token }, + .tool_use_template = tool_use_template_src.empty() ? std::nullopt + : std::optional({ tool_use_template_src, bos_token, eos_token }), + }; +} + // // KV cache utils // diff --git a/common/common.h b/common/common.h index 1d2bd932c211d..7747d66d55b67 100644 --- a/common/common.h +++ b/common/common.h @@ -3,6 +3,7 @@ #pragma once #include "llama.h" +#include "chat-template.hpp" #include #include @@ -324,6 +325,7 @@ struct common_params { std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT std::string chat_template = ""; // NOLINT + bool use_jinja = false; // NOLINT bool enable_chat_template = true; std::vector api_keys; @@ -571,8 +573,8 @@ struct common_chat_msg { std::string content; }; -// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool common_chat_verify_template(const std::string & tmpl); +// Check if the template is supported or not. Returns true if it's valid +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml @@ -593,6 +595,14 @@ std::string common_chat_format_single(const struct llama_model * model, std::string common_chat_format_example(const struct llama_model * model, const std::string & tmpl); + +struct llama_chat_templates { + minja::chat_template default_template; + std::optional tool_use_template; +}; + +llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); + // // KV cache utils // diff --git a/examples/server/README.md b/examples/server/README.md index c7d91be9976c4..24ef85727092d 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -129,7 +129,7 @@ The project is under active development, and we are [looking for feedback and co | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | `--grammar-file FNAME` | file to read grammar from | | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | - +| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) | **Example-specific params** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 30ff3b14957dc..cfa90056ae995 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1623,15 +1623,35 @@ struct server_context { return true; } - bool validate_model_chat_template() const { - std::vector model_template(2048, 0); // longest known template is about 1200 bytes - std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - if (res >= 0) { - llama_chat_message chat[] = {{"user", "test"}}; - std::string tmpl = std::string(model_template.data(), model_template.size()); - int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); - return chat_res > 0; + bool validate_model_chat_template(bool use_jinja) const { + llama_chat_message chat[] = {{"user", "test"}}; + + if (use_jinja) { + auto templates = llama_chat_templates_from_model(model, ""); + try { + templates.default_template.apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + if (templates.tool_use_template) { + templates.tool_use_template->apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + } + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + } + } else { + std::vector model_template(2048, 0); // longest known template is about 1200 bytes + std::string template_key = "tokenizer.chat_template"; + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + if (res >= 0) { + std::string tmpl = std::string(model_template.data(), model_template.size()); + int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); + return chat_res > 0; + } } return false; } @@ -3476,15 +3496,30 @@ int main(int argc, char ** argv) { } }; - const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + std::mutex chat_templates_mutex; + std::optional chat_templates; + + auto get_chat_templates = [&ctx_server, &chat_templates_mutex, &chat_templates]() -> const llama_chat_templates & { + std::lock_guard lock(chat_templates_mutex); + if (!chat_templates) { + chat_templates = llama_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template); + } + return *chat_templates; + }; + + const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) { // this endpoint is publicly available, please only return what is safe to be exposed + const auto & templates = get_chat_templates(); json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", llama_get_chat_template(ctx_server.model) }, + { "chat_template", templates.default_template.source() }, { "build_info", build_info }, }; + if (ctx_server.params_base.use_jinja && templates.tool_use_template) { + data["chat_template_tool_use"] = templates.tool_use_template->source(); + } res_ok(res, data); }; @@ -3685,13 +3720,17 @@ int main(int argc, char ** argv) { return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic, &get_chat_templates](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + auto body = json::parse(req.body); + const auto & templates = get_chat_templates(); + const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template; + json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja); + return handle_completions_generic( SERVER_TASK_TYPE_COMPLETION, data, @@ -4111,7 +4150,7 @@ int main(int argc, char ** argv) { // if a custom chat template is not supplied, we will use the one that comes with the model (if any) if (params.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { + if (!ctx_server.validate_model_chat_template(params.use_jinja)) { LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); params.chat_template = "chatml"; } diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 88549708113e9..ef716cc1ab223 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -4,22 +4,24 @@ server = ServerPreset.tinyllama2() - -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja", [ - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True), ] ) -def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja): global server + server.jinja = jinja server.start() res = server.make_request("POST", "/chat/completions", data={ "model": model, @@ -102,6 +104,7 @@ def test_chat_completion_with_openai_library(): @pytest.mark.parametrize("response_format,n_predicted,re_content", [ ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), + ({"type": "json_schema", "json_schema": {"const": "42"}}, 6, "\"42\""), ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), ({"type": "json_object"}, 10, "(\\{|John)+"), ({"type": "sound"}, 0, None), diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 277125e88b534..f0fe7b15dbf68 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -68,8 +68,9 @@ class ServerProcess: pooling: str | None = None draft: int | None = None api_key: str | None = None - response_format: str | None = None lora_files: List[str] | None = None + chat_template_file: str | None = None + jinja: bool | None = None disable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None @@ -154,6 +155,10 @@ def start(self, timeout_seconds: int = 10) -> None: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.jinja: + server_args.append("--jinja") if self.disable_ctx_shift: server_args.extend(["--no-context-shift"]) if self.api_key: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 334f2f19207ef..81a2d62e960bc 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -16,6 +16,8 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +#include "minja.hpp" +#include "chat-template.hpp" #include #include @@ -382,19 +384,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri return formatted_chat; } -static std::string llama_get_chat_template(const struct llama_model * model) { - std::string template_key = "tokenizer.chat_template"; - // call with NULL buffer to get the total size of the string - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0); - if (res < 2) { - return ""; - } else { - std::vector model_template(res + 1, 0); - llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - return std::string(model_template.data(), model_template.size() - 1); - } -} - // // base64 utils (TODO: move to common in the future) // @@ -552,11 +541,21 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ - const std::string & chat_template) { + const minja::chat_template & tmpl, + bool use_jinja) +{ json llama_params; - // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + auto tools = json_value(body, "tools", json()); + auto has_tools = tools.is_array() && !tools.empty(); + + if (has_tools) { + if (use_jinja) { + LOG_WRN("tools param is not fully supported yet\n"); + } else { + throw std::runtime_error("tools param requires --jinja flag"); + } + } // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -579,6 +578,13 @@ static json oaicompat_completion_params_parse( } } + // Apply chat template to the list of messages + if (use_jinja) { + llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); + } else { + llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); + } + // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { @@ -594,7 +600,7 @@ static json oaicompat_completion_params_parse( } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tools", "tool_choice" }; + static const std::vector unsupported_params { "tool_choice" }; for (const auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); diff --git a/scripts/get_hf_chat_template.py b/scripts/get_hf_chat_template.py new file mode 100755 index 0000000000000..820b84efc26b1 --- /dev/null +++ b/scripts/get_hf_chat_template.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +''' + Fetches the Jinja chat template of a HuggingFace model. + If a model has multiple chat templates, you can specify the variant name. + + Syntax: + ./scripts/get_hf_chat_template.py model_id [variant] + + Examples: + ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct + ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use + ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct +''' + +import json +import re +import sys + + +def get_hf_chat_template(model_id, variant=None): + try: + # Use huggingface_hub library if available. + # Allows access to gated models if the user has access and ran `huggingface-cli login`. + from huggingface_hub import hf_hub_download + with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: + config_str = f.read() + except ImportError: + import requests + assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}" + response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") + if response.status_code == 401: + raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`') + response.raise_for_status() + config_str = response.text + + try: + config = json.loads(config_str) + except json.JSONDecodeError: + # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json + # (Remove extra '}' near the end of the file) + config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + + chat_template = config['chat_template'] + if isinstance(chat_template, str): + return chat_template + else: + variants = { + ct['name']: ct['template'] + for ct in chat_template + } + + def format_variants(): + return ', '.join(f'"{v}"' for v in variants.keys()) + + if variant is None: + if 'default' not in variants: + raise Exception(f'Please specify a chat template variant (one of {format_variants()})') + variant = 'default' + print(f'Note: picked "default" chat template variant (out of {format_variants()})', file=sys.stderr) + elif variant not in variants: + raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})") + + return variants[variant] + + +def main(args): + if len(args) < 1: + raise ValueError("Please provide a model ID and an optional variant name") + model_id = args[0] + variant = None if len(args) < 2 else args[1] + + template = get_hf_chat_template(model_id, variant) + print(template, end=None) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2d3ea09945790..4bb58146ede32 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -17,7 +17,7 @@ add_library(llama unicode-data.cpp ) -target_include_directories(llama PUBLIC . ../include) +target_include_directories(llama PUBLIC . ../include ../common) target_compile_features (llama PUBLIC cxx_std_17) # don't bump target_link_libraries(llama PUBLIC ggml)