diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index c8485d93e11d0..9c0fd17826972 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -19,8 +19,6 @@ #include #include -#include -#include #include using namespace Qt::Literals::StringLiterals; @@ -207,26 +205,29 @@ void Server::start() QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &request, bool isChat) { - // We've been asked to do a completion... + // Parse JSON request QJsonParseError err; const QJsonDocument document = QJsonDocument::fromJson(request.body(), &err); if (err.error || !document.isObject()) { - std::cerr << "ERROR: invalid json in completions body" << std::endl; + std::cerr << "ERROR: invalid JSON in completions body" << std::endl; return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); } + #if defined(DEBUG) printf("/v1/completions %s\n", qPrintable(document.toJson(QJsonDocument::Indented))); fflush(stdout); #endif + const QJsonObject body = document.object(); - if (!body.contains("model")) { // required - std::cerr << "ERROR: completions contains no model" << std::endl; + if (!body.contains("model")) { + std::cerr << "ERROR: completions contain no model" << std::endl; return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); } + QJsonArray messages; if (isChat) { if (!body.contains("messages")) { - std::cerr << "ERROR: chat completions contains no messages" << std::endl; + std::cerr << "ERROR: chat completions contain no messages" << std::endl; return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); } messages = body["messages"].toArray(); @@ -236,16 +237,12 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo(); const QList modelList = ModelList::globalInstance()->selectableModelList(); for (const ModelInfo &info : modelList) { - Q_ASSERT(info.installed); - if (!info.installed) - continue; - if (modelRequested == info.name() || modelRequested == info.filename()) { + if (info.installed && (modelRequested == info.name() || modelRequested == info.filename())) { modelInfo = info; break; } } - // We only support one prompt for now QList prompts; if (body.contains("prompt")) { QJsonValue promptValue = body["prompt"]; @@ -256,91 +253,23 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re for (const QJsonValue &v : array) prompts.append(v.toString()); } - } else + } else { prompts.append(" "); - - int max_tokens = 16; - if (body.contains("max_tokens")) - max_tokens = body["max_tokens"].toInt(); - - float temperature = 1.f; - if (body.contains("temperature")) - temperature = body["temperature"].toDouble(); - - float top_p = 1.f; - if (body.contains("top_p")) - top_p = body["top_p"].toDouble(); - - float min_p = 0.f; - if (body.contains("min_p")) - min_p = body["min_p"].toDouble(); - - int n = 1; - if (body.contains("n")) - n = body["n"].toInt(); - - int logprobs = -1; // supposed to be null by default?? - if (body.contains("logprobs")) - logprobs = body["logprobs"].toInt(); - - bool echo = false; - if (body.contains("echo")) - echo = body["echo"].toBool(); - - // We currently don't support any of the following... -#if 0 - // FIXME: Need configurable reverse prompts - QList stop; - if (body.contains("stop")) { - QJsonValue stopValue = body["stop"]; - if (stopValue.isString()) - stop.append(stopValue.toString()); - else { - QJsonArray array = stopValue.toArray(); - for (QJsonValue v : array) - stop.append(v.toString()); - } } - // FIXME: QHttpServer doesn't support server-sent events - bool stream = false; - if (body.contains("stream")) - stream = body["stream"].toBool(); - - // FIXME: What does this do? - QString suffix; - if (body.contains("suffix")) - suffix = body["suffix"].toString(); - - // FIXME: We don't support - float presence_penalty = 0.f; - if (body.contains("presence_penalty")) - top_p = body["presence_penalty"].toDouble(); - - // FIXME: We don't support - float frequency_penalty = 0.f; - if (body.contains("frequency_penalty")) - top_p = body["frequency_penalty"].toDouble(); - - // FIXME: We don't support - int best_of = 1; - if (body.contains("best_of")) - logprobs = body["best_of"].toInt(); - - // FIXME: We don't need - QString user; - if (body.contains("user")) - suffix = body["user"].toString(); -#endif + int max_tokens = body.value("max_tokens").toInt(16); + float temperature = body.value("temperature").toDouble(1.0); + float top_p = body.value("top_p").toDouble(1.0); + float min_p = body.value("min_p").toDouble(0.0); + int n = body.value("n").toInt(1); + bool echo = body.value("echo").toBool(false); QString actualPrompt = prompts.first(); - // if we're a chat completion we have messages which means we need to prepend these to the prompt if (!messages.isEmpty()) { QList chats; - for (int i = 0; i < messages.count(); ++i) { - QJsonValue v = messages.at(i); - QString content = v.toObject()["content"].toString(); + for (int i = 0; i < messages.count(); ++i) { + QString content = messages.at(i).toObject()["content"].toString(); if (!content.endsWith("\n") && i < messages.count() - 1) content += "\n"; chats.append(content); @@ -348,10 +277,8 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re actualPrompt.prepend(chats.join("\n")); } - // adds prompt/response items to GUI emit requestServerNewPromptResponsePair(actualPrompt); // blocks - // load the new model if necessary setShouldBeLoaded(true); if (modelInfo.filename().isEmpty()) { @@ -362,107 +289,83 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError); } - // don't remember any context resetContext(); - const QString promptTemplate = modelInfo.promptTemplate(); - const float top_k = modelInfo.topK(); - const int n_batch = modelInfo.promptBatchSize(); - const float repeat_penalty = modelInfo.repeatPenalty(); - const int repeat_last_n = modelInfo.repeatPenaltyTokens(); + QByteArray responseData; + QTextStream stream(&responseData, QIODevice::WriteOnly); + + QString randomId = "chatcmpl-" + QUuid::createUuid().toString(QUuid::WithoutBraces).replace("-", ""); - int promptTokens = 0; - int responseTokens = 0; - QList>> responses; for (int i = 0; i < n; ++i) { - if (!promptInternal( - m_collections, - actualPrompt, - promptTemplate, - max_tokens /*n_predict*/, - top_k, - top_p, - min_p, - temperature, - n_batch, - repeat_penalty, - repeat_last_n)) { + if (!promptInternal(m_collections, + actualPrompt, + modelInfo.promptTemplate(), + max_tokens /*n_predict*/, + modelInfo.topK(), + top_p, + min_p, + temperature, + modelInfo.promptBatchSize(), + modelInfo.repeatPenalty(), + modelInfo.repeatPenaltyTokens())) { std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl; return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError); } - QString echoedPrompt = actualPrompt; - if (!echoedPrompt.endsWith("\n")) - echoedPrompt += "\n"; - responses.append(qMakePair((echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(), m_databaseResults)); - if (!promptTokens) - promptTokens += m_promptTokens; - responseTokens += m_promptResponseTokens - m_promptTokens; - if (i != n - 1) - resetResponse(); - } - QJsonObject responseObject; - responseObject.insert("id", "foobarbaz"); - responseObject.insert("object", "text_completion"); - responseObject.insert("created", QDateTime::currentSecsSinceEpoch()); - responseObject.insert("model", modelInfo.name()); + QString result = (echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(); - QJsonArray choices; + for (const QString &token : result.split(' ')) { + QJsonObject delta; + delta.insert("content", token + " "); - if (isChat) { - int index = 0; - for (const auto &r : responses) { - QString result = r.first; - QList infos = r.second; QJsonObject choice; - choice.insert("index", index++); - choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); - QJsonObject message; - message.insert("role", "assistant"); - message.insert("content", result); - choice.insert("message", message); - if (MySettings::globalInstance()->localDocsShowReferences()) { - QJsonArray references; - for (const auto &ref : infos) - references.append(resultToJson(ref)); - choice.insert("references", references); - } - choices.append(choice); - } - } else { - int index = 0; - for (const auto &r : responses) { - QString result = r.first; - QList infos = r.second; - QJsonObject choice; - choice.insert("text", result); - choice.insert("index", index++); - choice.insert("logprobs", QJsonValue::Null); // We don't support - choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); - if (MySettings::globalInstance()->localDocsShowReferences()) { - QJsonArray references; - for (const auto &ref : infos) - references.append(resultToJson(ref)); - choice.insert("references", references); - } - choices.append(choice); + choice.insert("index", i); + choice.insert("delta", delta); + + QJsonObject responseChunk; + responseChunk.insert("id", randomId); + responseChunk.insert("object", "chat.completion.chunk"); + responseChunk.insert("created", QDateTime::currentSecsSinceEpoch()); + responseChunk.insert("model", modelInfo.name()); + responseChunk.insert("choices", QJsonArray{choice}); + + stream << "data: " << QJsonDocument(responseChunk).toJson(QJsonDocument::Compact) << "\n\n"; + stream.flush(); } + + if (i != n - 1) + resetResponse(); } - responseObject.insert("choices", choices); + // Final empty delta to signify the end of the stream + QJsonObject delta; + delta.insert("content", QJsonValue::Null); - QJsonObject usage; - usage.insert("prompt_tokens", int(promptTokens)); - usage.insert("completion_tokens", int(responseTokens)); - usage.insert("total_tokens", int(promptTokens + responseTokens)); - responseObject.insert("usage", usage); + QJsonObject choice; + choice.insert("index", 0); + choice.insert("delta", delta); + choice.insert("finish_reason", "stop"); -#if defined(DEBUG) - QJsonDocument newDoc(responseObject); - printf("/v1/completions %s\n", qPrintable(newDoc.toJson(QJsonDocument::Indented))); - fflush(stdout); -#endif + QJsonObject finalChunk; + finalChunk.insert("id", randomId); + finalChunk.insert("object", "chat.completion.chunk"); + finalChunk.insert("created", QDateTime::currentSecsSinceEpoch()); + finalChunk.insert("model", modelInfo.name()); + finalChunk.insert("choices", QJsonArray{choice}); + + stream << "data: " << QJsonDocument(finalChunk).toJson(QJsonDocument::Compact) << "\n\n"; + stream << "data: [DONE]\n\n"; + stream.flush(); + + // Log the entire response data + qDebug() << "Full SSE Response:\n" << responseData; + + // Create the response + QHttpServerResponse response(responseData, QHttpServerResponder::StatusCode::Ok); + response.setHeader("Content-Type", "text/event-stream"); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); - return QHttpServerResponse(responseObject); + return response; }