Skip to content

Commit

Permalink
Merge pull request #46 from HanaPiece/feat/gemini-api-improve
Browse files Browse the repository at this point in the history
feat/#15  구조화 된 응답을 위한 프롬프트 관리 클래스 생성 & AI API 온도 매개변수 값 조정
  • Loading branch information
duddn2012 authored Jun 4, 2024
2 parents 0d3b7b6 + 6e9f5df commit aa52ca1
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 14 deletions.
34 changes: 34 additions & 0 deletions src/main/java/com/project/hana_piece/ai/dto/GeminiCallRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.project.hana_piece.ai.dto;

import com.project.hana_piece.ai.vo.GeminiPrompt;
import java.util.List;

public record GeminiCallRequest(List<Content> contents, List<SafetySetting> safetySettings, GenerationConfig generationConfig) {

private static final String safetySettingCategory = "HARM_CATEGORY_DANGEROUS_CONTENT";
private static final String safetySettingThreshold = "BLOCK_ONLY_HIGH";
private static final Double generationConfigTemperature = 0.1;
private static final Double generationConfigTopP = 0.1;
private static final Integer generationConfigTopK = 10;

public static record Part(String text) {}

public static record Content(List<Part> parts) {}

public static record SafetySetting(String category, String threshold) {}

public static record GenerationConfig(Double temperature, Double topP, Integer topK) {}

public static GeminiCallRequest fromPrompt(GeminiPrompt geminiPrompt){
Part part = new Part(geminiPrompt.getTotalPrompt());
Content content = new Content(List.of(part));
List<Content> contents = List.of(content);

SafetySetting safetySetting = new SafetySetting(safetySettingCategory, safetySettingThreshold);
List<SafetySetting> safetySettings = List.of(safetySetting);

GenerationConfig generationConfig = new GenerationConfig(generationConfigTemperature, generationConfigTopP, generationConfigTopK);

return new GeminiCallRequest(contents, safetySettings, generationConfig);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static com.project.hana_piece.ai.vo.GeminiResponseField.TEXT;

import com.google.gson.JsonObject;
import com.project.hana_piece.ai.dto.GeminiCallRequest;
import com.project.hana_piece.ai.dto.GeminiCallResponse;
import com.project.hana_piece.ai.exception.GeminiNetworkIOException;
import com.project.hana_piece.ai.vo.GeminiPrompt;
Expand Down Expand Up @@ -48,9 +49,12 @@ public GeminiCallResponse callGenerativeLanguageApi(GeminiPrompt geminiPrompt) {
WebClient webClient = WebClient.create(BASE_URL + API_KEY);

try {
GeminiCallRequest geminiCallRequest = GeminiCallRequest.fromPrompt(geminiPrompt);
String requestBody = jsonUtil.toJsonString(geminiCallRequest);

String geminiApiResponseString = webClient.post()
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(geminiPrompt.getTotalPrompt())
.bodyValue(requestBody)
.retrieve()
.bodyToMono(String.class)
.block();
Expand Down
27 changes: 22 additions & 5 deletions src/main/java/com/project/hana_piece/ai/vo/GeminiPrompt.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
package com.project.hana_piece.ai.vo;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;

@Getter
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class GeminiPrompt {

private final String totalPrompt;
private String requests;
private String constraints;
private String responseFormat;
private String exampleData;

public GeminiPrompt(String prompt) {
this.totalPrompt = createPromptJson(prompt);
public synchronized String getTotalPrompt() {
StringBuilder totalPromptBuilder = new StringBuilder();
totalPromptBuilder.append(requests)
.append(constraints)
.append(responseFormat)
.append(exampleData);
return totalPromptBuilder.toString();
}

private String createPromptJson(String prompt) {
return "{\"contents\":[{\"parts\":[{\"text\":\""+prompt+"\"}]}]}";
@Builder
public GeminiPrompt(String requests, String constraints, String responseFormat,
String exampleData) {
this.requests = "Requests: " + requests + "\n";
this.constraints = "Constraints: " + constraints + "\n";
this.responseFormat = "ResponseFormat: " + responseFormat + "\n";
this.exampleData = "ExampleData: "+ exampleData + "\n";
}
}
18 changes: 15 additions & 3 deletions src/main/java/com/project/hana_piece/common/util/JsonUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.project.hana_piece.ai.vo.GeminiResponseField;
import com.project.hana_piece.common.exception.JsonElementNotFoundException;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.Getter;
import org.springframework.stereotype.Component;
Expand All @@ -33,10 +30,25 @@ public <T> T fromJson(JsonElement jsonElement, Class<T> jsonObjectClass) {
return gson.fromJson(jsonElement, jsonObjectClass);
}

/**
* String 타입의 JSON 데이터 -> JsonObject 반환
* @param payload
* @return
*/
public JsonObject toJson(String payload) {
return gson.fromJson(payload, JsonObject.class);
}

/**
* 제네릭 타입의 데이터 -> JSON String 반환
* @param payload
* @param <T>
* @return
*/
public <T> String toJsonString(T payload) {
return gson.toJson(payload);
}

/**
* 재귀적 탐색을 통해 JsonElement 요소 중 특정 key 값을 갖는 JsonElement 탐색
* @param jsonElement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public ApartmentPricePredictResponse predictApartmentPrice(ApartmentPricePredict
String prompt = String.format("%s년 %s월 가장 최근 거래가가 %.1f억인 %s %s %d평의 가격을 선형 회귀 모델을 이용해 오차범위 50퍼센트 가격 이내에서 중간 시나리오로 최대한 정확히 예측해서 '%.1f억' 형식으로 가격만 알려줘",
year, month, priceInBillion, request.region(), request.apartmentNm(), request.area(), priceInBillion);

GeminiPrompt geminiPrompt = new GeminiPrompt(prompt);
GeminiPrompt geminiPrompt = GeminiPrompt.builder().requests(prompt).build();
GeminiCallResponse response = aiService.callGenerativeLanguageApi(geminiPrompt);

return new ApartmentPricePredictResponse(response.message());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
@Entity(name = "enrolled_products")
@Getter
@NoArgsConstructor(access = AccessLevel.PROTECTED)
@Builder
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class EnrolledProduct extends BaseEntity {

@Id
Expand Down Expand Up @@ -44,5 +42,17 @@ public class EnrolledProduct extends BaseEntity {
@Column(name="auto_renewal")
private boolean autoRenewal; //yyyyMM


@Builder
public EnrolledProduct(Product product, UserGoal userGoal, Integer contractPeriod,
BigInteger initialAmount, Long autoDebitAmount, Integer autoDebitDay, String maturityDate,
boolean autoRenewal) {
this.product = product;
this.userGoal = userGoal;
this.contractPeriod = contractPeriod;
this.initialAmount = initialAmount;
this.autoDebitAmount = autoDebitAmount;
this.autoDebitDay = autoDebitDay;
this.maturityDate = maturityDate;
this.autoRenewal = autoRenewal;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public RecommendationResponse recommendProducts(Long userGoalId) {

// 추천 상품 목록 생성
String promptMessage = buildPromptMessage(userGoal, products, enrolledProductIds);
GeminiPrompt geminiPrompt = new GeminiPrompt(promptMessage);
GeminiPrompt geminiPrompt = GeminiPrompt.builder().requests(promptMessage).build();
GeminiCallResponse aiResponse = aiService.callGenerativeLanguageApi(geminiPrompt);
String aiResponseMessage = aiResponse.message();
String[] productIdStringList = aiResponseMessage.split(",");
Expand Down

0 comments on commit aa52ca1

Please sign in to comment.