Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] GPT 를 활용한 musicgen promt 및 노래 제목 생성 #76

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.kuit.agarang.domain.ai.model.dto.MusicAnswer;
import com.kuit.agarang.domain.ai.model.dto.TextAnswer;
import com.kuit.agarang.domain.ai.model.dto.QuestionResponse;
import com.kuit.agarang.domain.ai.model.entity.cache.GPTChatHistory;
import com.kuit.agarang.domain.ai.service.AIService;
import com.kuit.agarang.global.common.model.dto.BaseResponse;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -36,7 +37,8 @@ public ResponseEntity<BaseResponse<Void>> saveLastAnswer(@RequestBody TextAnswer

@PostMapping("/music")
public ResponseEntity<BaseResponse<Void>> saveLastAnswer(@RequestBody MusicAnswer answer) {
AIService.saveMusicChoice(answer);
GPTChatHistory chatHistory = AIService.setMusicChoice(answer);
AIService.createMusicGenPrompt(chatHistory);
return ResponseEntity.ok(new BaseResponse<>());
}
}
16 changes: 16 additions & 0 deletions src/main/java/com/kuit/agarang/domain/ai/model/dto/MusicInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,20 @@ public static MusicInfo from(MusicChoice musicChoice) {
throw new BusinessException(BaseResponseStatus.INVALID_MUSIC_CHOICE);
}
}

public String getInstrumentAsString() {
return instrument.toString().toLowerCase();
}

public String getGenreAsString() {
return genre.toString().toLowerCase();
}

public String getMoodAsString() {
return mood.toString().toLowerCase();
}

public String getTempoAsString() {
return tempo.toString().toLowerCase();
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package com.kuit.agarang.domain.ai.model.dto.gpt;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kuit.agarang.global.common.exception.exception.OpenAPIException;
import com.kuit.agarang.global.common.model.dto.BaseResponseStatus;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
Expand All @@ -18,14 +14,4 @@
public class GPTImageDescription {
private String text;
private List<String> noun;

public static GPTImageDescription from(String content) {
ObjectMapper objectMapper = new ObjectMapper();
try {
return objectMapper.readValue(content, GPTImageDescription.class);
} catch (JsonProcessingException e) {
log.info("gpt's image description : {}", content);
throw new OpenAPIException(BaseResponseStatus.INVALID_GPT_RESPONSE);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.kuit.agarang.domain.ai.model.entity.cache;

import com.kuit.agarang.domain.ai.model.dto.MusicInfo;
import com.kuit.agarang.domain.ai.model.dto.gpt.GPTImageDescription;
import com.kuit.agarang.domain.ai.model.dto.gpt.GPTMessage;
import com.kuit.agarang.global.s3.model.dto.S3File;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
Expand All @@ -12,16 +14,16 @@
@Getter
@NoArgsConstructor
public class GPTChatHistory {
private String imageTempPath;
private List<String> hashtags;
private S3File image;
private GPTImageDescription imageDescription;
private List<GPTMessage> historyMessages;
@Setter
private MusicInfo musicInfo;

@Builder
public GPTChatHistory(String imageTempPath, List<String> hashtags, List<GPTMessage> historyMessages, MusicInfo musicInfo) {
this.imageTempPath = imageTempPath;
this.hashtags = hashtags;
public GPTChatHistory(S3File image, GPTImageDescription imageDescription, List<GPTMessage> historyMessages, MusicInfo musicInfo) {
this.image = image;
this.imageDescription = imageDescription;
this.historyMessages = historyMessages;
this.musicInfo = musicInfo;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
public enum GPTSystemRole {
IMAGE_DESCRIBER("You're the one who describes the image. Emotionally describe the image."),
COUNSELOR("너는 상담사야."),
ASSISTANT("You are an assistant.");
ASSISTANT("You are an assistant."),
MUSIC_PROMPT_ENGINEER("너는 text to music prompt 엔지니어야."),
MUSIC_TITLE_WRITER("너는 한국어 음악 노래 작명가야.")
;

private String text;
}
38 changes: 26 additions & 12 deletions src/main/java/com/kuit/agarang/domain/ai/service/AIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ public class AIService {
private final ObjectMapper objectMapper;

public QuestionResponse getFirstQuestion(MultipartFile image) throws Exception {
S3File s3File = s3FileUtil.uploadTempFile(image);
S3File convertedImage = s3FileUtil.uploadTempFile(image);

// image -> gpt -> 노래제목, 해시태그 생성
String prompt = promptUtil.createImageDescriptionPrompt();
GPTChat imageChat = gptChatService.chatWithImage(s3File, prompt);
GPTImageDescription imageDescription = GPTImageDescription.from(gptUtil.getGPTAnswer(imageChat));
GPTChat imageChat = gptChatService.chatWithImage(convertedImage, prompt);
GPTImageDescription imageDescription = gptUtil.parseJson(imageChat, GPTImageDescription.class);

// 해시태그 -> gpt -> 질문1 생성
prompt = promptUtil.createImageQuestionPrompt(imageDescription);
GPTChat questionChat = gptChatService.chat(GPTSystemRole.COUNSELOR, prompt, 0L);
GPTChat questionChat = gptChatService.chat(GPTSystemRole.COUNSELOR, prompt, 0L, false);
String question = gptUtil.getGPTAnswer(questionChat);

// 질문1 -> tts -> 오디오 변환
Expand All @@ -62,8 +62,8 @@ public QuestionResponse getFirstQuestion(MultipartFile image) throws Exception {
List<GPTMessage> historyMessage = gptUtil.createHistoryMessage(questionChat);
redisService.save(redisKey,
GPTChatHistory.builder()
.imageTempPath(s3File.getFilename())
.hashtags(imageDescription.getNoun())
.image(convertedImage.cleanBytes())
.imageDescription(imageDescription)
.historyMessages(historyMessage)
.build());

Expand Down Expand Up @@ -127,20 +127,34 @@ public void createMemoryText(String gptChatHistoryId) {
redisService.save(gptChatHistoryId, chatHistory);
}

public void saveMusicChoice(MusicAnswer answer) {
public GPTChatHistory setMusicChoice(MusicAnswer answer) {
GPTChatHistory chatHistory = redisService.get(answer.getId(), GPTChatHistory.class)
.orElseThrow(() -> new OpenAPIException(BaseResponseStatus.NOT_FOUND_HISTORY_CHAT));

chatHistory.setMusicInfo(MusicInfo.from(answer.getMusicChoice()));
log.info("music info : {}, {}, {}, {}",
chatHistory.getMusicInfo().getInstrument(), chatHistory.getMusicInfo().getGenre(),
chatHistory.getMusicInfo().getMood(), chatHistory.getMusicInfo().getTempo());
redisService.save(answer.getId(), chatHistory);
return chatHistory;
}

@Async
public void createMusicGenPrompt(GPTChatHistory chatHistory) {
String prompt = promptUtil.createMusicGenPrompt(chatHistory.getImageDescription(), chatHistory.getMusicInfo());
GPTChat chat = gptChatService.chat(GPTSystemRole.MUSIC_PROMPT_ENGINEER, prompt, 1L, true);
String musicGenPrompt = gptUtil.parseJson(chat, "prompt");
log.info(musicGenPrompt);

prompt = promptUtil.createMusicTitlePrompt(musicGenPrompt, chatHistory.getMusicInfo());
chat = gptChatService.chat(GPTSystemRole.MUSIC_TITLE_WRITER, prompt, 1L, true);
String musicTitle = gptUtil.parseJson(chat, "music_name");
log.info(musicTitle);

// TODO : 음악 생성

// TODO : DB 저장
}

public String getCharacterBubble(Character character, String familyRole) {
String prompt = promptUtil.createCharacterBubble(character, familyRole);
GPTChat chat = gptChatService.chat(GPTSystemRole.ASSISTANT, prompt, 1L);
GPTChat chat = gptChatService.chat(GPTSystemRole.ASSISTANT, prompt, 1L, false);
return gptUtil.getGPTAnswer(chat);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ public GPTChat chatWithImage(S3File image, String prompt) {
return new GPTChat(request, response);
}

public GPTChat chat(GPTSystemRole systemRole, String prompt, Long temperature) {
public GPTChat chat(GPTSystemRole systemRole, String prompt, Long temperature, boolean requiredJson) {
GPTMessage systemMessage = gptUtil.createSystemMessage(systemRole);
GPTMessage message = gptUtil.createTextMessage(prompt);

GPTRequest request = new GPTRequest(new ArrayList<>(List.of(systemMessage, message)));
request.setTemperature(temperature);
request.setRequiredJson(requiredJson);
GPTResponse response = gptClientUtil.post(request, GPTResponse.class);
return new GPTChat(request, response);
}
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/com/kuit/agarang/domain/ai/utils/GPTPromptUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.kuit.agarang.domain.ai.utils;

import com.kuit.agarang.domain.ai.model.dto.MusicInfo;
import com.kuit.agarang.domain.ai.model.dto.gpt.GPTImageDescription;
import com.kuit.agarang.domain.baby.model.entity.Character;
import org.springframework.stereotype.Component;
Expand Down Expand Up @@ -34,4 +35,24 @@ public String createCharacterBubble(Character character, String familyRole) {
.append(familyRole).append("에게 하고 싶은 말을 10자 내로 짧게 작성해줘. ")
.append(character.getName()).append("의 특징은 다음과 같아. ").append(character.getDescription()).toString();
}

public String createMusicGenPrompt(GPTImageDescription imageDescription, MusicInfo musicInfo) {
return new StringBuilder("text to music 이란, 장르, 악기, 무드, 속도를 포함한 prompt 를 입력으로 Music 을 생성하는 task를 의미해. ")
.append("너는 지금부터, 아래의 내 요구사항을 반드시 지켜서 music 생성을 위한 prompt 를 만들어야돼. ")
.append("Music Generation 을 위한 Text Prompt 예시는 다음과 같아.\n")
.append("1. Pop dance track with catchy melodies, tropical percussion, and upbeat rhythms, perfect for the beach.\n")
.append("2. A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.\n")
.append("3. earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves.\n")
.append(musicInfo.getGenreAsString()).append("와 ").append(imageDescription.getText()).append(" 에 어울리고, ")
.append(musicInfo.getInstrumentAsString()).append("로 연주되며, ").append(musicInfo.getMoodAsString()).append("이 잘 표현되는 prompt를 만들어줘. ")
.append("반드시 json 형식으로 알려줘. ").append("{\"prompt\": prompt}").toString();
}

public String createMusicTitlePrompt(String musicGenPrompt, MusicInfo musicInfo) {
return new StringBuilder("산모들의 태교 노래 제목를 만들어줘. ").append("태교 노래 제목의 예시는 다음과 같아.\n")
.append("1. 봄마중 꽃마중, ").append("2. 찬 바람이 불던 밤, ").append("3. 깊은 밤을 날아서\n")
.append(musicGenPrompt).append(" 이미지가 떠오를 수 있으면서, ")
.append(musicInfo.getMoodAsString()).append(" 분위기가 잘 표현되는 한국어로 된 태교 음악 제목을 만들어줘. ")
.append("반드시 json 형식으로 알려줘. ").append("{\"music_name\": music_name}").toString();
}
}
26 changes: 26 additions & 0 deletions src/main/java/com/kuit/agarang/domain/ai/utils/GPTUtil.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
package com.kuit.agarang.domain.ai.utils;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kuit.agarang.domain.ai.model.dto.gpt.GPTChat;
import com.kuit.agarang.domain.ai.model.dto.gpt.GPTContent;
import com.kuit.agarang.domain.ai.model.dto.gpt.GPTMessage;
import com.kuit.agarang.domain.ai.model.enums.GPTRole;
import com.kuit.agarang.domain.ai.model.enums.GPTSystemRole;
import com.kuit.agarang.global.common.exception.exception.OpenAPIException;
import com.kuit.agarang.global.common.model.dto.BaseResponseStatus;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.List;

@Slf4j
@Component
public class GPTUtil {

private final ObjectMapper objectMapper = new ObjectMapper();

public GPTMessage createImageQuestion(String prompt, String imageUrl) {
return GPTMessage.builder()
.role(GPTRole.USER)
Expand Down Expand Up @@ -57,4 +63,24 @@ private GPTMessage getResponseMessage(GPTChat gptChat) {
throw new OpenAPIException(BaseResponseStatus.INVALID_GPT_RESPONSE);
}
}

public <T> T parseJson(GPTChat chat, Class<T> clazz) {
String jsonString = getGPTAnswer(chat);
try {
return objectMapper.readValue(jsonString, clazz);
} catch (JsonProcessingException e) {
log.info("invalid gpt's json answer (clazz) : {}", jsonString);
throw new OpenAPIException(BaseResponseStatus.INVALID_GPT_RESPONSE);
}
}

public String parseJson(GPTChat chat, String filedName) {
String jsonString = getGPTAnswer(chat);
try {
return objectMapper.readTree(jsonString).get(filedName).asText();
} catch (JsonProcessingException e) {
log.info("invalid gpt's json answer (string) : {}", jsonString);
throw new OpenAPIException(BaseResponseStatus.INVALID_GPT_RESPONSE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ public String toGPTImageUrl() {
String base64EncodeData = Base64.getEncoder().encodeToString(this.getBytes());
return "data:" + this.getContentType().getMimeType() + ";base64," + base64EncodeData;
}

public S3File cleanBytes() {
this.bytes = null;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ void updateObject() {
GPTMessage.builder().role(GPTRole.USER).content("user-user").build()));

GPTChatHistory chatHistory = GPTChatHistory.builder()
.imageTempPath("path")
.hashtags(List.of("a", "b", "c"))
.imageDescription(new GPTImageDescription("text", List.of("a", "b", "c")))
.historyMessages(messages)
.build();
redisService.save(KEY, chatHistory);
Expand All @@ -70,8 +69,8 @@ void updateObject() {

// then
GPTChatHistory updatedGptChatHistory = redisService.get(KEY, GPTChatHistory.class).get();
assertEquals(gptChatHistory.getImageTempPath(), updatedGptChatHistory.getImageTempPath());
assertEquals(gptChatHistory.getHashtags(), updatedGptChatHistory.getHashtags());
assertEquals(gptChatHistory.getImageDescription().getText(), updatedGptChatHistory.getImageDescription().getText());
assertEquals(gptChatHistory.getImageDescription().getNoun(), updatedGptChatHistory.getImageDescription().getNoun());

assertEquals(gptChatHistory.getHistoryMessages().get(0).getContent(),
updatedGptChatHistory.getHistoryMessages().get(0).getContent());
Expand Down Expand Up @@ -141,8 +140,7 @@ void saveGPTChatAndGet() {

request.getMessages().add(responseMessage); // 질문, 대답 합친 history message 포함해서 선언
GPTChatHistory chatHistory = GPTChatHistory.builder()
.imageTempPath("images/image.jpeg")
.hashtags(imageDescription.getNoun())
.imageDescription(imageDescription)
.historyMessages(request.getMessages())
.build();

Expand All @@ -152,8 +150,8 @@ void saveGPTChatAndGet() {
// then
GPTChatHistory savedChatHistory = redisService.get(KEY, GPTChatHistory.class)
.orElseThrow(() -> new RuntimeException(""));
assertEquals(chatHistory.getImageTempPath(), savedChatHistory.getImageTempPath());
assertEquals(chatHistory.getHashtags(), savedChatHistory.getHashtags());
assertEquals(chatHistory.getImageDescription().getText(), savedChatHistory.getImageDescription().getText());
assertEquals(chatHistory.getImageDescription().getNoun(), savedChatHistory.getImageDescription().getNoun());
assertEquals(requestMessage.getContent(), savedChatHistory.getHistoryMessages().get(0).getContent());
assertEquals(responseMessage.getContent(), savedChatHistory.getHistoryMessages().get(1).getContent());
}
Expand Down