Skip to content

Commit

Permalink
Merge pull request #1142 from andreadimaio/main
Browse files Browse the repository at this point in the history
Introduce observability in watsonx
  • Loading branch information
geoand authored Dec 9, 2024
2 parents 5ec8e05 + d7c304b commit b779f6b
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.TOKEN_COUNT_ESTIMATOR;

import java.util.List;
import java.util.function.Supplier;
import java.util.function.Function;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.ParameterizedType;
import org.jboss.jandex.Type;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.deployment.DotNames;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ScoringModelProviderCandidateBuildItem;
Expand All @@ -26,6 +30,7 @@
import io.quarkiverse.langchain4j.watsonx.runtime.WatsonxRecorder;
import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig;
import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxFixedRuntimeConfig;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
Expand Down Expand Up @@ -86,8 +91,8 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
? fixedRuntimeConfig.defaultConfig().mode()
: fixedRuntimeConfig.namedConfig().get(configName).mode();

Supplier<ChatLanguageModel> chatLanguageModel;
Supplier<StreamingChatLanguageModel> streamingChatLanguageModel;
Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel> chatLanguageModel;
Function<SyntheticCreationalContext<StreamingChatLanguageModel>, StreamingChatLanguageModel> streamingChatLanguageModel;

if (mode.equalsIgnoreCase("chat")) {
chatLanguageModel = recorder.chatModel(runtimeConfig, configName);
Expand All @@ -106,7 +111,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(chatLanguageModel);
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
.createWith(chatLanguageModel);

addQualifierIfNecessary(chatBuilder, configName);
beanProducer.produce(chatBuilder.done());
Expand All @@ -116,7 +123,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(chatLanguageModel);
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
.createWith(chatLanguageModel);

addQualifierIfNecessary(tokenizerBuilder, configName);
beanProducer.produce(tokenizerBuilder.done());
Expand All @@ -126,7 +135,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(streamingChatLanguageModel);
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
.createWith(streamingChatLanguageModel);

addQualifierIfNecessary(streamingBuilder, configName);
beanProducer.produce(streamingBuilder.done());
Expand Down Expand Up @@ -171,9 +182,8 @@ private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigur

/**
* When both {@code rest-client-jackson} and {@code rest-client-jsonb} are present on the classpath we need to make sure
* that Jackson is used.
* This is not a proper solution as it affects all clients, but it's better than the having the reader/writers be selected
* at random.
* that Jackson is used. This is not a proper solution as it affects all clients, but it's better than the having the
* reader/writers be selected at random.
*/
@BuildStep
public void deprioritizeJsonb(Capabilities capabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;

import java.time.Duration;
import java.util.Date;
Expand All @@ -21,10 +22,12 @@
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
Expand Down Expand Up @@ -268,7 +271,25 @@ void check_chat_streaming_model_config() throws Exception {
dev.langchain4j.data.message.UserMessage.from("UserMessage"));

var streamingResponse = new AtomicReference<AiMessage>();
streamingChatModel.generate(messages, WireMockUtil.streamingResponseHandler(streamingResponse));
streamingChatModel.generate(messages, new StreamingResponseHandler<>() {
@Override
public void onNext(String token) {
}

@Override
public void onError(Throwable error) {
fail("Streaming failed: %s".formatted(error.getMessage()), error);
}

@Override
public void onComplete(Response<AiMessage> response) {
assertEquals(FinishReason.LENGTH, response.finishReason());
assertEquals(2, response.tokenUsage().inputTokenCount());
assertEquals(14, response.tokenUsage().outputTokenCount());
assertEquals(16, response.tokenUsage().totalTokenCount());
streamingResponse.set(response.content());
}
});

await().atMost(Duration.ofMinutes(1))
.pollInterval(Duration.ofSeconds(2))
Expand All @@ -277,5 +298,6 @@ void check_chat_streaming_model_config() throws Exception {
assertThat(streamingResponse.get().text())
.isNotNull()
.isEqualTo(". I'm a beginner");

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,31 @@

import java.net.URL;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.client.api.LoggingScope;

import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi;
import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;

public abstract class Watsonx {

private static final Logger logger = Logger.getLogger(Watsonx.class);

protected final String modelId, projectId, spaceId, version;
protected final WatsonxRestApi client;
protected final List<ChatModelListener> listeners;

public Watsonx(Builder<?> builder) {
QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder()
Expand All @@ -34,6 +47,38 @@ public Watsonx(Builder<?> builder) {
this.spaceId = builder.spaceId;
this.projectId = builder.projectId;
this.version = builder.version;
this.listeners = builder.listeners;
}

protected void beforeSentRequest(ChatModelRequest request, Map<Object, Object> attributes) {
for (ChatModelListener listener : listeners) {
try {
listener.onRequest(new ChatModelRequestContext(request, attributes));
} catch (Exception e) {
logger.warn("Exception while calling model listener", e);
}
}
}

protected void afterReceivedResponse(ChatModelResponse response, ChatModelRequest request, Map<Object, Object> attributes) {
for (ChatModelListener listener : listeners) {
try {
listener.onResponse(new ChatModelResponseContext(response, request, attributes));
} catch (Exception e) {
logger.warn("Exception while calling model listener", e);
}
}
}

protected void onRequestError(Throwable error, ChatModelRequest request, ChatModelResponse partialResponse,
Map<Object, Object> attributes) {
for (ChatModelListener listener : listeners) {
try {
listener.onError(new ChatModelErrorContext(error, request, partialResponse, attributes));
} catch (Exception e) {
logger.warn("Exception while calling model listener", e);
}
}
}

public WatsonxRestApi getClient() {
Expand Down Expand Up @@ -67,6 +112,7 @@ public static abstract class Builder<T extends Builder<T>> {
protected URL url;
protected boolean logResponses;
protected boolean logRequests;
private List<ChatModelListener> listeners = Collections.emptyList();
protected WatsonxTokenGenerator tokenGenerator;

public T modelId(String modelId) {
Expand Down Expand Up @@ -99,6 +145,11 @@ public T timeout(Duration timeout) {
return (T) this;
}

public T listeners(List<ChatModelListener> listeners) {
this.listeners = listeners;
return (T) this;
}

public T tokenGenerator(WatsonxTokenGenerator tokenGenerator) {
this.tokenGenerator = tokenGenerator;
return (T) this;
Expand Down
Loading

0 comments on commit b779f6b

Please sign in to comment.