diff --git a/src/main/java/com/project/hana_piece/ai/dto/GeminiCallRequest.java b/src/main/java/com/project/hana_piece/ai/dto/GeminiCallRequest.java new file mode 100644 index 0000000..b8c6aba --- /dev/null +++ b/src/main/java/com/project/hana_piece/ai/dto/GeminiCallRequest.java @@ -0,0 +1,34 @@ +package com.project.hana_piece.ai.dto; + +import com.project.hana_piece.ai.vo.GeminiPrompt; +import java.util.List; + +public record GeminiCallRequest(List contents, List safetySettings, GenerationConfig generationConfig) { + + private static final String safetySettingCategory = "HARM_CATEGORY_DANGEROUS_CONTENT"; + private static final String safetySettingThreshold = "BLOCK_ONLY_HIGH"; + private static final Double generationConfigTemperature = 0.1; + private static final Double generationConfigTopP = 0.1; + private static final Integer generationConfigTopK = 10; + + public static record Part(String text) {} + + public static record Content(List parts) {} + + public static record SafetySetting(String category, String threshold) {} + + public static record GenerationConfig(Double temperature, Double topP, Integer topK) {} + + public static GeminiCallRequest fromPrompt(GeminiPrompt geminiPrompt){ + Part part = new Part(geminiPrompt.getTotalPrompt()); + Content content = new Content(List.of(part)); + List contents = List.of(content); + + SafetySetting safetySetting = new SafetySetting(safetySettingCategory, safetySettingThreshold); + List safetySettings = List.of(safetySetting); + + GenerationConfig generationConfig = new GenerationConfig(generationConfigTemperature, generationConfigTopP, generationConfigTopK); + + return new GeminiCallRequest(contents, safetySettings, generationConfig); + } +} diff --git a/src/main/java/com/project/hana_piece/ai/service/AiServiceImpl.java b/src/main/java/com/project/hana_piece/ai/service/AiServiceImpl.java index 04d9633..0d23faa 100644 --- a/src/main/java/com/project/hana_piece/ai/service/AiServiceImpl.java +++ b/src/main/java/com/project/hana_piece/ai/service/AiServiceImpl.java @@ -3,6 +3,7 @@ import static com.project.hana_piece.ai.vo.GeminiResponseField.TEXT; import com.google.gson.JsonObject; +import com.project.hana_piece.ai.dto.GeminiCallRequest; import com.project.hana_piece.ai.dto.GeminiCallResponse; import com.project.hana_piece.ai.exception.GeminiNetworkIOException; import com.project.hana_piece.ai.vo.GeminiPrompt; @@ -48,9 +49,12 @@ public GeminiCallResponse callGenerativeLanguageApi(GeminiPrompt geminiPrompt) { WebClient webClient = WebClient.create(BASE_URL + API_KEY); try { + GeminiCallRequest geminiCallRequest = GeminiCallRequest.fromPrompt(geminiPrompt); + String requestBody = jsonUtil.toJsonString(geminiCallRequest); + String geminiApiResponseString = webClient.post() .contentType(MediaType.APPLICATION_JSON) - .bodyValue(geminiPrompt.getTotalPrompt()) + .bodyValue(requestBody) .retrieve() .bodyToMono(String.class) .block(); diff --git a/src/main/java/com/project/hana_piece/ai/vo/GeminiPrompt.java b/src/main/java/com/project/hana_piece/ai/vo/GeminiPrompt.java index b0e3d82..377ec1c 100644 --- a/src/main/java/com/project/hana_piece/ai/vo/GeminiPrompt.java +++ b/src/main/java/com/project/hana_piece/ai/vo/GeminiPrompt.java @@ -1,17 +1,34 @@ package com.project.hana_piece.ai.vo; +import lombok.AccessLevel; +import lombok.Builder; import lombok.Getter; +import lombok.NoArgsConstructor; @Getter +@NoArgsConstructor(access = AccessLevel.PROTECTED) public class GeminiPrompt { - private final String totalPrompt; + private String requests; + private String constraints; + private String responseFormat; + private String exampleData; - public GeminiPrompt(String prompt) { - this.totalPrompt = createPromptJson(prompt); + public synchronized String getTotalPrompt() { + StringBuilder totalPromptBuilder = new StringBuilder(); + totalPromptBuilder.append(requests) + .append(constraints) + .append(responseFormat) + .append(exampleData); + return totalPromptBuilder.toString(); } - private String createPromptJson(String prompt) { - return "{\"contents\":[{\"parts\":[{\"text\":\""+prompt+"\"}]}]}"; + @Builder + public GeminiPrompt(String requests, String constraints, String responseFormat, + String exampleData) { + this.requests = "Requests: " + requests + "\n"; + this.constraints = "Constraints: " + constraints + "\n"; + this.responseFormat = "ResponseFormat: " + responseFormat + "\n"; + this.exampleData = "ExampleData: "+ exampleData + "\n"; } } diff --git a/src/main/java/com/project/hana_piece/common/util/JsonUtil.java b/src/main/java/com/project/hana_piece/common/util/JsonUtil.java index 380d8c0..c2578a9 100644 --- a/src/main/java/com/project/hana_piece/common/util/JsonUtil.java +++ b/src/main/java/com/project/hana_piece/common/util/JsonUtil.java @@ -4,11 +4,8 @@ import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonObject; -import com.project.hana_piece.ai.vo.GeminiResponseField; import com.project.hana_piece.common.exception.JsonElementNotFoundException; -import java.util.ArrayList; -import java.util.List; import java.util.Map; import lombok.Getter; import org.springframework.stereotype.Component; @@ -33,10 +30,25 @@ public T fromJson(JsonElement jsonElement, Class jsonObjectClass) { return gson.fromJson(jsonElement, jsonObjectClass); } + /** + * String 타입의 JSON 데이터 -> JsonObject 반환 + * @param payload + * @return + */ public JsonObject toJson(String payload) { return gson.fromJson(payload, JsonObject.class); } + /** + * 제네릭 타입의 데이터 -> JSON String 반환 + * @param payload + * @param + * @return + */ + public String toJsonString(T payload) { + return gson.toJson(payload); + } + /** * 재귀적 탐색을 통해 JsonElement 요소 중 특정 key 값을 갖는 JsonElement 탐색 * @param jsonElement diff --git a/src/main/java/com/project/hana_piece/goal/service/ApartmentService.java b/src/main/java/com/project/hana_piece/goal/service/ApartmentService.java index c2bc6da..503499a 100644 --- a/src/main/java/com/project/hana_piece/goal/service/ApartmentService.java +++ b/src/main/java/com/project/hana_piece/goal/service/ApartmentService.java @@ -38,7 +38,7 @@ public ApartmentPricePredictResponse predictApartmentPrice(ApartmentPricePredict String prompt = String.format("%s년 %s월 가장 최근 거래가가 %.1f억인 %s %s %d평의 가격을 선형 회귀 모델을 이용해 오차범위 50퍼센트 가격 이내에서 중간 시나리오로 최대한 정확히 예측해서 '%.1f억' 형식으로 가격만 알려줘", year, month, priceInBillion, request.region(), request.apartmentNm(), request.area(), priceInBillion); - GeminiPrompt geminiPrompt = new GeminiPrompt(prompt); + GeminiPrompt geminiPrompt = GeminiPrompt.builder().requests(prompt).build(); GeminiCallResponse response = aiService.callGenerativeLanguageApi(geminiPrompt); return new ApartmentPricePredictResponse(response.message()); diff --git a/src/main/java/com/project/hana_piece/product/domain/EnrolledProduct.java b/src/main/java/com/project/hana_piece/product/domain/EnrolledProduct.java index 7325be4..27c7d88 100644 --- a/src/main/java/com/project/hana_piece/product/domain/EnrolledProduct.java +++ b/src/main/java/com/project/hana_piece/product/domain/EnrolledProduct.java @@ -10,8 +10,6 @@ @Entity(name = "enrolled_products") @Getter @NoArgsConstructor(access = AccessLevel.PROTECTED) -@Builder -@AllArgsConstructor(access = AccessLevel.PRIVATE) public class EnrolledProduct extends BaseEntity { @Id @@ -44,5 +42,17 @@ public class EnrolledProduct extends BaseEntity { @Column(name="auto_renewal") private boolean autoRenewal; //yyyyMM - + @Builder + public EnrolledProduct(Product product, UserGoal userGoal, Integer contractPeriod, + BigInteger initialAmount, Long autoDebitAmount, Integer autoDebitDay, String maturityDate, + boolean autoRenewal) { + this.product = product; + this.userGoal = userGoal; + this.contractPeriod = contractPeriod; + this.initialAmount = initialAmount; + this.autoDebitAmount = autoDebitAmount; + this.autoDebitDay = autoDebitDay; + this.maturityDate = maturityDate; + this.autoRenewal = autoRenewal; + } } \ No newline at end of file diff --git a/src/main/java/com/project/hana_piece/product/service/ProductService.java b/src/main/java/com/project/hana_piece/product/service/ProductService.java index b7dfcfd..56df9c3 100644 --- a/src/main/java/com/project/hana_piece/product/service/ProductService.java +++ b/src/main/java/com/project/hana_piece/product/service/ProductService.java @@ -66,7 +66,7 @@ public RecommendationResponse recommendProducts(Long userGoalId) { // 추천 상품 목록 생성 String promptMessage = buildPromptMessage(userGoal, products, enrolledProductIds); - GeminiPrompt geminiPrompt = new GeminiPrompt(promptMessage); + GeminiPrompt geminiPrompt = GeminiPrompt.builder().requests(promptMessage).build(); GeminiCallResponse aiResponse = aiService.callGenerativeLanguageApi(geminiPrompt); String aiResponseMessage = aiResponse.message(); String[] productIdStringList = aiResponseMessage.split(",");