Skip to content

Commit

Permalink
Update to LangChain4j 0.34
Browse files Browse the repository at this point in the history
Closes: #873
  • Loading branch information
geoand authored and jmartisk committed Sep 6, 2024
1 parent d12d6f7 commit c605e0f
Show file tree
Hide file tree
Showing 72 changed files with 56 additions and 828 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
boolean needsMemorySeed = needsMemorySeed(context, memoryId); // we need to know figure this out before we add the system and user message

boolean hasMethodSpecificTools = methodCreateInfo.getToolClassNames() != null
&& !methodCreateInfo.getToolClassNames().isEmpty();
List<ToolSpecification> toolSpecifications = hasMethodSpecificTools ? methodCreateInfo.getToolSpecifications()
: context.toolSpecifications;
Map<String, ToolExecutor> toolExecutors = hasMethodSpecificTools ? methodCreateInfo.getToolExecutors()
: context.toolExecutors;

Type returnType = methodCreateInfo.getReturnType();
AugmentationResult augmentationResult = null;
if (context.retrievalAugmentor != null) {
Expand Down Expand Up @@ -163,7 +170,12 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, (UserMessage) augmentedUserMessage,
context.chatMemory(memoryId), ar);
List<ChatMessage> messagesToSend = messagesToSend(augmentedUserMessage, needsMemorySeed);
return Multi.createFrom().emitter(new MultiEmitterConsumer(messagesToSend, context, memoryId));
return Multi.createFrom()
.emitter(new MultiEmitterConsumer(messagesToSend, toolSpecifications,
toolExecutors,
ar.contents(),
context,
memoryId));
}

private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
Expand Down Expand Up @@ -205,27 +217,21 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
if (isTokenStream(returnType)) {
// TODO Indicate the output guardrails cannot be used when streaming
chatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse
return new AiServiceTokenStream(messagesToSend, context, memoryId);
return new AiServiceTokenStream(messagesToSend, toolSpecifications, toolExecutors,
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId);
}

if (isMulti(returnType)) {
// TODO Indicate the output guardrails cannot be used when streaming
chatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse
return Multi.createFrom().emitter(new MultiEmitterConsumer(messagesToSend, context, memoryId));
return Multi.createFrom().emitter(new MultiEmitterConsumer(messagesToSend, toolSpecifications,
toolExecutors, augmentationResult != null ? augmentationResult.contents() : null, context, memoryId));
}

Future<Moderation> moderationFuture = triggerModerationIfNeeded(context, methodCreateInfo, messagesToSend);

log.debug("Attempting to obtain AI response");

List<ToolSpecification> toolSpecifications = context.toolSpecifications;
Map<String, ToolExecutor> toolExecutors = context.toolExecutors;
// override with method specific info
if (methodCreateInfo.getToolClassNames() != null && !methodCreateInfo.getToolClassNames().isEmpty()) {
toolSpecifications = methodCreateInfo.getToolSpecifications();
toolExecutors = methodCreateInfo.getToolExecutors();
}

Response<AiMessage> response = toolSpecifications == null
? context.chatModel.generate(messagesToSend)
: context.chatModel.generate(messagesToSend, toolSpecifications);
Expand Down Expand Up @@ -572,19 +578,29 @@ public interface Wrapper {

private static class MultiEmitterConsumer implements Consumer<MultiEmitter<? super String>> {
private final List<ChatMessage> messagesToSend;
private final List<ToolSpecification> toolSpecifications;
private final Map<String, ToolExecutor> toolExecutors;
private final List<dev.langchain4j.rag.content.Content> contents;
private final QuarkusAiServiceContext context;
private final Object memoryId;

public MultiEmitterConsumer(List<ChatMessage> messagesToSend, QuarkusAiServiceContext context,
public MultiEmitterConsumer(List<ChatMessage> messagesToSend,
List<ToolSpecification> toolSpecifications,
Map<String, ToolExecutor> toolExecutors,
List<dev.langchain4j.rag.content.Content> contents,
QuarkusAiServiceContext context,
Object memoryId) {
this.messagesToSend = messagesToSend;
this.toolSpecifications = toolSpecifications;
this.toolExecutors = toolExecutors;
this.contents = contents;
this.context = context;
this.memoryId = memoryId;
}

@Override
public void accept(MultiEmitter<? super String> em) {
new AiServiceTokenStream(messagesToSend, context, memoryId)
new AiServiceTokenStream(messagesToSend, toolSpecifications, toolExecutors, contents, context, memoryId)
.onNext(em::emit)
.onComplete(new Consumer<>() {
@Override
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/ROOT/pages/includes/attributes.adoc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
:project-version: 0.17.2
:langchain4j-version: 0.33.0
:langchain4j-version: 0.34.0
:examples-dir: ./../examples/
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ void testSearch() {
Embedding matchingEmbedding = this.embeddingModel().embed(textSegment).content();
this.embeddingStore().add(matchingEmbedding, textSegment);
});
this.awaitUntilPersisted();
long totalIngestTime = System.currentTimeMillis() - startTime;
Log.debugf("End Ingesting %s embeddings in %d ms.", nbRows, totalIngestTime);
Embedding queryEmbedding = this.embeddingModel().embed("matching").content();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ protected EmbeddingModel embeddingModel() {
return embeddingModel;
}

@Override
protected void awaitUntilPersisted() {
delay();
}

@Override
protected void clearStore() {
Log.info("About to delete all embeddings");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -23,6 +24,8 @@
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.ResponseFormat;
import dev.ai4j.openai4j.chat.ResponseFormatType;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand Down Expand Up @@ -69,7 +72,7 @@ public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstima
private final Double frequencyPenalty;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final String responseFormat;
private final ResponseFormat responseFormat;
private final List<ChatModelListener> listeners;

public AzureOpenAiChatModel(String endpoint,
Expand Down Expand Up @@ -119,7 +122,11 @@ public AzureOpenAiChatModel(String endpoint,
throw new IllegalArgumentException("max-retries must be at least 1");
}
this.tokenizer = tokenizer;
this.responseFormat = responseFormat;
this.responseFormat = responseFormat == null ? null
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();
;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import dev.ai4j.openai4j.SyncOrAsync;
import dev.ai4j.openai4j.image.GenerateImagesRequest;
import dev.ai4j.openai4j.image.GenerateImagesResponse;
import dev.ai4j.openai4j.image.ImageData;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;
Expand Down Expand Up @@ -106,7 +107,7 @@ private void persistIfNecessary(GenerateImagesResponse response) {
throw new UncheckedIOException(e);
}

for (GenerateImagesResponse.ImageData data : response.data()) {
for (ImageData data : response.data()) {
try {
data.url(
data.url() != null
Expand All @@ -118,7 +119,7 @@ private void persistIfNecessary(GenerateImagesResponse response) {
}
}

private static Image fromImageData(GenerateImagesResponse.ImageData data) {
private static Image fromImageData(ImageData data) {
return Image.builder().url(data.url()).base64Data(data.b64Json()).revisedPrompt(data.revisedPrompt()).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
import java.net.Proxy;
import java.time.Duration;
import java.util.List;
import java.util.Locale;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.ResponseFormat;
import dev.ai4j.openai4j.chat.ResponseFormatType;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand Down Expand Up @@ -58,7 +61,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
private final Double presencePenalty;
private final Double frequencyPenalty;
private final Tokenizer tokenizer;
private final String responseFormat;
private final ResponseFormat responseFormat;

public AzureOpenAiStreamingChatModel(String endpoint,
String apiVersion,
Expand Down Expand Up @@ -100,7 +103,11 @@ public AzureOpenAiStreamingChatModel(String endpoint,
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.tokenizer = tokenizer;
this.responseFormat = responseFormat;
this.responseFormat = responseFormat == null ? null
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();
;
}

@Override
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit c605e0f

Please sign in to comment.