Skip to content

Commit

Permalink
Enhance Message and Memory API Validation and storage (#3283)
Browse files Browse the repository at this point in the history
* Enchance Message and Memory API Validation and storage

Throw an error when an unknown field is provided in CreateConversation or CreateInteraction.
Skip saving empty fields in interactions and conversations to optimize storage usage.
Modify GET requests for interactions and conversations to return only non-null fields.
Throw an exception if all fields in a create interaction call are empty or null.
Add unit tests to cover the above cases.

Signed-off-by: rithin-pullela-aws <[email protected]>

* Update unit test to check for null instead of empty map

Signed-off-by: rithin-pullela-aws <[email protected]>

* Refactored userstr to Camel Case

Signed-off-by: rithin-pullela-aws <[email protected]>

* Addressing comments

Used assertThrows and added promptTemplate with empty string in test_ToXContent to ensure well rounded testing of expected functionality

Signed-off-by: rithin-pullela-aws <[email protected]>

* Undo: throw an error when an unknown field is provided in CreateConversation or CreateInteraction.

Signed-off-by: rithin-pullela-aws <[email protected]>

---------

Signed-off-by: rithin-pullela-aws <[email protected]>
  • Loading branch information
rithin-pullela-aws authored Dec 24, 2024
1 parent d09374c commit 06d39b9
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,18 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId);
builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id);
builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime);
builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
if (input != null && !input.trim().isEmpty()) {
builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
}
if (promptTemplate != null && !promptTemplate.trim().isEmpty()) {
builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
}
if (response != null && !response.trim().isEmpty()) {
builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
}
if (origin != null && !origin.trim().isEmpty()) {
builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
}
if (additionalInfo != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,17 @@ public void test_ToXContent() throws IOException {
.builder()
.conversationId("conversation id")
.origin("amazon bedrock")
.promptTemplate(" ")
.parentInteractionId("parant id")
.additionalInfo(Collections.singletonMap("suggestion", "new suggestion"))
.response("sample response")
.traceNum(1)
.build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
interaction.toXContent(builder, EMPTY_PARAMS);
String interactionContent = TestHelper.xContentBuilderToString(builder);
assertEquals(
"{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}",
"{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"response\":\"sample response\",\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}",
interactionContent
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,28 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest)
}
try (XContentParser parser = restRequest.contentParser()) {
Map<String, Object> body = parser.map();
String name = null;
String applicationType = null;
Map<String, String> additionalInfo = null;

for (String key : body.keySet()) {
switch (key) {
case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD:
name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD);
break;
case APPLICATION_TYPE_FIELD:
applicationType = (String) body.get(APPLICATION_TYPE_FIELD);
break;
case META_ADDITIONAL_INFO_FIELD:
additionalInfo = (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD);
break;
default:
parser.skipChildren();
break;
}
}
if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) {
return new CreateConversationRequest(
(String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD),
body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD)
);
return new CreateConversationRequest(name, applicationType, additionalInfo);
} else {
return new CreateConversationRequest();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro
}
}

boolean allFieldsEmpty = (input == null || input.trim().isEmpty())
&& (prompt == null || prompt.trim().isEmpty())
&& (response == null || response.trim().isEmpty())
&& (origin == null || origin.trim().isEmpty())
&& (addinf == null || addinf.isEmpty());
if (allFieldsEmpty) {
throw new IllegalArgumentException(
"At least one of the following parameters must be non-empty: " + "input, prompt_template, response, origin, additional_info"
);
}
return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -139,24 +140,24 @@ public void createConversation(
) {
initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> {
if (indexExists) {
String userstr = getUserStrFromThreadContext();
String userStr = getUserStrFromThreadContext();
Instant now = Instant.now();
IndexRequest request = Requests
.indexRequest(META_INDEX_NAME)
.source(
ConversationalIndexConstants.META_CREATED_TIME_FIELD,
now,
ConversationalIndexConstants.META_UPDATED_TIME_FIELD,
now,
ConversationalIndexConstants.META_NAME_FIELD,
name,
ConversationalIndexConstants.USER_FIELD,
userstr == null ? null : User.parse(userstr).getName(),
ConversationalIndexConstants.APPLICATION_TYPE_FIELD,
applicationType,
ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD,
additionalInfos == null ? Map.of() : additionalInfos
);
Map<String, Object> sourceMap = new HashMap<>();
sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now);
sourceMap.put(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, now);
if (name != null && !name.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name);
}
if (userStr != null && !userStr.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName());
}
if (applicationType != null && !applicationType.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType);
}
if (additionalInfos != null && !additionalInfos.isEmpty()) {
sourceMap.put(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, additionalInfos);
}
IndexRequest request = Requests.indexRequest(META_INDEX_NAME).source(sourceMap);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<IndexResponse> al = ActionListener.wrap(resp -> {
Expand Down Expand Up @@ -210,12 +211,12 @@ public void getConversations(int from, int maxResults, ActionListener<List<Conve
return;
}
SearchRequest request = Requests.searchRequest(META_INDEX_NAME);
String userstr = getUserStrFromThreadContext();
String userStr = getUserStrFromThreadContext();
QueryBuilder queryBuilder;
if (userstr == null)
if (userStr == null)
queryBuilder = new MatchAllQueryBuilder();
else
queryBuilder = new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, User.parse(userstr).getName());
queryBuilder = new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName());
request.source().query(queryBuilder);
request.source().from(from).size(maxResults);
request.source().sort(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, SortOrder.DESC);
Expand Down Expand Up @@ -264,8 +265,8 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
return;
}
DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId);
String userstr = getUserStrFromThreadContext();
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
String userStr = getUserStrFromThreadContext();
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
this.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
Expand Down Expand Up @@ -308,7 +309,7 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
listener.onResponse(true);
return;
}
String userstr = getUserStrFromThreadContext();
String userStr = getUserStrFromThreadContext();
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId);
Expand All @@ -318,12 +319,12 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
throw new ResourceNotFoundException("Memory [" + conversationId + "] not found");
}
// If security is off - User doesn't exist - you have permission
if (userstr == null || User.parse(userstr) == null) {
if (userStr == null || User.parse(userStr) == null) {
internalListener.onResponse(true);
return;
}
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
String user = User.parse(userstr).getName();
String user = User.parse(userStr).getName();
// If you're not the owner of this conversation, you do not have permission
if (!user.equals(conversation.getUser())) {
internalListener.onResponse(false);
Expand Down Expand Up @@ -353,9 +354,9 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
QueryBuilder originalQuery = request.source().query();
BoolQueryBuilder newQuery = new BoolQueryBuilder();
newQuery.must(originalQuery);
String userstr = getUserStrFromThreadContext();
if (userstr != null) {
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
String userStr = getUserStrFromThreadContext();
if (userStr != null) {
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, user));
}
request.source().query(newQuery);
Expand Down Expand Up @@ -388,11 +389,11 @@ public void updateConversation(String conversationId, UpdateRequest updateReques
if (access) {
innerUpdateConversation(updateRequest, listener);
} else {
String userstr = client
String userStr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
throw new OpenSearchStatusException(
"User [" + user + "] does not have access to memory " + conversationId,
RestStatus.UNAUTHORIZED
Expand Down Expand Up @@ -421,7 +422,7 @@ public void getConversation(String conversationId, ActionListener<ConversationMe
listener.onFailure(new IndexNotFoundException("cannot get memory since the memory index does not exist", META_INDEX_NAME));
return;
}
String userstr = getUserStrFromThreadContext();
String userStr = getUserStrFromThreadContext();
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<ConversationMeta> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId);
Expand All @@ -432,12 +433,12 @@ public void getConversation(String conversationId, ActionListener<ConversationMe
}
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
// If no security, return conversation
if (userstr == null || User.parse(userstr) == null) {
if (userStr == null || User.parse(userStr) == null) {
internalListener.onResponse(conversation);
return;
}
// If security and correct user, return conversation
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
if (user.equals(conversation.getUser())) {
internalListener.onResponse(conversation);
log.info("Successfully get the memory for {}", conversationId);
Expand Down
Loading

0 comments on commit 06d39b9

Please sign in to comment.