Skip to content

Commit

Permalink
common : update lora
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Jan 2, 2025
1 parent 8d117a5 commit 272cd0e
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 40 deletions.
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1512,15 +1512,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & value) {
params.lora_adapters.push_back({ std::string(value), 1.0 });
params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) {
params.lora_adapters.push_back({ fname, std::stof(scale) });
params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
Expand Down
21 changes: 11 additions & 10 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -922,20 +922,21 @@ struct common_init_result common_init_from_params(common_params & params) {

// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
common_lora_adapter_container loaded_la;
loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter.reset(llama_lora_adapter_init(model, la.path.c_str()));
if (loaded_la.adapter == nullptr) {
llama_lora_adapter_ptr lora;
lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_free_model(model);
return iparams;
}
iparams.lora_adapters.emplace_back(std::move(loaded_la)); // copy to list of loaded adapters

la.ptr = lora.get();
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}

if (!params.lora_init_without_apply) {
common_lora_adapters_apply(lctx, iparams.lora_adapters);
common_lora_adapters_apply(lctx, params.lora_adapters);
}

if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
Expand Down Expand Up @@ -1002,11 +1003,11 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}

void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters) {
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) {
llama_lora_adapter_clear(ctx);
for (auto & la : lora_adapters) {
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_lora_adapter_set(ctx, la.adapter.get(), la.scale);
llama_lora_adapter_set(ctx, la.ptr, la.scale);
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@

#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"

// TODO: "lora_adapter" is tautology
struct common_lora_adapter_info {
std::string path;
float scale;
};

struct common_lora_adapter_container : common_lora_adapter_info {
llama_lora_adapter_ptr adapter;
struct llama_lora_adapter * ptr;
};

using llama_tokens = std::vector<llama_token>;
Expand Down Expand Up @@ -478,11 +477,12 @@ std::string fs_get_cache_file(const std::string & filename);
// Model utils
//

// note: defines object's lifetime
struct common_init_result {
llama_model_ptr model;
llama_context_ptr context;

std::vector<common_lora_adapter_container> lora_adapters;
std::vector<llama_lora_adapter_ptr> lora;
};

struct common_init_result common_init_from_params(common_params & params);
Expand All @@ -504,7 +504,7 @@ struct llama_model * common_load_model_from_hf(
const struct llama_model_params & params);

// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora);

//
// Batch utils
Expand Down
27 changes: 11 additions & 16 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ struct slot_params {
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit

std::vector<common_lora_adapter_container> lora;
std::vector<common_lora_adapter_info> lora;

std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
Expand Down Expand Up @@ -198,15 +198,14 @@ struct server_task {
bool metrics_reset_bucket = false;

// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_lora_adapter_container> set_lora;
std::vector<common_lora_adapter_info> set_lora;

server_task(server_task_type type) : type(type) {}

static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx,
const common_params & params_base,
const std::vector<common_lora_adapter_container> & lora_base,
const json & data) {
slot_params params;

Expand Down Expand Up @@ -265,12 +264,12 @@ struct server_task {

if (data.contains("lora")) {
if (data.at("lora").is_array()) {
params.lora = parse_lora_request(lora_base, data.at("lora"));
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
} else {
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
}
} else {
params.lora = lora_base;
params.lora = params_base.lora_adapters;
}

// TODO: add more sanity checks for the input parameters
Expand Down Expand Up @@ -1132,7 +1131,7 @@ struct server_slot {

common_speculative * spec = nullptr;

std::vector<common_lora_adapter_container> lora;
std::vector<common_lora_adapter_info> lora;

// the index relative to completion multi-task request
size_t index = 0;
Expand Down Expand Up @@ -1633,8 +1632,6 @@ struct server_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;

std::vector<common_lora_adapter_container> lora;

llama_model * model_dft = nullptr;

llama_context_params cparams_dft;
Expand Down Expand Up @@ -1687,8 +1684,6 @@ struct server_context {
model = llama_init.model.get();
ctx = llama_init.context.get();

lora = std::move(llama_init.lora_adapters);

if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
return false;
Expand Down Expand Up @@ -1883,7 +1878,7 @@ struct server_context {
if (!are_lora_equal(task.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
slot.cache_tokens.clear();
slot.lora = std::move(task.params.lora);
slot.lora = task.params.lora;
}

SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
Expand Down Expand Up @@ -2577,7 +2572,7 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
lora = std::move(task.set_lora);
params_base.lora_adapters = std::move(task.set_lora);
auto res = std::make_unique<server_task_result_apply_lora>();
res->id = task.id;
queue_results.send(std::move(res));
Expand Down Expand Up @@ -3656,7 +3651,6 @@ int main(int argc, char ** argv) {
ctx_server.model,
ctx_server.ctx,
ctx_server.params_base,
ctx_server.lora,
data);
task.id_selected_slot = json_value(data, "id_slot", -1);

Expand Down Expand Up @@ -4083,8 +4077,9 @@ int main(int argc, char ** argv) {

const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array();
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
auto & lora = ctx_server.lora[i];
const auto & loras = ctx_server.params_base.lora_adapters;
for (size_t i = 0; i < loras.size(); ++i) {
auto & lora = loras[i];
result.push_back({
{"id", i},
{"path", lora.path},
Expand All @@ -4103,7 +4098,7 @@ int main(int argc, char ** argv) {
}
server_task task(SERVER_TASK_TYPE_SET_LORA);
task.id = ctx_server.queue_tasks.get_new_id();
task.set_lora = parse_lora_request(ctx_server.lora, body);
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);

Expand Down
14 changes: 7 additions & 7 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,25 +799,25 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
}

static bool are_lora_equal(
const std::vector<common_lora_adapter_container> & l1,
const std::vector<common_lora_adapter_container> & l2) {
const std::vector<common_lora_adapter_info> & l1,
const std::vector<common_lora_adapter_info> & l2) {
if (l1.size() != l2.size()) {
return false;
}
for (size_t i = 0; i < l1.size(); ++i) {
// we don't check lora.path to reduce the time complexity
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) {
return false;
}
}
return true;
}

// parse lora config from JSON request, returned a copy of base_lora with updated scale
static std::vector<common_lora_adapter_container> parse_lora_request(
const std::vector<common_lora_adapter_container> & base_lora,
// parse lora config from JSON request, returned a copy of lora_base with updated scale
static std::vector<common_lora_adapter_info> parse_lora_request(
const std::vector<common_lora_adapter_info> & lora_base,
const json & data) {
std::vector<common_lora_adapter_container> lora(base_lora);
std::vector<common_lora_adapter_info> lora(lora_base);
int max_idx = lora.size();

// clear existing value
Expand Down
1 change: 1 addition & 0 deletions src/llama-impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cinttypes>
#include <climits>
#include <cstdarg>
#include <cstring>
#include <vector>
#include <sstream>

Expand Down
2 changes: 2 additions & 0 deletions src/llama-model-loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

#include "ggml.h"

#include <array>
#include <cinttypes>
#include <cstring>
#include <future>

const char * llama_file_version_name(llama_fver version) {
Expand Down
1 change: 1 addition & 0 deletions src/llama-model-loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <cstddef>
#include <map>
#include <stdexcept>
#include <unordered_map>

using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
Expand Down

0 comments on commit 272cd0e

Please sign in to comment.