Skip to content

Commit

Permalink
Merge pull request #89 from Princess-in-silvertown/feat/87
Browse files Browse the repository at this point in the history
Feat: 이미지 모델 연결 & S3 저장 로직 생성
  • Loading branch information
woosung1223 authored Sep 30, 2024
2 parents 6b294da + 41370e9 commit 7a5208f
Show file tree
Hide file tree
Showing 21 changed files with 218 additions and 184 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ dependencies {

// Spring Security 설정
implementation 'org.springframework.boot:spring-boot-starter-security'

implementation 'org.springframework.cloud:spring-cloud-starter-aws:2.2.6.RELEASE'
}

tasks.register('importConfigSubmodule', Copy) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package slvtwn.khu.toyouserver.agent.aws;

import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import lombok.Getter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Getter
@Configuration
public class AwsConfiguration {

@Value("${cloud.aws.credentials.access-key}")
private String accessKey;
@Value("${cloud.aws.credentials.secret-key}")
private String secretKey;
@Value("${cloud.aws.region.static}")
private String region;
@Value("${cloud.aws.s3.bucket}")
private String bucket;

@Bean
public AmazonS3Client amazonS3Client() {
BasicAWSCredentials awsCredentials = new BasicAWSCredentials(accessKey, secretKey);
return (AmazonS3Client) AmazonS3ClientBuilder.standard()
.withRegion(region)
.withCredentials(new AWSStaticCredentialsProvider(awsCredentials))
.build();
}
}
37 changes: 37 additions & 0 deletions src/main/java/slvtwn/khu/toyouserver/agent/aws/S3Agent.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package slvtwn.khu.toyouserver.agent.aws;

import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.PutObjectRequest;
import java.io.ByteArrayInputStream;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import slvtwn.khu.toyouserver.common.response.ResponseType;
import slvtwn.khu.toyouserver.exception.ToyouException;

@RequiredArgsConstructor
@Component
public class S3Agent {

private final AmazonS3Client amazonS3Client;
private final AwsConfiguration configuration;

public String uploadFile(byte[] fileData, String fileName, String contentType) {
try {
ObjectMetadata metadata = new ObjectMetadata();
metadata.setContentType(contentType);
metadata.setContentLength(fileData.length);

ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(fileData);
amazonS3Client.putObject(new PutObjectRequest(configuration.getBucket(), fileName, byteArrayInputStream, metadata));

return fileName;
} catch (Exception e) {
throw new ToyouException(ResponseType.INTERNAL_SERVER_ERROR);
}
}

public String getUrl(String key) {
return amazonS3Client.getResourceUrl(configuration.getBucket(), key);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package slvtwn.khu.toyouserver.agent.model;

import java.util.Base64;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import slvtwn.khu.toyouserver.agent.WebClientWrapper;

@RequiredArgsConstructor
@Component
public class StickerModelAgent {

private final StickerModelConfiguration configuration;

private final WebClientWrapper webClientWrapper;

public List<byte[]> generateStickers(StickerModelRequest request) {
StickerModelResponse response = webClientWrapper.send(
configuration.getEndpoint(), (httpHeaders -> {
}), request, StickerModelResponse.class
);

return response.base64EncodedImages().stream()
.map(each -> Base64.getDecoder().decode(each))
.toList();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package slvtwn.khu.toyouserver.agent.model;

import lombok.Getter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

@Getter
@Component
public class StickerModelConfiguration {

@Value("${sticker-model.endpoint}")
private String endpoint;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package slvtwn.khu.toyouserver.agent.model;

public record StickerModelRequest(String prompt, String color) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package slvtwn.khu.toyouserver.agent.model;

import java.util.List;

public record StickerModelResponse(List<String> base64EncodedImages) {

}

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

25 changes: 22 additions & 3 deletions src/main/java/slvtwn/khu/toyouserver/application/AgentService.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,43 @@

import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import lombok.AllArgsConstructor;
import org.springframework.stereotype.Service;
import slvtwn.khu.toyouserver.agent.aws.S3Agent;
import slvtwn.khu.toyouserver.agent.gpt.ChatGptAgent;
import slvtwn.khu.toyouserver.agent.gpt.ChatGptResponse;
import slvtwn.khu.toyouserver.agent.model.StickerModelAgent;
import slvtwn.khu.toyouserver.agent.model.StickerModelRequest;
import slvtwn.khu.toyouserver.dto.GenerateStickerRequest;
import slvtwn.khu.toyouserver.dto.GenerateStickerResponse;
import slvtwn.khu.toyouserver.dto.KeywordRequest;
import slvtwn.khu.toyouserver.dto.KeywordResponse;

@AllArgsConstructor
@Service
public class AgentService {

private static final String STICKER_MIME_TYPE = "image/png";

private final StickerModelAgent stickerModelAgent;
private final ChatGptAgent chatGptAgent;
private final S3Agent s3Agent;

public GenerateStickerResponse generateStickers(GenerateStickerRequest request) {
List<String> urls = stickerModelAgent.generateStickers(new StickerModelRequest(request.prompt(), request.color())).stream()
.map(each -> s3Agent.uploadFile(each, UUID.nameUUIDFromBytes(each).toString(), STICKER_MIME_TYPE))
.map(s3Agent::getUrl)
.toList();

return new GenerateStickerResponse(urls);
}

public KeywordResponse generateKeywords(KeywordRequest request) {
String content = request.content();
String prompt = String.format("""
Suggest 3 keywords that could represent emotions or characteristics in the content.
<example>
<request>
content:
Expand All @@ -31,7 +50,7 @@ public KeywordResponse generateKeywords(KeywordRequest request) {
반가움, 기대, 아쉬움
</response>
</example>
<example>
<request>
content:
Expand All @@ -42,7 +61,7 @@ public KeywordResponse generateKeywords(KeywordRequest request) {
축하, 아쉬움, 즐거움
</response>
</example>
content: %s
keywords:
""", content);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import org.springframework.data.domain.PageRequest;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import slvtwn.khu.toyouserver.agent.modellabs.ModelLabsAgent;
import slvtwn.khu.toyouserver.common.response.ResponseType;
import slvtwn.khu.toyouserver.domain.Member;
import slvtwn.khu.toyouserver.domain.RollingPaper;
Expand All @@ -31,17 +30,18 @@ public class RollingPaperService {
private final RollingPaperRepository rollingPaperRepository;
private final StickerRepository stickerRepository;
private final MemberRepository memberRepository;
private final ModelLabsAgent modelLabsAgent;
private final RollingPaperRepository rollingpaperRepository;
private final UserRepository userRepository;

// TODO: ModelLabs 관련 문제로 DISABLED
@Deprecated
public void generateCoverImageAndUpdateRollingPaper(CoverRequest request, Long rollingPaperId) {
List<String> keywords = request.keywords();
String coverImageUrl = modelLabsAgent.generateCoverWithKeywords(keywords)
.get(FIRST_COVER_IMAGE);
RollingPaper rollingPaper = rollingpaperRepository.findById(rollingPaperId)
.orElseThrow(() -> new ToyouException(ResponseType.BAD_REQUEST));
rollingPaper.updateCoverImage(coverImageUrl);
// List<String> keywords = request.keywords();
// String coverImageUrl = modelLabsAgent.generateCoverWithKeywords(keywords)
// .get(FIRST_COVER_IMAGE);
// RollingPaper rollingPaper = rollingpaperRepository.findById(rollingPaperId)
// .orElseThrow(() -> new ToyouException(ResponseType.BAD_REQUEST));
// rollingPaper.updateCoverImage(coverImageUrl);
}

public void sendRollingPaper(Long recipientUserId, RollingPaperRequest request) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package slvtwn.khu.toyouserver.dto;

public record GenerateStickerRequest(String prompt, String color) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package slvtwn.khu.toyouserver.dto;

import java.util.List;

public record GenerateStickerResponse(List<String> stickers) {

}
Loading

0 comments on commit 7a5208f

Please sign in to comment.