Skip to content

Commit

Permalink
#5211 - AI assistant prototype
Browse files Browse the repository at this point in the history
- Factor retrievers out into own classes and introduce an extension point for them
- Introduce a chunker interface
- Add source information to chunks
  • Loading branch information
reckart committed Jan 2, 2025
1 parent 0b9f16f commit 2c20c88
Show file tree
Hide file tree
Showing 21 changed files with 798 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

public interface AssistantService
{
List<MAssistantTextMessage> getConversationMessages(String aSessionOwner, Project aProject);
List<MAssistantTextMessage> getAllChatMessages(String aSessionOwner, Project aProject);

List<MAssistantTextMessage> getChatMessages(String aSessionOwner, Project aProject);

void processUserMessage(String aSessionOwner, Project aProject,
MAssistantTextMessage aMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.time.format.FormatStyle;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
Expand All @@ -52,11 +48,10 @@
import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.security.model.User;
import de.tudarmstadt.ukp.inception.assistant.config.AssistantProperties;
import de.tudarmstadt.ukp.inception.assistant.documents.DocumentQueryService;
import de.tudarmstadt.ukp.inception.assistant.model.MAssistantClearCommand;
import de.tudarmstadt.ukp.inception.assistant.model.MAssistantMessage;
import de.tudarmstadt.ukp.inception.assistant.model.MAssistantTextMessage;
import de.tudarmstadt.ukp.inception.assistant.userguide.UserGuideQueryService;
import de.tudarmstadt.ukp.inception.assistant.retriever.RetrieverExtensionPoint;
import de.tudarmstadt.ukp.inception.project.api.event.AfterProjectRemovedEvent;
import de.tudarmstadt.ukp.inception.project.api.event.BeforeProjectRemovedEvent;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaChatMessage;
Expand All @@ -75,24 +70,21 @@ public class AssistantServiceImpl
private final ConcurrentMap<AssistentStateKey, AssistentState> states;
private final OllamaClient ollamaClient;
private final AssistantProperties properties;
private final UserGuideQueryService documentationIndexingService;
private final DocumentQueryService documentQueryService;
private final EncodingRegistry encodingRegistry;
private final RetrieverExtensionPoint retrieverExtensionPoint;

public AssistantServiceImpl(SessionRegistry aSessionRegistry,
SimpMessagingTemplate aMsgTemplate, OllamaClient aOllamaClient,
AssistantProperties aProperties, UserGuideQueryService aDocumentationIndexingService,
DocumentQueryService aDocumentQueryService,
EncodingRegistry aEncodingRegistry)
AssistantProperties aProperties, EncodingRegistry aEncodingRegistry,
RetrieverExtensionPoint aRetrieverExtensionPoint)
{
sessionRegistry = aSessionRegistry;
msgTemplate = aMsgTemplate;
states = new ConcurrentHashMap<>();
ollamaClient = aOllamaClient;
properties = aProperties;
documentationIndexingService = aDocumentationIndexingService;
documentQueryService = aDocumentQueryService;
encodingRegistry = aEncodingRegistry;
retrieverExtensionPoint = aRetrieverExtensionPoint;
}

// Set order so this is handled before session info is removed from sessionRegistry
Expand Down Expand Up @@ -133,13 +125,30 @@ public void onAfterProjectRemoved(AfterProjectRemovedEvent aEvent)
}

@Override
public List<MAssistantTextMessage> getConversationMessages(String aSessionOwner, Project aProject)
public List<MAssistantTextMessage> getAllChatMessages(String aSessionOwner, Project aProject)
{
var state = getState(aSessionOwner, aProject);
return state.getMessages();

return state.getMessages().stream() //
.filter(MAssistantTextMessage.class::isInstance) //
.map(MAssistantTextMessage.class::cast) //
.toList();
}

@Override
public List<MAssistantTextMessage> getChatMessages(String aSessionOwner, Project aProject)
{
var state = getState(aSessionOwner, aProject);

// In dev mode, we also record internal messages, so we need to filter them out again here
return state.getMessages().stream() //
.filter(MAssistantTextMessage.class::isInstance) //
.map(MAssistantTextMessage.class::cast) //
.filter(msg -> !msg.internal()) //
.toList();
}

void recordMessage(String aSessionOwner, Project aProject, MAssistantTextMessage aMessage)
void recordMessage(String aSessionOwner, Project aProject, MAssistantMessage aMessage)
{
var state = getState(aSessionOwner, aProject);
state.addMessage(aMessage);
Expand All @@ -162,43 +171,45 @@ public void clearConversation(String aSessionOwner, Project aProject)

dispatchMessage(aSessionOwner, aProject, new MAssistantClearCommand());
}

@Override
public void processUserMessage(String aSessionOwner, Project aProject,
MAssistantTextMessage aMessage)
{
// Dispatch message early so the front-end can enter waiting state
dispatchMessage(aSessionOwner, aProject, aMessage);

var responseId = UUID.randomUUID();
try {
var systemMessages = generateSystemMessages(aSessionOwner, aProject, aMessage);
var transientMessages = generateTransientMessages(aSessionOwner, aProject, aMessage);
var recentMessages = getConversationMessages(aSessionOwner, aProject);
var conversationMessages = getChatMessages(aSessionOwner, aProject);

// We record the message only now to ensure it is not included in the listMessages above
recordMessage(aSessionOwner, aProject, aMessage);

if (properties.isDevMode()) {
// For testing purposes we send this message to the UI but do not record it as
// part of the conversation
for (var msg : transientMessages) {
recordMessage(aSessionOwner, aProject, msg);
dispatchMessage(aSessionOwner, aProject, msg);
}
}

var conversation = limitConversationToContextLength(systemMessages, transientMessages,
recentMessages, aMessage, properties.getChat().getContextLength());
var recentConversation = limitConversationToContextLength(systemMessages,
transientMessages, conversationMessages, aMessage,
properties.getChat().getContextLength());

var request = OllamaChatRequest.builder() //
.withModel(properties.getChat().getModel()) //
.withStream(true) //
.withMessages(conversation.stream() //
.withMessages(recentConversation.stream() //
.map(msg -> new OllamaChatMessage(msg.role(), msg.message())) //
.toList()) //
.withOption(OllamaOptions.NUM_CTX, properties.getChat().getContextLength()) //
.withOption(OllamaOptions.TOP_P, properties.getChat().getTopP()) //
.withOption(OllamaOptions.TOP_K, properties.getChat().getTopK()) //
.withOption(OllamaOptions.REPEAT_PENALTY, properties.getChat().getRepeatPenalty()) //
.withOption(OllamaOptions.REPEAT_PENALTY,
properties.getChat().getRepeatPenalty()) //
.withOption(OllamaOptions.TEMPERATURE, properties.getChat().getTemperature()) //
.build();

Expand Down Expand Up @@ -242,74 +253,15 @@ private List<MAssistantTextMessage> generateTransientMessages(String aSessionOwn
{
var transientMessages = new ArrayList<MAssistantTextMessage>();

addTransientContextFromUserManual(aSessionOwner, aProject, transientMessages, aMessage);

addTransientContextFromDocuments(aSessionOwner, aProject, transientMessages, aMessage);

var dtf = DateTimeFormatter.ofLocalizedDateTime(FormatStyle.MEDIUM);
transientMessages.add(MAssistantTextMessage.builder() //
.withRole(SYSTEM).internal() //
.withMessage("The current time is " + LocalDateTime.now(ZoneOffset.UTC).format(dtf)) //
.build());

return transientMessages;
}

private void addTransientContextFromUserManual(String aSessionOwner, Project aProject,
List<MAssistantTextMessage> aConversation, MAssistantTextMessage aMessage)
{
var messageBody = new StringBuilder();
var passages = documentationIndexingService.query(aMessage.message(), 3, 0.8);
for (var passage : passages) {
messageBody.append("\n```user-manual\n").append(passage).append("\n```\n\n");
}

MAssistantTextMessage message;
if (messageBody.isEmpty()) {
message = MAssistantTextMessage.builder() //
.withRole(SYSTEM).internal() //
.withMessage("There seems to be no relevant information in the user manual.") //
.build();
}
else {
message = MAssistantTextMessage.builder() //
.withRole(SYSTEM).internal() //
.withMessage(
"""
Use the context information from following user manual entries to respond.
""" + messageBody
.toString()) //
.build();
}
aConversation.add(message);
}

private void addTransientContextFromDocuments(String aSessionOwner, Project aProject,
List<MAssistantTextMessage> aConversation, MAssistantTextMessage aMessage)
{
var messageBody = new StringBuilder();
var passages = documentQueryService.query(aProject, aMessage.message(), 10,
properties.getEmbedding().getChunkScoreThreshold());
for (var passage : passages) {
messageBody.append("```context\n").append(passage).append("\n```\n\n");
for (var retriever : retrieverExtensionPoint.getExtensions(aProject)) {
transientMessages.addAll(retriever.retrieve(aSessionOwner, aProject, aMessage));
}

if (!messageBody.isEmpty()) {
var message = MAssistantTextMessage.builder() //
.withRole(SYSTEM).internal() //
.withMessage(
"""
Use the context information from the following documents to respond.
The source of this information are the authors of the documents.
""" + messageBody
.toString()) //
.build();
aConversation.add(message);
}
return transientMessages;
}

private List<MAssistantTextMessage> generateSystemMessages(String aSessionOwner, Project aProject,
MAssistantTextMessage aMessage)
private List<MAssistantTextMessage> generateSystemMessages(String aSessionOwner,
Project aProject, MAssistantTextMessage aMessage)
{
var primeDirectives = asList(
"You are Dominick, a helpful assistant within the annotation tool INCEpTION.",
Expand All @@ -332,7 +284,8 @@ private List<MAssistantTextMessage> generateSystemMessages(String aSessionOwner,
}

private List<MAssistantTextMessage> limitConversationToContextLength(
List<MAssistantTextMessage> aSystemMessages, List<MAssistantTextMessage> aTransientMessages,
List<MAssistantTextMessage> aSystemMessages,
List<MAssistantTextMessage> aTransientMessages,
List<MAssistantTextMessage> aRecentMessages, MAssistantTextMessage aLatestUserMessage,
int aContextLength)
{
Expand Down Expand Up @@ -449,30 +402,34 @@ private void clearState(String aSessionOwner)

private static class AssistentState
{
private LinkedList<MAssistantTextMessage> messages = new LinkedList<>();
private LinkedList<MAssistantMessage> messages = new LinkedList<>();

public List<MAssistantTextMessage> getMessages()
public List<MAssistantMessage> getMessages()
{
return new ArrayList<>(messages);
}

public void addMessage(MAssistantTextMessage aMessage)
public void addMessage(MAssistantMessage aMessage)
{
synchronized (messages) {
var found = false;
var i = messages.listIterator(messages.size());

// If a message with the same ID already exists, update it
while (i.hasPrevious() && !found) {
var m = i.previous();
if (Objects.equals(m.id(), aMessage.id())) {
if (aMessage.done()) {
i.set(aMessage);
}
else {
i.set(m.append(aMessage));
if (aMessage instanceof MAssistantTextMessage textMsg) {
var i = messages.listIterator(messages.size());

// If a message with the same ID already exists, update it
while (i.hasPrevious() && !found) {
var m = i.previous();
if (m instanceof MAssistantTextMessage existingTextMsg) {
if (Objects.equals(existingTextMsg.id(), textMsg.id())) {
if (textMsg.done()) {
i.set(textMsg);
}
else {
i.set(existingTextMsg.append(textMsg));
}
found = true;
}
}
found = true;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public List<MAssistantTextMessage> onSubscribeToAssistantMessages(SimpMessageHea
throws IOException
{
var project = projectService.getProject(aProjectId);
return assistantService.getConversationMessages(aPrincipal.getName(), project);
return assistantService.getAllChatMessages(aPrincipal.getName(), project);
}

@MessageMapping(PROJECT_ASSISTANT_TOPIC_TEMPLATE)
Expand Down
Loading

0 comments on commit 2c20c88

Please sign in to comment.