Skip to content

Commit

Permalink
Merge pull request #22 from olmangjolmang/OMJM-82-recommend-posts
Browse files Browse the repository at this point in the history
생성형 AI Gemini 도입, 함께 읽으면 좋은 아티클 기능 구현
  • Loading branch information
chaeyeonKong authored Jul 9, 2024
2 parents a6ec0e6 + 4455f5f commit 74aa0fc
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 76 deletions.
7 changes: 4 additions & 3 deletions src/main/java/com/ticle/server/ServerApplication.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.PropertySource;

@SpringBootApplication
public class ServerApplication {

public static void main(String[] args) {
SpringApplication.run(ServerApplication.class, args);
}
public static void main(String[] args) {
SpringApplication.run(ServerApplication.class, args);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ public ResponseEntity<ResponseTemplate<Object>> findAllArticles(@RequestParam(re

Page<PostResponse> postPage = postService.findAllByCategory(category, page);

// List<PostResponse> postResponses = postPage.getContent().stream()
// .map(PostResponse::from)
// .collect(Collectors.toList());


return ResponseEntity
.status(HttpStatus.OK)
.body(ResponseTemplate.from(postPage));
Expand All @@ -59,7 +54,8 @@ public ResponseEntity<ResponseTemplate<Object>> findArticle(@PathVariable long i

return ResponseEntity
.status(HttpStatus.OK)
.body(ResponseTemplate.from(post));
.body(ResponseTemplate.from(PostResponse.from((Post) post)));

}

@Operation(summary = "아티클 스크랩", description = "새로운 아티클 스크랩, 스크랩 취소")
Expand Down Expand Up @@ -94,4 +90,6 @@ public ResponseEntity<ResponseTemplate<Object>> memoArticle(@PathVariable long i
.body(ResponseTemplate.from(MemoDto.from((Memo) memo)));
}
}


}
4 changes: 4 additions & 0 deletions src/main/java/com/ticle/server/post/domain/Post.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,8 @@ public class Post extends BaseTimeEntity {
@OneToMany(mappedBy = "post", cascade = CascadeType.ALL, orphanRemoval = true)
private List<Scrapped> scrappeds;

@Transient
@Column(name = "recommend_post")
private List recommendPost;

}
56 changes: 56 additions & 0 deletions src/main/java/com/ticle/server/post/dto/GeminiRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.ticle.server.post.dto;

import lombok.*;

import java.util.ArrayList;
import java.util.List;

@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
@Getter
@Setter

public class GeminiRequest {
private List<Content> contents;
private GenerationConfig generationConfig;

@Getter
@Setter
public static class Content {
private Parts parts;
}

@Getter
@Setter
public static class Parts {
private String text;

}

@Getter
@Setter
public static class GenerationConfig {
private int candidate_count;
private int max_output_tokens;
private double temperature;

}

public GeminiRequest(String prompt) {
this.contents = new ArrayList<>();
Content content = new Content();
Parts parts = new Parts();

parts.setText(prompt);
content.setParts(parts);

this.contents.add(content);
this.generationConfig = new GenerationConfig();
this.generationConfig.setCandidate_count(1);
this.generationConfig.setMax_output_tokens(1000);
this.generationConfig.setTemperature(0.7);
}
}

84 changes: 84 additions & 0 deletions src/main/java/com/ticle/server/post/dto/GeminiResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package com.ticle.server.post.dto;

import lombok.*;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class GeminiResponse {

private List<Candidate> candidates;
private PromptFeedback promptFeedback;

@Data
public static class Candidate {
private Content content;
private String finishReason;
private int index;
private List<SafetyRating> safetyRatings;
}

@Data
public static class Content {
private List<Parts> parts;
private String role;
}

@Data
public static class Parts {
private String text;
}

@Data
public static class SafetyRating {
private String category;
private String probability;
}

@Data
public static class PromptFeedback {
private List<SafetyRating> safetyRatings;
}

// Gemini 리턴
// postDetail의 recommenPosts 결과를 List 형태로 만들어주는 함수
public List<Map<String, Object>> formatRecommendPost() {
List<Map<String, Object>> recommendPosts = new ArrayList<>();

if (candidates != null) {
for (Candidate candidate : candidates) {
Content content = candidate.getContent();
if (content != null && content.getParts() != null && content.getParts().size() > 0) {
String combinedString = content.getParts().get(0).getText();

String[] parts = combinedString.split("postTitle=");

if (parts.length == 2) {
String postId = parts[0].substring(parts[0].indexOf("[") + 1, parts[0].indexOf("]"));
String postTitle = parts[1].substring(parts[1].indexOf("[") + 1, parts[1].indexOf("]"));

//postid, posttitle를 각각 list화
String[] postIds = postId.split(", ");
String[] postTitles = postTitle.split(", ");

for (int i = 0; i < postIds.length; i++) {
Map<String, Object> postMap = new LinkedHashMap<>();
postMap.put("postId", Long.parseLong(postIds[i]));
postMap.put("postTitle", postTitles[i].replaceAll("'", ""));
recommendPosts.add(postMap);
}
}
}
}
}

return recommendPosts;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.ticle.server.post.dto;


import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;

@Configuration
@RequiredArgsConstructor
public class GeminiRestTemplateConfig {

@Bean
@Qualifier("geminiRestTemplate")
public RestTemplate geminiRestTemplate() {
RestTemplate restTemplate = new RestTemplate();
restTemplate.getInterceptors().add((request, body, execution) -> execution.execute(request, body));

return restTemplate;
}
}
21 changes: 21 additions & 0 deletions src/main/java/com/ticle/server/post/dto/PostIdTitleDto.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.ticle.server.post.dto;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;

@NoArgsConstructor
@AllArgsConstructor
@Getter
public class PostIdTitleDto {
private Long postId;
private String title;

@Override
public String toString() {
return "{" +
"postId=" + postId +
", title='" + title + '\'' +
'}';
}
}
7 changes: 5 additions & 2 deletions src/main/java/com/ticle/server/post/dto/PostResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import com.ticle.server.user.domain.type.Category;
import lombok.Getter;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;

import java.time.LocalDateTime;
import java.util.List;

@Getter
@AllArgsConstructor
@NoArgsConstructor
public class PostResponse {

private Long postId;
Expand All @@ -19,7 +22,7 @@ public class PostResponse {
private LocalDateTime createdDate;
private Category postCategory;
private S3Info image;
private Long userId;
private List recommendPost;

public static PostResponse from(Post post) {
return new PostResponse(
Expand All @@ -30,7 +33,7 @@ public static PostResponse from(Post post) {
post.getCreatedDate(),
post.getCategory(),
post.getImage(),
post.getUser() != null ? post.getUser().getId() : null
post.getRecommendPost()
);
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.ticle.server.post.repository;

import com.ticle.server.post.domain.Post;
import com.ticle.server.post.dto.PostIdTitleDto;
import com.ticle.server.user.domain.type.Category;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;

import java.util.List;

Expand All @@ -14,4 +16,6 @@ public interface PostRepository extends JpaRepository<Post, Long> {

Page<Post> findByCategory(Category category, Pageable pageable);

@Query("SELECT new com.ticle.server.post.dto.PostIdTitleDto(p.postId, p.title) FROM Post p")
List<PostIdTitleDto> findAllPostSummaries();
}
46 changes: 44 additions & 2 deletions src/main/java/com/ticle/server/post/service/PostService.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.ticle.server.memo.domain.Memo;
import com.ticle.server.mypage.repository.MemoRepository;
import com.ticle.server.post.dto.PostResponse;
import com.ticle.server.post.dto.*;
import com.ticle.server.scrapped.dto.ScrappedDto;
import com.ticle.server.user.domain.type.Category;
import com.ticle.server.post.domain.Post;
Expand All @@ -12,18 +12,25 @@
import com.ticle.server.user.domain.User;
import com.ticle.server.user.service.UserService;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.server.ResponseStatusException;

import java.util.List;
import java.util.Optional;

@RequiredArgsConstructor
@Service

public class PostService {

private final PostRepository postRepository;
Expand All @@ -50,9 +57,44 @@ public Page<PostResponse> findAllByCategory(String category, int page) {

}


@Qualifier("geminiRestTemplate")
@Autowired
private RestTemplate restTemplate;

@Value("${gemini.api.url}")
private String apiUrl;

@Value("${gemini.api.key}")
private String geminiApiKey;

//postId로 조회한 특정 post 정보 리턴
public Post findById(long id) {
return postRepository.findById(id).orElseThrow(() -> new IllegalArgumentException("not found: " + id));

Optional<Post> optionalPost = postRepository.findById(id);
Post post = optionalPost.orElseThrow(() -> new IllegalArgumentException("Post not found with ID: " + id));

String now_post_title = post.getTitle();
List<PostIdTitleDto> alltitle = postRepository.findAllPostSummaries();

// Gemini에 요청 전송
String requestUrl = apiUrl + "?key=" + geminiApiKey;
String prompt = "현재 기사의 제목은 " + now_post_title + " 이야. " +
"다음은 기사의 리스트야. 리스트 안의 title과 비교해서 " +
"현재의 기사 제목과 가장 연관 된 기사 3개의 id, title을 각각의 리스트로 추출해줘." +
"단, 현재 기사는 제외한다." +
"예: postId=[1,2,3], postTitle=[title1, title2, title3] "
+ alltitle;

System.out.println(prompt);

GeminiRequest request = new GeminiRequest(prompt);
GeminiResponse response = restTemplate.postForObject(requestUrl, request, GeminiResponse.class);

List recommendPost = response.formatRecommendPost(); // 리턴 형식 지정하는 함수
post.setRecommendPost(recommendPost);

return post;
}


Expand Down
Loading

0 comments on commit 74aa0fc

Please sign in to comment.