Skip to content

Commit

Permalink
Several bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jschm42 committed Jun 25, 2024
1 parent fa843c8 commit 9a3c08e
Show file tree
Hide file tree
Showing 22 changed files with 384 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,6 @@ public class RestClientConfiguration {
@Value("${spring.ai.ollama.base-url}")
private String ollamaBaseUrl;

@Value("${elevenlabs.api-key}")
private String elevenLabsApiKey;

@Value("${elevenlabs.api-url}")
private String elevenLabsBaseUrl;

@Bean(name = "elevenLabsRestClient")
public RestClient elevenLabsRestClient() {
return RestClient.builder()
.baseUrl(elevenLabsBaseUrl)
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + elevenLabsApiKey)
.build();
}

@Bean(name = "openAiRestClient")
public RestClient openAiRestClient() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class AssistantSpringController {

private final AssistantSpringService assistantService;


private final FileStorageService fileStorageService;

public AssistantSpringController(AssistantSpringService assistantService,
Expand Down Expand Up @@ -178,6 +179,11 @@ public ResponseEntity<byte[]> getImage(@PathVariable String threadId,
}
}

@PostMapping("/threads/{threadId}/regenerate")
public void regenerateThread(@PathVariable("threadId") String threadId) {
assistantService.regenerateThread(threadId);
}

@DeleteMapping("/threads/{threadId}")
public void deleteThread(@PathVariable("threadId") String threadId) {
assistantService.deleteThread(threadId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public Map<String, String> mapAssistantProperties(
Map<String, AssistantPropertyValue> properties) {
Map<String, String> mappedProperties = new HashMap<>();

Arrays.stream(AssistantProperties.values()).forEach(property -> {
Arrays.stream(AssistantProperty.values()).forEach(property -> {
AssistantPropertyValue propertyValue = properties.get(property.getKey());
if (propertyValue != null) {
mappedProperties.put(
Expand All @@ -89,8 +89,12 @@ public Map<String, String> mapAssistantProperties(
public Map<String, AssistantPropertyValue> mapProperties(Map<String, String> properties) {
Map<String, AssistantPropertyValue> mappedProperties = new HashMap<>();

Arrays.stream(AssistantProperties.values()).forEach(property -> {
Arrays.stream(AssistantProperty.values()).forEach(property -> {
String propertyValue = properties.get(property.getKey());
if (propertyValue != null && propertyValue.isEmpty()) {
propertyValue = null;
}

AssistantPropertyValue assistantPropertyValue = new AssistantPropertyValue();
assistantPropertyValue.setPropertyValue(propertyValue);
mappedProperties.put(property.getKey(), assistantPropertyValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,43 @@

package com.talkforgeai.backend.assistant.service;

public enum AssistantProperties {
public enum AssistantProperty {
TTS_TYPE("tts_type", ""),

SPEECHAPI_VOICE("speechAPI_voice", ""),

ELEVENLABS_VOICEID("elevenlabs_voiceId", ""),
ELEVENLABS_MODELID("elevenlabs_modelId", "eleven_monolingual_v2"),
ELEVENLABS_MODELID("elevenlabs_modelId", "eleven_multilingual_v2"),
ELEVENLABS_SIMILARITYBOOST("elevenlabs_similarityBoost", "0"),
ELEVENLABS_STABILITY("elevenlabs_stability", "0"),

MODEL_TEMPERATURE("model_temperature", "0.7"),
MODEL_TOP_P("model_topP", "1.0"),
MODEL_FREQUENCY_PENALTY("model_frequencyPenalty", "0"),
MODEL_PRESENCE_PENALTY("model_presencePenalty", "0"),
MODEL_MAX_TOKENS("model_maxTokens", "4096"),

FEATURE_PLANTUML("feature_plantUMLGeneration", "false"),
FEATURE_IMAGEGENERATION("feature_imageGeneration", "false"),
FEATURE_AUTOSPEAKDEFAULT("feature_autoSpeakDefault", "false"),
FEATURE_TITLEGENERATION("feature_titleGeneration", "true");

private final String key;
private final String defaultValue;

AssistantProperties(String key, String defaultValue) {
AssistantProperty(String key, String defaultValue) {
this.key = key;
this.defaultValue = defaultValue;
}

public static AssistantProperty fromKey(String key) {
for (AssistantProperty property : AssistantProperty.values()) {
if (property.getKey().equals(key)) {
return property;
}
}
throw new IllegalArgumentException("Unknown key: " + key);
}

public String getDefaultValue() {
return defaultValue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.talkforgeai.backend.memory.functions.MemoryContextStorageFunction.Request;
import com.talkforgeai.backend.memory.functions.MemoryContextStorageFunction.Response;
import com.talkforgeai.backend.memory.functions.MemoryFunctionContext;
import com.talkforgeai.backend.memory.service.DBVectorStore;
import com.talkforgeai.backend.memory.service.MemoryService;
import com.talkforgeai.backend.storage.FileStorageService;
import com.talkforgeai.backend.transformers.MessageProcessor;
Expand Down Expand Up @@ -103,8 +104,18 @@
public class AssistantSpringService {

public static final Logger LOGGER = LoggerFactory.getLogger(AssistantSpringService.class);
public static final String SYSTEM_MESSAGE_PLANTUML = "You can generate PlantUML diagrams. PlantUML code that you generate will be transformed to a downloadable image.";
public static final String SYSTEM_MESSAGE_IMAGE_GEN = "You can generate an image by using the following syntax: !image_gen[<image prompt>]";
private static final String SYSTEM_MESSAGE_MEMORY = """
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
---------------------
LONG_TERM_MEMORY:
%s
---------------------
""";


private final UniversalChatService universalChatService;
private final UniversalImageGenService universalImageGenService;
Expand Down Expand Up @@ -144,18 +155,16 @@ public AssistantSpringService(

private static @NotNull Mono<InitInfos> getInitInfosMono(Mono<AssistantDto> assistantEntityMono,
Mono<List<MessageDto>> pastMessages) {
Mono<InitInfos> initInfosMono = Mono.zip(
return Mono.zip(
Arrays.asList(assistantEntityMono, pastMessages),
args -> new InitInfos((AssistantDto) args[0], (List<MessageDto>) args[1]));
return initInfosMono;
}

private static @NotNull Flux<ServerSentEvent<String>> getRunIdEventFlux(String runId) {
Flux<ServerSentEvent<String>> runIdMono = Flux.just(ServerSentEvent.<String>builder()
return Flux.just(ServerSentEvent.<String>builder()
.event("run.started")
.data(runId)
.build());
return runIdMono;
}

private static @NotNull List<Message> getFinalPromptMessageList(
Expand All @@ -174,27 +183,30 @@ public AssistantSpringService(

List<Message> finalPromptMessageList = new ArrayList<>(promptMessageList);

if (assistantDto.properties()
.get(AssistantProperties.FEATURE_IMAGEGENERATION.getKey()).equals(
"true")) {
finalPromptMessageList.addFirst(new SystemMessage(SYSTEM_MESSAGE_IMAGE_GEN));
// Remove the last message if it was from the user
if (!finalPromptMessageList.isEmpty()
&& finalPromptMessageList.getLast() instanceof UserMessage) {
finalPromptMessageList.removeLast();
}

if (assistantDto.properties()
.get(AssistantProperties.FEATURE_PLANTUML.getKey()).equals(
"true")) {
finalPromptMessageList.addFirst(new SystemMessage(SYSTEM_MESSAGE_PLANTUML));
.get(AssistantProperty.FEATURE_IMAGEGENERATION.getKey()).equals("true")) {
finalPromptMessageList.addFirst(new SystemMessage(SYSTEM_MESSAGE_IMAGE_GEN));
}

finalPromptMessageList.addFirst(new SystemMessage(assistantDto.instructions()));

StringBuilder memoryMessage = new StringBuilder();
if (!memoryResultsList.isEmpty()) {
if (!memoryResultsList.isEmpty() && assistantDto.memory() == MemoryType.AI_DECIDES) {
StringBuilder memoryMessage = new StringBuilder();
memoryMessage.append("Use the following information from memory:\n");
memoryResultsList.forEach(
result -> memoryMessage.append(result.content()).append("\n"));
memoryMessage.append("\nUser message:\n");

String memorySystemMessage = SYSTEM_MESSAGE_MEMORY.formatted(memoryMessage);
finalPromptMessageList.addFirst(new SystemMessage(memorySystemMessage));
}

return finalPromptMessageList;
}

Expand Down Expand Up @@ -262,8 +274,8 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S

Mono<Object> saveUserMessageMono = getSaveUserMessageMono(assistantId, threadId, userMessage);
Mono<AssistantDto> assistantEntityMono = getAssistantEntityMono(assistantId);
Mono<List<MessageDto>> pastMessages = getPastMessagesMono(threadId);
Mono<InitInfos> initInfosMono = getInitInfosMono(assistantEntityMono, pastMessages);
Mono<List<MessageDto>> pastMessagesMono = getPastMessagesMono(threadId);
Mono<InitInfos> initInfosMono = getInitInfosMono(assistantEntityMono, pastMessagesMono);

StringBuilder assistantMessageContent = new StringBuilder();

Expand All @@ -272,12 +284,18 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S
.then(initInfosMono)
.flux()
.flatMap(initInfos -> {
List<DocumentWithoutEmbeddings> memorySearchResults = getMemorySearchResults(
initInfos.assistantDto.id(), initInfos.assistantDto.memory(), userMessage);
List<DocumentWithoutEmbeddings> memorySearchResults = new ArrayList<>();

if (MemoryType.AI_DECIDES == initInfos.assistantDto.memory()) {
LOGGER.info("Searching memory for assistant '{}'", assistantId);
memorySearchResults = getMemorySearchResults(
initInfos.assistantDto.id(), initInfos.assistantDto.memory(), userMessage);
}

return Flux.just(
new PreparedInfos(initInfos.assistantDto(), initInfos.pastMessages(),
memorySearchResults));
memorySearchResults)
);
})
.flatMap(preparedInfos -> {
AssistantDto assistantDto = preparedInfos.assistantDto();
Expand Down Expand Up @@ -341,6 +359,14 @@ public Flux<ServerSentEvent<String>> streamRunConversation(String assistantId, S
private @NotNull ServerSentEvent<String> mapChatResponse(@NotNull ChatResponse chatResponse,
StringBuilder assistantMessageContent) {

if (chatResponse.getResult() == null || chatResponse.getResult().getOutput() == null) {
LOGGER.warn("Empty ChatResponse received: {}", chatResponse.getResult());
return ServerSentEvent.<String>builder()
.event("thread.message.delta")
.data("")
.build();
}

String content = chatResponse.getResult().getOutput().getContent();
LOGGER.trace("ChatResponse received: {}", chatResponse.getResult());

Expand Down Expand Up @@ -407,7 +433,8 @@ private FunctionCallbackWrapper<Request, Response> getMemoryFunctionCallback(
LOGGER.info("Searching memory for message: {}", message);

FilterExpressionBuilder expressionBuilder = new FilterExpressionBuilder();
Expression assistantExpression = expressionBuilder.eq("assistantId", assistantId).build();
Expression assistantExpression = expressionBuilder.eq(DBVectorStore.SEARCH_CONVERSATION_ID,
assistantId).build();

List<DocumentWithoutEmbeddings> searchResults = memoryService.search(
SearchRequest.query(message)
Expand Down Expand Up @@ -701,6 +728,10 @@ public List<String> retrieveModels(LlmSystem system) {
return universalChatService.getModels(system);
}

public void regenerateThread(String threadId) {

}

record InitInfos(AssistantDto assistantDto, List<MessageDto> pastMessages) {

}
Expand Down
Loading

0 comments on commit 9a3c08e

Please sign in to comment.