From d7c304b69c7e2a271b883fa9bcb26202eed8ba23 Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Sun, 8 Dec 2024 19:53:46 +0100 Subject: [PATCH] Introduce observability in watsonx --- .../watsonx/deployment/WatsonxProcessor.java | 28 ++-- .../GenerationAllPropertiesTest.java | 24 +++- .../langchain4j/watsonx/Watsonx.java | 51 +++++++ .../langchain4j/watsonx/WatsonxChatModel.java | 91 ++++++++++-- .../watsonx/WatsonxGenerationModel.java | 135 +++++++++++++----- .../watsonx/WatsonxTokenGenerator.java | 2 +- .../bean/TextStreamingChatResponse.java | 7 + .../watsonx/runtime/WatsonxRecorder.java | 96 ++++++++----- .../DisabledModelsWatsonRecorderTest.java | 5 +- 9 files changed, 340 insertions(+), 99 deletions(-) diff --git a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java index 13539e093..f39da5e76 100644 --- a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java +++ b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java @@ -7,15 +7,19 @@ import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.TOKEN_COUNT_ESTIMATOR; import java.util.List; -import java.util.function.Supplier; +import java.util.function.Function; import jakarta.enterprise.context.ApplicationScoped; import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.ClassType; +import org.jboss.jandex.ParameterizedType; +import org.jboss.jandex.Type; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.deployment.DotNames; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.ScoringModelProviderCandidateBuildItem; @@ -26,6 +30,7 @@ import io.quarkiverse.langchain4j.watsonx.runtime.WatsonxRecorder; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxFixedRuntimeConfig; +import io.quarkus.arc.SyntheticCreationalContext; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.Capabilities; import io.quarkus.deployment.Capability; @@ -86,8 +91,8 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon ? fixedRuntimeConfig.defaultConfig().mode() : fixedRuntimeConfig.namedConfig().get(configName).mode(); - Supplier chatLanguageModel; - Supplier streamingChatLanguageModel; + Function, ChatLanguageModel> chatLanguageModel; + Function, StreamingChatLanguageModel> streamingChatLanguageModel; if (mode.equalsIgnoreCase("chat")) { chatLanguageModel = recorder.chatModel(runtimeConfig, configName); @@ -106,7 +111,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(chatLanguageModel); + .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, + new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null)) + .createWith(chatLanguageModel); addQualifierIfNecessary(chatBuilder, configName); beanProducer.produce(chatBuilder.done()); @@ -116,7 +123,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(chatLanguageModel); + .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, + new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null)) + .createWith(chatLanguageModel); addQualifierIfNecessary(tokenizerBuilder, configName); beanProducer.produce(tokenizerBuilder.done()); @@ -126,7 +135,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(streamingChatLanguageModel); + .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, + new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null)) + .createWith(streamingChatLanguageModel); addQualifierIfNecessary(streamingBuilder, configName); beanProducer.produce(streamingBuilder.done()); @@ -171,9 +182,8 @@ private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigur /** * When both {@code rest-client-jackson} and {@code rest-client-jsonb} are present on the classpath we need to make sure - * that Jackson is used. - * This is not a proper solution as it affects all clients, but it's better than the having the reader/writers be selected - * at random. + * that Jackson is used. This is not a proper solution as it affects all clients, but it's better than the having the + * reader/writers be selected at random. */ @BuildStep public void deprioritizeJsonb(Capabilities capabilities, diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java index b3392bce0..a4076f79d 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java @@ -4,6 +4,7 @@ import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; import java.time.Duration; import java.util.Date; @@ -21,10 +22,12 @@ import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.scoring.ScoringModel; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters; @@ -268,7 +271,25 @@ void check_chat_streaming_model_config() throws Exception { dev.langchain4j.data.message.UserMessage.from("UserMessage")); var streamingResponse = new AtomicReference(); - streamingChatModel.generate(messages, WireMockUtil.streamingResponseHandler(streamingResponse)); + streamingChatModel.generate(messages, new StreamingResponseHandler<>() { + @Override + public void onNext(String token) { + } + + @Override + public void onError(Throwable error) { + fail("Streaming failed: %s".formatted(error.getMessage()), error); + } + + @Override + public void onComplete(Response response) { + assertEquals(FinishReason.LENGTH, response.finishReason()); + assertEquals(2, response.tokenUsage().inputTokenCount()); + assertEquals(14, response.tokenUsage().outputTokenCount()); + assertEquals(16, response.tokenUsage().totalTokenCount()); + streamingResponse.set(response.content()); + } + }); await().atMost(Duration.ofMinutes(1)) .pollInterval(Duration.ofSeconds(2)) @@ -277,5 +298,6 @@ void check_chat_streaming_model_config() throws Exception { assertThat(streamingResponse.get().text()) .isNotNull() .isEqualTo(". I'm a beginner"); + } } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Watsonx.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Watsonx.java index dc3a55c12..73eb300f7 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Watsonx.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Watsonx.java @@ -2,18 +2,31 @@ import java.net.URL; import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; +import org.jboss.logging.Logger; import org.jboss.resteasy.reactive.client.api.LoggingScope; +import dev.langchain4j.model.chat.listener.ChatModelErrorContext; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import dev.langchain4j.model.chat.listener.ChatModelRequest; +import dev.langchain4j.model.chat.listener.ChatModelRequestContext; +import dev.langchain4j.model.chat.listener.ChatModelResponse; +import dev.langchain4j.model.chat.listener.ChatModelResponseContext; import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; public abstract class Watsonx { + private static final Logger logger = Logger.getLogger(Watsonx.class); + protected final String modelId, projectId, spaceId, version; protected final WatsonxRestApi client; + protected final List listeners; public Watsonx(Builder builder) { QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder() @@ -34,6 +47,38 @@ public Watsonx(Builder builder) { this.spaceId = builder.spaceId; this.projectId = builder.projectId; this.version = builder.version; + this.listeners = builder.listeners; + } + + protected void beforeSentRequest(ChatModelRequest request, Map attributes) { + for (ChatModelListener listener : listeners) { + try { + listener.onRequest(new ChatModelRequestContext(request, attributes)); + } catch (Exception e) { + logger.warn("Exception while calling model listener", e); + } + } + } + + protected void afterReceivedResponse(ChatModelResponse response, ChatModelRequest request, Map attributes) { + for (ChatModelListener listener : listeners) { + try { + listener.onResponse(new ChatModelResponseContext(response, request, attributes)); + } catch (Exception e) { + logger.warn("Exception while calling model listener", e); + } + } + } + + protected void onRequestError(Throwable error, ChatModelRequest request, ChatModelResponse partialResponse, + Map attributes) { + for (ChatModelListener listener : listeners) { + try { + listener.onError(new ChatModelErrorContext(error, request, partialResponse, attributes)); + } catch (Exception e) { + logger.warn("Exception while calling model listener", e); + } + } } public WatsonxRestApi getClient() { @@ -67,6 +112,7 @@ public static abstract class Builder> { protected URL url; protected boolean logResponses; protected boolean logRequests; + private List listeners = Collections.emptyList(); protected WatsonxTokenGenerator tokenGenerator; public T modelId(String modelId) { @@ -99,6 +145,11 @@ public T timeout(Duration timeout) { return (T) this; } + public T listeners(List listeners) { + this.listeners = listeners; + return (T) this; + } + public T tokenGenerator(WatsonxTokenGenerator tokenGenerator) { this.tokenGenerator = tokenGenerator; return (T) this; diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java index a600376ac..37a8369e9 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java @@ -4,7 +4,9 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -16,6 +18,8 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; +import dev.langchain4j.model.chat.listener.ChatModelRequest; +import dev.langchain4j.model.chat.listener.ChatModelResponse; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -35,6 +39,7 @@ public class WatsonxChatModel extends Watsonx implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator { + private static final String ID_CONTEXT = "ID"; private static final String USAGE_CONTEXT = "USAGE"; private static final String FINISH_REASON_CONTEXT = "FINISH_REASON"; private static final String ROLE_CONTEXT = "ROLE"; @@ -67,12 +72,20 @@ public Response generate(List messages, List attributes = new ConcurrentHashMap<>(); + ChatModelRequest chatModelRequest = toChatModelRequest(messages, toolSpecifications); + beforeSentRequest(chatModelRequest, attributes); + TextChatRequest request = new TextChatRequest(modelId, spaceId, projectId, convertedMessages, tools, null, parameters); TextChatResponse response = retryOn(new Callable() { @Override public TextChatResponse call() throws Exception { - return client.chat(request, version); + try { + return client.chat(request, version); + } catch (RuntimeException exception) { + onRequestError(exception, chatModelRequest, null, attributes); + throw exception; + } } }); @@ -80,11 +93,11 @@ public TextChatResponse call() throws Exception { TextChatResultMessage message = choice.message(); TextChatUsage usage = response.usage(); - AiMessage content; + AiMessage aiMessage; if (message.toolCalls() != null && message.toolCalls().size() > 0) { - content = AiMessage.from(message.toolCalls().stream().map(TextChatToolCall::convert).toList()); + aiMessage = AiMessage.from(message.toolCalls().stream().map(TextChatToolCall::convert).toList()); } else { - content = AiMessage.from(message.content().trim()); + aiMessage = AiMessage.from(message.content().trim()); } var finishReason = toFinishReason(choice.finishReason()); @@ -93,7 +106,9 @@ public TextChatResponse call() throws Exception { usage.completionTokens(), usage.totalTokens()); - return Response.from(content, tokenUsage, finishReason); + ChatModelResponse chatModelResponse = toChatModelResponse(response.id(), aiMessage, tokenUsage, finishReason); + afterReceivedResponse(chatModelResponse, chatModelRequest, attributes); + return Response.from(aiMessage, tokenUsage, finishReason); } @Override @@ -104,6 +119,10 @@ public void generate(List messages, List toolSpe ? toolSpecifications.stream().map(TextChatParameterTool::of).toList() : null; + Map attributes = new ConcurrentHashMap<>(); + ChatModelRequest chatModelRequest = toChatModelRequest(messages, toolSpecifications); + beforeSentRequest(chatModelRequest, attributes); + TextChatRequest request = new TextChatRequest(modelId, spaceId, projectId, convertedMessages, tools, null, parameters); Context context = Context.empty(); context.put(TOOLS_CONTEXT, new ArrayList()); @@ -125,6 +144,10 @@ public void accept(TextStreamingChatResponse chunk) { var message = chunk.choices().get(0); + if (!context.contains(ID_CONTEXT) && chunk.id() != null) { + context.put(ID_CONTEXT, chunk.id()); + } + if (message.finishReason() != null) { context.put(FINISH_REASON_CONTEXT, message.finishReason()); } @@ -183,6 +206,27 @@ public void accept(TextStreamingChatResponse chunk) { new Consumer() { @Override public void accept(Throwable error) { + + AiMessage aiMessage = null; + TokenUsage tokenUsage = null; + String id = context.contains(ID_CONTEXT) ? context.get(ID_CONTEXT) : null; + FinishReason finishReason = context.contains(FINISH_REASON_CONTEXT) + ? toFinishReason(context.get(FINISH_REASON_CONTEXT)) + : null; + + if (context.contains(COMPLETE_MESSAGE_CONTEXT)) { + StringBuilder message = context.get(COMPLETE_MESSAGE_CONTEXT); + aiMessage = AiMessage.from(message.toString()); + } + + if (context.contains(USAGE_CONTEXT)) { + TextStreamingChatResponse.TextChatUsage textChatUsage = context.get(USAGE_CONTEXT); + tokenUsage = textChatUsage.toTokenUsage(); + } + + ChatModelResponse chatModelResponse = toChatModelResponse(id, aiMessage, tokenUsage, + finishReason); + onRequestError(error, chatModelRequest, chatModelResponse, attributes); handler.onError(error); } }, @@ -190,11 +234,9 @@ public void accept(Throwable error) { @Override public void run() { - TextStreamingChatResponse.TextChatUsage usage = context.get(USAGE_CONTEXT); - TokenUsage tokenUsage = new TokenUsage( - usage.promptTokens(), - usage.completionTokens(), - usage.totalTokens()); + String id = context.get(ID_CONTEXT); + TextStreamingChatResponse.TextChatUsage textChatUsage = context.get(USAGE_CONTEXT); + TokenUsage tokenUsage = textChatUsage.toTokenUsage(); String finishReason = context.get(FINISH_REASON_CONTEXT); FinishReason finishReasonObj = toFinishReason(finishReason); @@ -211,6 +253,11 @@ public void run() { } else { StringBuilder message = context.get(COMPLETE_MESSAGE_CONTEXT); + AiMessage aiMessage = AiMessage.from(message.toString()); + ChatModelResponse chatModelResponse = toChatModelResponse(id, aiMessage, tokenUsage, + finishReasonObj); + + afterReceivedResponse(chatModelResponse, chatModelRequest, attributes); handler.onComplete( Response.from(AiMessage.from(message.toString()), tokenUsage, finishReasonObj)); } @@ -256,6 +303,28 @@ public static Builder builder() { return new Builder(); } + private ChatModelRequest toChatModelRequest(List messages, List toolSpecifications) { + return ChatModelRequest.builder() + .maxTokens(parameters.getMaxTokens()) + .messages(messages) + .model(modelId) + .temperature(parameters.getTemperature()) + .toolSpecifications(toolSpecifications) + .topP(parameters.getTopP()) + .build(); + } + + private ChatModelResponse toChatModelResponse(String id, AiMessage aiMessage, TokenUsage tokenUsage, + FinishReason finishReason) { + return ChatModelResponse.builder() + .aiMessage(aiMessage) + .finishReason(finishReason) + .id(id) + .model(modelId) + .tokenUsage(tokenUsage) + .build(); + } + private FinishReason toFinishReason(String reason) { if (reason == null) return FinishReason.OTHER; diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java index 554686e0c..dd8d6e073 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java @@ -3,10 +3,11 @@ import static io.quarkiverse.langchain4j.watsonx.WatsonxUtils.retryOn; import static java.util.stream.Collectors.joining; -import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.function.Function; @@ -18,6 +19,8 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; +import dev.langchain4j.model.chat.listener.ChatModelRequest; +import dev.langchain4j.model.chat.listener.ChatModelResponse; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -32,6 +35,11 @@ public class WatsonxGenerationModel extends Watsonx implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator { + private static final String INPUT_TOKEN_COUNT_CONTEXT = "INPUT_TOKEN_COUNT"; + private static final String GENERATED_TOKEN_COUNT_CONTEXT = "GENERATED_TOKEN_COUNT"; + private static final String COMPLETE_MESSAGE_CONTEXT = "COMPLETE_MESSAGE"; + private static final String FINISH_REASON_CONTEXT = "FINISH_REASON"; + private final TextGenerationParameters parameters; private final String promptJoiner; @@ -65,12 +73,21 @@ public WatsonxGenerationModel(Builder builder) { @Override public Response generate(List messages) { + Map attributes = new ConcurrentHashMap<>(); + ChatModelRequest chatModelRequest = toChatModelRequest(messages); + beforeSentRequest(chatModelRequest, attributes); + Result result = retryOn(new Callable() { @Override public TextGenerationResponse call() throws Exception { TextGenerationRequest request = new TextGenerationRequest(modelId, spaceId, projectId, toInput(messages), parameters); - return client.generation(request, version); + try { + return client.generation(request, version); + } catch (RuntimeException exception) { + onRequestError(exception, chatModelRequest, null, attributes); + throw exception; + } } }).results().get(0); @@ -79,16 +96,24 @@ public TextGenerationResponse call() throws Exception { result.inputTokenCount(), result.generatedTokenCount()); - AiMessage content = AiMessage.from(result.generatedText()); - return Response.from(content, tokenUsage, finishReason); + AiMessage aiMessage = AiMessage.from(result.generatedText()); + ChatModelResponse chatModelResponse = toChatModelResponse(null, aiMessage, tokenUsage, finishReason); + afterReceivedResponse(chatModelResponse, chatModelRequest, attributes); + return Response.from(aiMessage, tokenUsage, finishReason); } @Override public void generate(List messages, StreamingResponseHandler handler) { TextGenerationRequest request = new TextGenerationRequest(modelId, spaceId, projectId, toInput(messages), parameters); + Map attributes = new ConcurrentHashMap<>(); + ChatModelRequest chatModelRequest = toChatModelRequest(messages); + beforeSentRequest(chatModelRequest, attributes); + Context context = Context.empty(); - context.put("response", new ArrayList()); + context.put(COMPLETE_MESSAGE_CONTEXT, new StringBuilder()); + context.put(INPUT_TOKEN_COUNT_CONTEXT, 0); + context.put(GENERATED_TOKEN_COUNT_CONTEXT, 0); client.generationStreaming(request, version) .subscribe() @@ -101,14 +126,22 @@ public void accept(TextGenerationResponse response) { if (response == null || response.results() == null || response.results().isEmpty()) return; - String chunk = response.results().get(0).generatedText(); + StringBuilder stringBuilder = context.get(COMPLETE_MESSAGE_CONTEXT); + Result chunk = response.results().get(0); - if (chunk.isEmpty()) - return; + if (!chunk.stopReason().equals("not_finished")) { + context.put(FINISH_REASON_CONTEXT, chunk.stopReason()); + } + + int inputTokenCount = context.get(INPUT_TOKEN_COUNT_CONTEXT); + context.put(INPUT_TOKEN_COUNT_CONTEXT, inputTokenCount + chunk.inputTokenCount()); - List responses = context.get("response"); - responses.add(response); - handler.onNext(chunk); + int generatedTokenCount = context.get(GENERATED_TOKEN_COUNT_CONTEXT); + context.put(GENERATED_TOKEN_COUNT_CONTEXT, + generatedTokenCount + chunk.generatedTokenCount()); + + stringBuilder.append(chunk.generatedText()); + handler.onNext(chunk.generatedText()); } catch (Exception e) { handler.onError(e); @@ -118,41 +151,48 @@ public void accept(TextGenerationResponse response) { new Consumer() { @Override public void accept(Throwable error) { + + StringBuilder response = context.get(COMPLETE_MESSAGE_CONTEXT); + FinishReason finishReason = context.contains(FINISH_REASON_CONTEXT) + ? toFinishReason(context.get(FINISH_REASON_CONTEXT)) + : null; + int inputTokenCount = context.contains(INPUT_TOKEN_COUNT_CONTEXT) + ? context.get(INPUT_TOKEN_COUNT_CONTEXT) + : 0; + int generatedTokenCount = context.contains(GENERATED_TOKEN_COUNT_CONTEXT) + ? context.get(GENERATED_TOKEN_COUNT_CONTEXT) + : 0; + + AiMessage aiMessage = AiMessage.from(response.toString()); + TokenUsage tokenUsage = new TokenUsage(inputTokenCount, generatedTokenCount); + ChatModelResponse chatModelResponse = toChatModelResponse(null, aiMessage, tokenUsage, + finishReason); + onRequestError(error, chatModelRequest, chatModelResponse, attributes); handler.onError(error); } }, new Runnable() { @Override public void run() { - List list = context.get("response"); - - int inputTokenCount = 0; - int outputTokenCount = 0; - String stopReason = null; - StringBuilder builder = new StringBuilder(); - - for (int i = 0; i < list.size(); i++) { - - TextGenerationResponse.Result response = list.get(i).results().get(0); - if (i == 0) - inputTokenCount = response.inputTokenCount(); - - if (i == list.size() - 1) { - outputTokenCount = response.generatedTokenCount(); - stopReason = response.stopReason(); - } - - builder.append(response.generatedText()); - } - - AiMessage content; + StringBuilder response = context.get(COMPLETE_MESSAGE_CONTEXT); + FinishReason finishReason = context.contains(FINISH_REASON_CONTEXT) + ? toFinishReason(context.get(FINISH_REASON_CONTEXT)) + : null; + int inputTokenCount = context.contains(INPUT_TOKEN_COUNT_CONTEXT) + ? context.get(INPUT_TOKEN_COUNT_CONTEXT) + : 0; + int outputTokenCount = context.contains(GENERATED_TOKEN_COUNT_CONTEXT) + ? context.get(GENERATED_TOKEN_COUNT_CONTEXT) + : 0; + + AiMessage aiMessage = AiMessage.from(response.toString()); TokenUsage tokenUsage = new TokenUsage(inputTokenCount, outputTokenCount); - FinishReason finishReason = toFinishReason(stopReason); + ChatModelResponse chatModelResponse = toChatModelResponse(null, aiMessage, tokenUsage, + finishReason); - String message = builder.toString(); - content = AiMessage.from(message); - handler.onComplete(Response.from(content, tokenUsage, finishReason)); + afterReceivedResponse(chatModelResponse, chatModelRequest, attributes); + handler.onComplete(Response.from(aiMessage, tokenUsage, finishReason)); } }); } @@ -170,6 +210,27 @@ public Integer call() throws Exception { }); } + private ChatModelRequest toChatModelRequest(List messages) { + return ChatModelRequest.builder() + .maxTokens(parameters.getMaxNewTokens()) + .messages(messages) + .model(modelId) + .temperature(parameters.getTemperature()) + .topP(parameters.getTopP()) + .build(); + } + + private ChatModelResponse toChatModelResponse(String id, AiMessage aiMessage, TokenUsage tokenUsage, + FinishReason finishReason) { + return ChatModelResponse.builder() + .aiMessage(aiMessage) + .finishReason(finishReason) + .id(id) + .model(modelId) + .tokenUsage(tokenUsage) + .build(); + } + public static Builder builder() { return new Builder(); } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxTokenGenerator.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxTokenGenerator.java index 30b3e49c8..6832fbedb 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxTokenGenerator.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxTokenGenerator.java @@ -15,7 +15,7 @@ public class WatsonxTokenGenerator { - private final static Semaphore lock = new Semaphore(1); + private final Semaphore lock = new Semaphore(1); private final IAMRestApi client; private final String apiKey; private final String grantType; diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextStreamingChatResponse.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextStreamingChatResponse.java index 8d0272461..5a56ec8c0 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextStreamingChatResponse.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextStreamingChatResponse.java @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.model.output.TokenUsage; import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatToolCall; public record TextStreamingChatResponse(String id, String modelId, List choices, Long created, @@ -17,6 +18,12 @@ public record TextChatResultChoice(Integer index, TextChatResultMessage delta, S @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record TextChatUsage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { + public TokenUsage toTokenUsage() { + return new TokenUsage( + promptTokens, + completionTokens, + totalTokens); + } } @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java index e7cafa653..4e0fbee75 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java @@ -11,10 +11,14 @@ import java.util.function.Function; import java.util.function.Supplier; +import jakarta.enterprise.inject.Instance; +import jakarta.enterprise.util.TypeLiteral; + import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.DisabledChatLanguageModel; import dev.langchain4j.model.chat.DisabledStreamingChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.listener.ChatModelListener; import dev.langchain4j.model.embedding.DisabledEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.scoring.ScoringModel; @@ -30,6 +34,7 @@ import io.quarkiverse.langchain4j.watsonx.runtime.config.IAMConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.ScoringModelConfig; +import io.quarkus.arc.SyntheticCreationalContext; import io.quarkus.runtime.annotations.Recorder; import io.smallrye.config.ConfigValidationException; @@ -38,80 +43,94 @@ public class WatsonxRecorder { private static final Map tokenGeneratorCache = new HashMap<>(); private static final ConfigValidationException.Problem[] EMPTY_PROBLEMS = new ConfigValidationException.Problem[0]; + private static final TypeLiteral> CHAT_MODEL_LISTENER_TYPE_LITERAL = new TypeLiteral<>() { + }; - public Supplier chatModel(LangChain4jWatsonxConfig runtimeConfig, String configName) { + public Function, ChatLanguageModel> chatModel( + LangChain4jWatsonxConfig runtimeConfig, String configName) { LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); + String apiKey = firstOrDefault(null, watsonRuntimeConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); if (watsonRuntimeConfig.enableIntegration()) { var builder = chatBuilder(runtimeConfig, configName); - return new Supplier<>() { + + return new Function<>() { @Override - public ChatLanguageModel get() { - return builder.build(); + public ChatLanguageModel apply(SyntheticCreationalContext context) { + return builder + .tokenGenerator(createTokenGenerator(watsonRuntimeConfig.iam(), apiKey)) + .listeners(context.getInjectedReference(CHAT_MODEL_LISTENER_TYPE_LITERAL).stream().toList()) + .build(); } }; } else { - return new Supplier<>() { - + return new Function<>() { @Override - public ChatLanguageModel get() { + public ChatLanguageModel apply(SyntheticCreationalContext context) { return new DisabledChatLanguageModel(); } - }; } } - public Supplier streamingChatModel(LangChain4jWatsonxConfig runtimeConfig, + public Function, StreamingChatLanguageModel> streamingChatModel( + LangChain4jWatsonxConfig runtimeConfig, String configName) { LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); + String apiKey = firstOrDefault(null, watsonRuntimeConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); if (watsonRuntimeConfig.enableIntegration()) { var builder = chatBuilder(runtimeConfig, configName); - return new Supplier<>() { + + return new Function<>() { @Override - public StreamingChatLanguageModel get() { - return builder.build(); + public StreamingChatLanguageModel apply(SyntheticCreationalContext context) { + return builder + .tokenGenerator(createTokenGenerator(watsonRuntimeConfig.iam(), apiKey)) + .listeners(context.getInjectedReference(CHAT_MODEL_LISTENER_TYPE_LITERAL).stream().toList()) + .build(); } }; } else { - return new Supplier<>() { - + return new Function<>() { @Override - public StreamingChatLanguageModel get() { + public StreamingChatLanguageModel apply(SyntheticCreationalContext context) { return new DisabledStreamingChatLanguageModel(); } - }; } } - public Supplier generationModel(LangChain4jWatsonxConfig runtimeConfig, + public Function, ChatLanguageModel> generationModel( + LangChain4jWatsonxConfig runtimeConfig, String configName) { LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); + String apiKey = firstOrDefault(null, watsonRuntimeConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); if (watsonRuntimeConfig.enableIntegration()) { var builder = generationBuilder(runtimeConfig, configName); - return new Supplier<>() { + return new Function<>() { @Override - public ChatLanguageModel get() { - return builder.build(); + public ChatLanguageModel apply(SyntheticCreationalContext context) { + return builder + .tokenGenerator(createTokenGenerator(watsonRuntimeConfig.iam(), apiKey)) + .listeners(context.getInjectedReference(CHAT_MODEL_LISTENER_TYPE_LITERAL).stream().toList()) + .build(); } }; } else { - return new Supplier<>() { - + return new Function<>() { @Override - public ChatLanguageModel get() { + public ChatLanguageModel apply(SyntheticCreationalContext context) { return new DisabledChatLanguageModel(); } @@ -119,26 +138,30 @@ public ChatLanguageModel get() { } } - public Supplier generationStreamingModel(LangChain4jWatsonxConfig runtimeConfig, + public Function, StreamingChatLanguageModel> generationStreamingModel( + LangChain4jWatsonxConfig runtimeConfig, String configName) { LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); + String apiKey = firstOrDefault(null, watsonRuntimeConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); if (watsonRuntimeConfig.enableIntegration()) { var builder = generationBuilder(runtimeConfig, configName); - return new Supplier<>() { + return new Function<>() { @Override - public StreamingChatLanguageModel get() { - return builder.build(); + public StreamingChatLanguageModel apply(SyntheticCreationalContext context) { + return builder + .tokenGenerator(createTokenGenerator(watsonRuntimeConfig.iam(), apiKey)) + .listeners(context.getInjectedReference(CHAT_MODEL_LISTENER_TYPE_LITERAL).stream().toList()) + .build(); } }; } else { - return new Supplier<>() { - + return new Function<>() { @Override - public StreamingChatLanguageModel get() { + public StreamingChatLanguageModel apply(SyntheticCreationalContext context) { return new DisabledStreamingChatLanguageModel(); } @@ -167,7 +190,6 @@ public Supplier embeddingModel(LangChain4jWatsonxConfig runtimeC EmbeddingModelConfig embeddingModelConfig = watsonConfig.embeddingModel(); var builder = WatsonxEmbeddingModel.builder() - .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) .url(url) .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), watsonConfig.logRequests())) @@ -181,7 +203,9 @@ public Supplier embeddingModel(LangChain4jWatsonxConfig runtimeC return new Supplier<>() { @Override public WatsonxEmbeddingModel get() { - return builder.build(); + return builder + .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) + .build(); } }; @@ -217,7 +241,6 @@ public Supplier scoringModel(LangChain4jWatsonxConfig runtimeConfi ScoringModelConfig rerankModelConfig = watsonConfig.scoringModel(); var builder = WatsonxScoringModel.builder() - .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) .url(url) .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(firstOrDefault(false, rerankModelConfig.logRequests(), watsonConfig.logRequests())) @@ -231,7 +254,9 @@ public Supplier scoringModel(LangChain4jWatsonxConfig runtimeConfi return new Supplier<>() { @Override public WatsonxScoringModel get() { - return builder.build(); + return builder + .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) + .build(); } }; } @@ -246,7 +271,6 @@ private WatsonxChatModel.Builder chatBuilder(LangChain4jWatsonxConfig runtimeCon } ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - String apiKey = firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); URL url; try { @@ -256,7 +280,6 @@ private WatsonxChatModel.Builder chatBuilder(LangChain4jWatsonxConfig runtimeCon } return WatsonxChatModel.builder() - .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) .url(url) .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), watsonConfig.logRequests())) @@ -300,7 +323,6 @@ private WatsonxGenerationModel.Builder generationBuilder(LangChain4jWatsonxConfi String promptJoiner = generationModelConfig.promptJoiner(); return WatsonxGenerationModel.builder() - .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) .url(url) .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(firstOrDefault(false, generationModelConfig.logRequests(), watsonConfig.logRequests())) diff --git a/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java b/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java index d778c3368..4b2704c50 100644 --- a/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java +++ b/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java @@ -32,13 +32,12 @@ void setupMocks() { @Test void disabledChatModel() { assertThat(recorder - .generationModel(runtimeConfig, NamedConfigUtil.DEFAULT_NAME) - .get()) + .generationModel(runtimeConfig, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledChatLanguageModel.class); assertThat( - recorder.generationStreamingModel(runtimeConfig, NamedConfigUtil.DEFAULT_NAME).get()) + recorder.generationStreamingModel(runtimeConfig, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledStreamingChatLanguageModel.class);