Skip to content

Commit

Permalink
#5211 - AI assistant prototype
Browse files Browse the repository at this point in the history
- Update settings documentation
- Adjust parameters for embedding
- Allow assistant to provide attribution to document context used in responses
- Allow specifying a document when using the DIAM scrollTo function
  • Loading branch information
reckart committed Jan 6, 2025
1 parent a558d05 commit 2260ace
Show file tree
Hide file tree
Showing 25 changed files with 11,443 additions and 14,349 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package de.tudarmstadt.ukp.clarin.webanno.api.annotation.page;

import static de.tudarmstadt.ukp.clarin.webanno.model.ValidationMode.NEVER;
import static de.tudarmstadt.ukp.inception.rendering.model.Range.rangeClippedToDocument;
import static de.tudarmstadt.ukp.inception.rendering.selection.FocusPosition.CENTERED;
import static de.tudarmstadt.ukp.inception.support.WebAnnoConst.CURATION_USER;
import static java.lang.String.format;
Expand Down Expand Up @@ -213,6 +214,7 @@ protected void onParameterArrival(IRequestParameters aRequestParameters,

var previousDoc = getModelObject().getDocument();
var aPreviousUser = getModelObject().getUser();

handleParameters(document, focus, user);

updateDocumentView(aTarget, previousDoc, aPreviousUser, focus);
Expand Down Expand Up @@ -265,24 +267,24 @@ protected void updateUrlFragment(AjaxRequestTarget aTarget)
* the document to open
* @return whether the document had to be switched or not.
*/
public boolean actionShowSelectedDocument(AjaxRequestTarget aTarget, SourceDocument aDocument)
public boolean actionShowDocument(AjaxRequestTarget aTarget, SourceDocument aDocument)
{
if (!Objects.equals(aDocument.getId(), getModelObject().getDocument().getId())) {
List<SourceDocument> docs = getListOfDocs();
if (!docs.contains(aDocument)) {
error("The document [" + aDocument.getName() + "] is not accessible");
if (aTarget != null) {
aTarget.addChildren(getPage(), IFeedback.class);
}
return false;
}
if (Objects.equals(aDocument.getId(), getModelObject().getDocument().getId())) {
return false;
}

getModelObject().setDocument(aDocument, docs);
actionLoadDocument(aTarget);
return true;
var docs = getListOfDocs();
if (!docs.contains(aDocument)) {
error("The document [" + aDocument.getName() + "] is not accessible");
if (aTarget != null) {
aTarget.addChildren(getPage(), IFeedback.class);
}
return false;
}

return false;
getModelObject().setDocument(aDocument, docs);
actionLoadDocument(aTarget);
return true;
}

/**
Expand All @@ -303,13 +305,14 @@ public void actionShowSelectedDocument(AjaxRequestTarget aTarget, SourceDocument
int aBegin, int aEnd)
throws IOException
{
boolean switched = actionShowSelectedDocument(aTarget, aDocument);
boolean switched = actionShowDocument(aTarget, aDocument);

var state = getModelObject();

var cas = getEditorCas();
state.getPagingStrategy().moveToOffset(state, cas, aBegin, new VRange(aBegin, aEnd),
CENTERED);
var range = rangeClippedToDocument(cas, aBegin, aEnd);
state.getPagingStrategy().moveToOffset(state, cas, aBegin,
new VRange(range.getBegin(), range.getEnd()), CENTERED);

if (!switched && state.getPagingStrategy() instanceof NoPagingStrategy) {
return;
Expand Down Expand Up @@ -496,7 +499,7 @@ public boolean isAnnotationFinished()

private boolean loadAnnotationFinished()
{
AnnotatorState state = getModelObject();
var state = getModelObject();
return documentService.isAnnotationFinished(state.getDocument(), state.getUser());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static org.apache.wicket.event.Broadcast.BREADTH;

import java.util.List;
import java.util.Optional;

import org.apache.uima.cas.CAS;
import org.apache.wicket.Page;
Expand Down Expand Up @@ -54,7 +53,7 @@ public void moveToOffset(AnnotatorViewState aState, CAS aCas, int aOffset, VRang
List<Unit> units = units(aCas);

// Find the unit containing the given offset
Unit unit = units.stream() //
var unit = units.stream() //
.filter(u -> u.getBegin() <= aOffset && aOffset <= u.getEnd()) //
.findFirst() //
.orElseThrow(() -> new IllegalArgumentException(
Expand All @@ -63,7 +62,7 @@ public void moveToOffset(AnnotatorViewState aState, CAS aCas, int aOffset, VRang
// How many rows to display before the unit such that the unit is centered?
int rowsInPageBeforeUnit = aState.getPreferences().getWindowSize() / 2;
// The -1 below is because unit.getIndex() is 1-based
Unit firstUnit = units.get(Math.max(0, unit.getIndex() - rowsInPageBeforeUnit - 1));
var firstUnit = units.get(Math.max(0, unit.getIndex() - rowsInPageBeforeUnit - 1));

aState.setPageBegin(aCas, firstUnit.getBegin());
aState.setFocusUnitIndex(unit.getIndex());
Expand All @@ -77,15 +76,15 @@ public void moveToOffset(AnnotatorViewState aState, CAS aCas, int aOffset, VRang

private void fireScrollToEvent(int aOffset, VRange aPingRange, FocusPosition aPos)
{
RequestCycle requestCycle = RequestCycle.get();
var requestCycle = RequestCycle.get();

if (requestCycle == null) {
return;
}

Optional<IPageRequestHandler> handler = requestCycle.find(IPageRequestHandler.class);
var handler = requestCycle.find(IPageRequestHandler.class);
if (handler.isPresent() && handler.get().isPageInstanceCreated()) {
Page page = (Page) handler.get().getPage();
var page = (Page) handler.get().getPage();
var target = requestCycle.find(AjaxRequestTarget.class).orElse(null);
page.send(page, BREADTH, new ScrollToEvent(target, aOffset, aPingRange, aPos));
}
Expand Down
14 changes: 14 additions & 0 deletions inception/inception-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-api-render</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-diam</artifactId>
</dependency>

<dependency>
<groupId>com.knuddels</groupId>
Expand Down Expand Up @@ -289,6 +293,16 @@
<groupId>com.github.eirslett</groupId>
<artifactId>frontend-maven-plugin</artifactId>
<executions>
<execution>
<id>npm link dependencies</id>
<goals>
<goal>npm</goal>
</goals>
<phase>${ts-link-dependency-phase}</phase>
<configuration>
<arguments>link @inception-project/inception-js-api @inception-project/inception-diam</arguments>
</configuration>
</execution>
<execution>
<id>npm build</id>
<goals>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,10 @@ private List<MTextMessage> generateTransientMessages(ChatContext aAssistant, MTe
private List<MTextMessage> generateSystemMessages()
{
var primeDirectives = asList(
"You are " + properties.getNickname() + ", a helpful assistant within the annotation tool INCEpTION.",
"Your name is " + properties.getNickname() + ".",
"You are a helpful assistant within the annotation tool INCEpTION.",
"INCEpTION always refers to the annotation tool, never anything else such as the movie.",
"Do not include references to INCEpTION unless the user explicitly asks about the environment itself.",
"If the source of an information is known, provide it in your response.",
"The document retriever automatically provides you with relevant information from the current project.",
"The user guide retriever automatically provides you with relevant information from the user guide.",
"Use this relevant information when responding to the user."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,27 @@ public class AssistantWebsocketControllerImpl

@Autowired
public AssistantWebsocketControllerImpl(ServletContext aServletContext,
SimpMessagingTemplate aMsgTemplate, AssistantService aAssistantService, ProjectService aProjectService
, UserDao aUserService)
SimpMessagingTemplate aMsgTemplate, AssistantService aAssistantService,
ProjectService aProjectService, UserDao aUserService)
{
assistantService = aAssistantService;
projectService = aProjectService;
userService = aUserService;
}

@SubscribeMapping(PROJECT_ASSISTANT_TOPIC_TEMPLATE)
public List<MTextMessage> onSubscribeToAssistantMessages(SimpMessageHeaderAccessor aHeaderAccessor,
Principal aPrincipal, //
public List<MTextMessage> onSubscribeToAssistantMessages(
SimpMessageHeaderAccessor aHeaderAccessor, Principal aPrincipal, //
@DestinationVariable(PARAM_PROJECT) long aProjectId)
throws IOException
{
var project = projectService.getProject(aProjectId);
return assistantService.getAllChatMessages(aPrincipal.getName(), project);
}

@MessageMapping(PROJECT_ASSISTANT_TOPIC_TEMPLATE)
public void onUserMessage(SimpMessageHeaderAccessor aHeaderAccessor,
Principal aPrincipal, //
@DestinationVariable(PARAM_PROJECT) long aProjectId,
@Payload String aMessage)
public void onUserMessage(SimpMessageHeaderAccessor aHeaderAccessor, Principal aPrincipal, //
@DestinationVariable(PARAM_PROJECT) long aProjectId, @Payload String aMessage)
throws IOException
{
var project = projectService.getProject(aProjectId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.lang.invoke.MethodHandles;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.UUID;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -79,10 +80,15 @@ public List<MTextMessage> retrieve(ChatContext aAssistant, MTextMessage aMessage
var references = new LinkedHashMap<Chunk, MReference>();
var body = new StringBuilder();
for (var chunk : chunks) {
var reference = new MReference(String.valueOf(references.size() + 1), "doc",
chunk.documentName(),
"#!d=" + chunk.documentId() + "&hl=" + chunk.begin() + "-" + chunk.end(),
chunk.score());
var reference = MReference.builder() //
//.withId(String.valueOf(references.size() + 1)) //
.withId(UUID.randomUUID().toString()) //
.withDocumentId(chunk.documentId()) //
.withDocumentName(chunk.documentName()) //
.withBegin(chunk.begin()) //
.withEnd(chunk.end()) //
.withScore(chunk.score()) //
.build();
references.put(chunk, reference);
renderChunkJson(body, chunk, reference);
}
Expand All @@ -91,16 +97,40 @@ public List<MTextMessage> retrieve(ChatContext aAssistant, MTextMessage aMessage
return emptyList();
}

return asList(MTextMessage.builder() //
var msg = MTextMessage.builder() //
.withActor("Document context retriever") //
.withRole(SYSTEM).internal() //
.withMessage(join("\n", asList(
"The document retriever found the following relevant information in the following documents.",
"", //
body.toString(), "",
"It is critical to mention the source of each document text in the form `{{ref::ref-id}}`.")))
.withReferences(references.values()) //
.build());
.withReferences(references.values());

// Works good with qwen72b but not with granite 8b
// msg.withMessage(join("\n", asList(
// "The document retriever found the following relevant information in the documents of this project.",
// "", //
// body.toString(), "",
// "It is critical to mention the source of each document text in the form `{{ref::ref-id}}`.")));

msg.withMessage(join("\n", asList(
"""
Use the following documents from this project to respond.
It is absolutely critital to mention the `{{ref::ref-id}}` after each individual information from a document.
Here is an example of how to include the ref-id:
Input:
{
"document": "The Eiffel Tower is located in Paris, France. It is one of the most famous landmarks in the world.",
"ref-id": "123"
}
Response:
The Eiffel Tower is located in Paris, France {{ref::123}}.
It is one of the most famous landmarks in the world {{ref::123}}.
Now, use the same pattern to process the following document:
""",
"", //
body.toString())));

return asList(msg.build());
}

private void renderChunkJson(StringBuilder body, Chunk chunk, MReference aReference)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,72 @@
*/
package de.tudarmstadt.ukp.inception.assistant.model;

public record MReference(String id, String type, String title, String target, double score) {
public record MReference(String id, long documentId, String documentName, int begin, int end,
double score)
{
private MReference(Builder builder)
{
this(builder.id, builder.documentId, builder.documentName, builder.begin, builder.end,
builder.score);
}

public static Builder builder()
{
return new Builder();
}

public static final class Builder
{
private String id;
private long documentId;
private String documentName;
private int begin;
private int end;
private double score;

private Builder()
{
}

public Builder withId(String aId)
{
id = aId;
return this;
}

public Builder withDocumentId(long aDocumentId)
{
documentId = aDocumentId;
return this;
}

public Builder withDocumentName(String aDocumentName)
{
documentName = aDocumentName;
return this;
}

public Builder withBegin(int aBegin)
{
begin = aBegin;
return this;
}

public Builder withEnd(int aEnd)
{
end = aEnd;
return this;
}

public Builder withScore(double aScore)
{
score = aScore;
return this;
}

public MReference build()
{
return new MReference(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import de.tudarmstadt.ukp.clarin.webanno.security.UserDao;
import de.tudarmstadt.ukp.inception.assistant.AssistantService;
import de.tudarmstadt.ukp.inception.assistant.AssistantWebsocketController;
import de.tudarmstadt.ukp.inception.diam.editor.DiamAjaxBehavior;
import de.tudarmstadt.ukp.inception.support.svelte.SvelteBehavior;
import jakarta.servlet.ServletContext;

Expand All @@ -50,18 +51,22 @@ public class AssistantPanel
private @SpringBean UserDao userService;
private @SpringBean AssistantService assistantService;

private DiamAjaxBehavior diamBehavior;

public AssistantPanel(String aId)
{
super(aId);
setOutputMarkupPlaceholderTag(true);
}

@Override
protected void onInitialize()
{
super.onInitialize();

add(new SvelteBehavior());

add(diamBehavior = new DiamAjaxBehavior(null));
}

@Override
Expand All @@ -77,6 +82,7 @@ protected void onConfigure()
}

Map<String, Object> properties = Map.of( //
"ajaxEndpointUrl", diamBehavior.getCallbackUrl(), //
"wsEndpointUrl", constructEndpointUrl(), //
"documentId", state.getDocument().getId(), //
"csrfToken", getCsrfTokenFromSession(), //
Expand Down
Loading

0 comments on commit 2260ace

Please sign in to comment.