diff --git a/build.gradle b/build.gradle index ce5adb6..94665d2 100644 --- a/build.gradle +++ b/build.gradle @@ -56,6 +56,8 @@ dependencies { // Spring Security 설정 implementation 'org.springframework.boot:spring-boot-starter-security' + + implementation 'org.springframework.cloud:spring-cloud-starter-aws:2.2.6.RELEASE' } tasks.register('importConfigSubmodule', Copy) { diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/aws/AwsConfiguration.java b/src/main/java/slvtwn/khu/toyouserver/agent/aws/AwsConfiguration.java new file mode 100644 index 0000000..bab7382 --- /dev/null +++ b/src/main/java/slvtwn/khu/toyouserver/agent/aws/AwsConfiguration.java @@ -0,0 +1,33 @@ +package slvtwn.khu.toyouserver.agent.aws; + +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.services.s3.AmazonS3Client; +import com.amazonaws.services.s3.AmazonS3ClientBuilder; +import lombok.Getter; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Getter +@Configuration +public class AwsConfiguration { + + @Value("${cloud.aws.credentials.access-key}") + private String accessKey; + @Value("${cloud.aws.credentials.secret-key}") + private String secretKey; + @Value("${cloud.aws.region.static}") + private String region; + @Value("${cloud.aws.s3.bucket}") + private String bucket; + + @Bean + public AmazonS3Client amazonS3Client() { + BasicAWSCredentials awsCredentials = new BasicAWSCredentials(accessKey, secretKey); + return (AmazonS3Client) AmazonS3ClientBuilder.standard() + .withRegion(region) + .withCredentials(new AWSStaticCredentialsProvider(awsCredentials)) + .build(); + } +} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/aws/S3Agent.java b/src/main/java/slvtwn/khu/toyouserver/agent/aws/S3Agent.java new file mode 100644 index 0000000..827ac60 --- /dev/null +++ b/src/main/java/slvtwn/khu/toyouserver/agent/aws/S3Agent.java @@ -0,0 +1,37 @@ +package slvtwn.khu.toyouserver.agent.aws; + +import com.amazonaws.services.s3.AmazonS3Client; +import com.amazonaws.services.s3.model.ObjectMetadata; +import com.amazonaws.services.s3.model.PutObjectRequest; +import java.io.ByteArrayInputStream; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; +import slvtwn.khu.toyouserver.common.response.ResponseType; +import slvtwn.khu.toyouserver.exception.ToyouException; + +@RequiredArgsConstructor +@Component +public class S3Agent { + + private final AmazonS3Client amazonS3Client; + private final AwsConfiguration configuration; + + public String uploadFile(byte[] fileData, String fileName, String contentType) { + try { + ObjectMetadata metadata = new ObjectMetadata(); + metadata.setContentType(contentType); + metadata.setContentLength(fileData.length); + + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(fileData); + amazonS3Client.putObject(new PutObjectRequest(configuration.getBucket(), fileName, byteArrayInputStream, metadata)); + + return fileName; + } catch (Exception e) { + throw new ToyouException(ResponseType.INTERNAL_SERVER_ERROR); + } + } + + public String getUrl(String key) { + return amazonS3Client.getResourceUrl(configuration.getBucket(), key); + } +} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelAgent.java b/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelAgent.java new file mode 100644 index 0000000..10260d6 --- /dev/null +++ b/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelAgent.java @@ -0,0 +1,27 @@ +package slvtwn.khu.toyouserver.agent.model; + +import java.util.Base64; +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; +import slvtwn.khu.toyouserver.agent.WebClientWrapper; + +@RequiredArgsConstructor +@Component +public class StickerModelAgent { + + private final StickerModelConfiguration configuration; + + private final WebClientWrapper webClientWrapper; + + public List generateStickers(StickerModelRequest request) { + StickerModelResponse response = webClientWrapper.send( + configuration.getEndpoint(), (httpHeaders -> { + }), request, StickerModelResponse.class + ); + + return response.base64EncodedImages().stream() + .map(each -> Base64.getDecoder().decode(each)) + .toList(); + } +} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelConfiguration.java b/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelConfiguration.java new file mode 100644 index 0000000..0533cf9 --- /dev/null +++ b/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelConfiguration.java @@ -0,0 +1,13 @@ +package slvtwn.khu.toyouserver.agent.model; + +import lombok.Getter; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +@Getter +@Component +public class StickerModelConfiguration { + + @Value("${sticker-model.endpoint}") + private String endpoint; +} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelRequest.java b/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelRequest.java new file mode 100644 index 0000000..92bd737 --- /dev/null +++ b/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelRequest.java @@ -0,0 +1,5 @@ +package slvtwn.khu.toyouserver.agent.model; + +public record StickerModelRequest(String prompt, String color) { + +} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelResponse.java b/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelResponse.java new file mode 100644 index 0000000..1596f6b --- /dev/null +++ b/src/main/java/slvtwn/khu/toyouserver/agent/model/StickerModelResponse.java @@ -0,0 +1,7 @@ +package slvtwn.khu.toyouserver.agent.model; + +import java.util.List; + +public record StickerModelResponse(List base64EncodedImages) { + +} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/MetaDataResponse.java b/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/MetaDataResponse.java deleted file mode 100644 index cd58bb8..0000000 --- a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/MetaDataResponse.java +++ /dev/null @@ -1,26 +0,0 @@ -package slvtwn.khu.toyouserver.agent.modellabs; - -public record MetaDataResponse( - String prompt, - String model_id, - String negative_prompt, - String scheduler, - String safetychecker, - int W, - int H, - double guidance_scale, - long seed, - int steps, - int n_samples, - String full_url, - String upscale, - String panorama, - String self_attention, - String embeddings, - String lora_model, - String lora_strength, - String outdir, - String file_prefix -) { -} - diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsAgent.java b/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsAgent.java deleted file mode 100644 index dc7a60a..0000000 --- a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsAgent.java +++ /dev/null @@ -1,50 +0,0 @@ -package slvtwn.khu.toyouserver.agent.modellabs; - -import java.util.List; -import lombok.RequiredArgsConstructor; -import org.springframework.stereotype.Component; -import slvtwn.khu.toyouserver.agent.WebClientWrapper; - -@RequiredArgsConstructor -@Component -public class ModelLabsAgent { - - private final ModelLabsConfiguration configuration; - private final WebClientWrapper webClientWrapper; - - public List generateCoverWithKeywords(List keywords) { - ModelLabsRequest request = setupBody(keywords); - - ModelLabsResponse response = webClientWrapper.send( - configuration.getModelLabsEndPoint(), (httpHeaders -> { - }), request, ModelLabsResponse.class); - - return response.output(); - } - - private ModelLabsRequest setupBody(List keywords) { - // TODO: prompt, negative prompt, keywords - String prompt = ""; - String negativePrompt = ""; - - return ModelLabsRequest.builder() - .prompt(prompt) - .key(configuration.getKey()) - .model_id(configuration.getModelId()) - .negative_prompt(negativePrompt) - .width("512") - .height("512") - .samples("3") - .num_inference_steps("30") - .safety_checker("no") - .enhance_prompt("yes") - .guidance_scale(9) - .panorama("no") - .self_attention("no") - .upscale("no") - .tomesd("yes") - .use_karras_sigmas("yes") - .scheduler("DPMSolverMultistepScheduler") - .build(); - } -} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsConfiguration.java b/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsConfiguration.java deleted file mode 100644 index afa6079..0000000 --- a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsConfiguration.java +++ /dev/null @@ -1,17 +0,0 @@ -package slvtwn.khu.toyouserver.agent.modellabs; - -import lombok.Getter; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; - -@Getter -@Component -public class ModelLabsConfiguration { - - @Value("${model-labs.endpoint}") - private String modelLabsEndPoint; - @Value("${model-labs.key}") - private String key; - @Value("${model-labs.model-id}") - private String modelId; -} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsRequest.java b/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsRequest.java deleted file mode 100644 index 1f1d71d..0000000 --- a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsRequest.java +++ /dev/null @@ -1,32 +0,0 @@ -package slvtwn.khu.toyouserver.agent.modellabs; - -import lombok.Builder; - -@Builder -public record ModelLabsRequest( - String key, - String model_id, - String prompt, - String negative_prompt, - String width, - String height, - String samples, - String num_inference_steps, - String safety_checker, - String enhance_prompt, - String seed, - int guidance_scale, - String panorama, - String self_attention, - String upscale, - String embeddings_model, - String lora_model, - String tomesd, - String use_karras_sigmas, - String vae, - String lora_strength, - String scheduler, - String webhook, - String track_id -) { -} diff --git a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsResponse.java b/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsResponse.java deleted file mode 100644 index 7f9fd03..0000000 --- a/src/main/java/slvtwn/khu/toyouserver/agent/modellabs/ModelLabsResponse.java +++ /dev/null @@ -1,13 +0,0 @@ -package slvtwn.khu.toyouserver.agent.modellabs; - -import java.util.List; - -public record ModelLabsResponse( - String status, - double generationTime, - int id, - List output, - MetaDataResponse meta -) { - -} diff --git a/src/main/java/slvtwn/khu/toyouserver/application/AgentService.java b/src/main/java/slvtwn/khu/toyouserver/application/AgentService.java index 5fc46fd..eea5b2a 100644 --- a/src/main/java/slvtwn/khu/toyouserver/application/AgentService.java +++ b/src/main/java/slvtwn/khu/toyouserver/application/AgentService.java @@ -2,10 +2,16 @@ import java.util.Arrays; import java.util.List; +import java.util.UUID; import lombok.AllArgsConstructor; import org.springframework.stereotype.Service; +import slvtwn.khu.toyouserver.agent.aws.S3Agent; import slvtwn.khu.toyouserver.agent.gpt.ChatGptAgent; import slvtwn.khu.toyouserver.agent.gpt.ChatGptResponse; +import slvtwn.khu.toyouserver.agent.model.StickerModelAgent; +import slvtwn.khu.toyouserver.agent.model.StickerModelRequest; +import slvtwn.khu.toyouserver.dto.GenerateStickerRequest; +import slvtwn.khu.toyouserver.dto.GenerateStickerResponse; import slvtwn.khu.toyouserver.dto.KeywordRequest; import slvtwn.khu.toyouserver.dto.KeywordResponse; @@ -13,13 +19,26 @@ @Service public class AgentService { + private static final String STICKER_MIME_TYPE = "image/png"; + + private final StickerModelAgent stickerModelAgent; private final ChatGptAgent chatGptAgent; + private final S3Agent s3Agent; + + public GenerateStickerResponse generateStickers(GenerateStickerRequest request) { + List urls = stickerModelAgent.generateStickers(new StickerModelRequest(request.prompt(), request.color())).stream() + .map(each -> s3Agent.uploadFile(each, UUID.nameUUIDFromBytes(each).toString(), STICKER_MIME_TYPE)) + .map(s3Agent::getUrl) + .toList(); + + return new GenerateStickerResponse(urls); + } public KeywordResponse generateKeywords(KeywordRequest request) { String content = request.content(); String prompt = String.format(""" Suggest 3 keywords that could represent emotions or characteristics in the content. - + content: @@ -31,7 +50,7 @@ public KeywordResponse generateKeywords(KeywordRequest request) { 반가움, 기대, 아쉬움 - + content: @@ -42,7 +61,7 @@ public KeywordResponse generateKeywords(KeywordRequest request) { 축하, 아쉬움, 즐거움 - + content: %s keywords: """, content); diff --git a/src/main/java/slvtwn/khu/toyouserver/application/RollingPaperService.java b/src/main/java/slvtwn/khu/toyouserver/application/RollingPaperService.java index a0c6f05..ab2be80 100644 --- a/src/main/java/slvtwn/khu/toyouserver/application/RollingPaperService.java +++ b/src/main/java/slvtwn/khu/toyouserver/application/RollingPaperService.java @@ -5,7 +5,6 @@ import org.springframework.data.domain.PageRequest; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import slvtwn.khu.toyouserver.agent.modellabs.ModelLabsAgent; import slvtwn.khu.toyouserver.common.response.ResponseType; import slvtwn.khu.toyouserver.domain.Member; import slvtwn.khu.toyouserver.domain.RollingPaper; @@ -31,17 +30,18 @@ public class RollingPaperService { private final RollingPaperRepository rollingPaperRepository; private final StickerRepository stickerRepository; private final MemberRepository memberRepository; - private final ModelLabsAgent modelLabsAgent; private final RollingPaperRepository rollingpaperRepository; private final UserRepository userRepository; + // TODO: ModelLabs 관련 문제로 DISABLED + @Deprecated public void generateCoverImageAndUpdateRollingPaper(CoverRequest request, Long rollingPaperId) { - List keywords = request.keywords(); - String coverImageUrl = modelLabsAgent.generateCoverWithKeywords(keywords) - .get(FIRST_COVER_IMAGE); - RollingPaper rollingPaper = rollingpaperRepository.findById(rollingPaperId) - .orElseThrow(() -> new ToyouException(ResponseType.BAD_REQUEST)); - rollingPaper.updateCoverImage(coverImageUrl); +// List keywords = request.keywords(); +// String coverImageUrl = modelLabsAgent.generateCoverWithKeywords(keywords) +// .get(FIRST_COVER_IMAGE); +// RollingPaper rollingPaper = rollingpaperRepository.findById(rollingPaperId) +// .orElseThrow(() -> new ToyouException(ResponseType.BAD_REQUEST)); +// rollingPaper.updateCoverImage(coverImageUrl); } public void sendRollingPaper(Long recipientUserId, RollingPaperRequest request) { diff --git a/src/main/java/slvtwn/khu/toyouserver/dto/GenerateStickerRequest.java b/src/main/java/slvtwn/khu/toyouserver/dto/GenerateStickerRequest.java new file mode 100644 index 0000000..ee7d800 --- /dev/null +++ b/src/main/java/slvtwn/khu/toyouserver/dto/GenerateStickerRequest.java @@ -0,0 +1,5 @@ +package slvtwn.khu.toyouserver.dto; + +public record GenerateStickerRequest(String prompt, String color) { + +} diff --git a/src/main/java/slvtwn/khu/toyouserver/dto/GenerateStickerResponse.java b/src/main/java/slvtwn/khu/toyouserver/dto/GenerateStickerResponse.java new file mode 100644 index 0000000..3b84700 --- /dev/null +++ b/src/main/java/slvtwn/khu/toyouserver/dto/GenerateStickerResponse.java @@ -0,0 +1,7 @@ +package slvtwn.khu.toyouserver.dto; + +import java.util.List; + +public record GenerateStickerResponse(List stickers) { + +} diff --git a/src/main/java/slvtwn/khu/toyouserver/presentation/AgentController.java b/src/main/java/slvtwn/khu/toyouserver/presentation/AgentController.java index a33b227..136543e 100644 --- a/src/main/java/slvtwn/khu/toyouserver/presentation/AgentController.java +++ b/src/main/java/slvtwn/khu/toyouserver/presentation/AgentController.java @@ -6,6 +6,8 @@ import org.springframework.web.bind.annotation.RestController; import slvtwn.khu.toyouserver.application.AgentService; import slvtwn.khu.toyouserver.common.response.ToyouResponse; +import slvtwn.khu.toyouserver.dto.GenerateStickerRequest; +import slvtwn.khu.toyouserver.dto.GenerateStickerResponse; import slvtwn.khu.toyouserver.dto.KeywordRequest; import slvtwn.khu.toyouserver.dto.KeywordResponse; @@ -15,9 +17,15 @@ public class AgentController { private final AgentService agentService; + @PostMapping("/generate-stickers") + public ToyouResponse generateStickers(@RequestBody GenerateStickerRequest request) { + GenerateStickerResponse response = agentService.generateStickers(request); + return ToyouResponse.from(response); + } + @PostMapping("/generate-keywords") - public ToyouResponse generateKeywords(@RequestBody KeywordRequest keywordRequest) { - KeywordResponse response = agentService.generateKeywords(keywordRequest); + public ToyouResponse generateKeywords(@RequestBody KeywordRequest request) { + KeywordResponse response = agentService.generateKeywords(request); return ToyouResponse.from(response); } } diff --git a/src/main/java/slvtwn/khu/toyouserver/presentation/RollingPaperController.java b/src/main/java/slvtwn/khu/toyouserver/presentation/RollingPaperController.java index 36d4bb9..e3f7a02 100644 --- a/src/main/java/slvtwn/khu/toyouserver/presentation/RollingPaperController.java +++ b/src/main/java/slvtwn/khu/toyouserver/presentation/RollingPaperController.java @@ -43,6 +43,8 @@ public ToyouResponse sendRollingPaper(@UserAuthentication Long RequestUser return ToyouResponse.noContent(); } + // TODO: ModelLabs 관련 문제로 DISABLED + @Deprecated @PostMapping("/rollingpapers/{rollingPaperId}/generate-cover") public ToyouResponse generateCoverImage(@UserAuthentication Long userId, @PathVariable Long rollingPaperId, diff --git a/src/test/java/slvtwn/khu/toyouserver/application/RollingPaperServiceTest.java b/src/test/java/slvtwn/khu/toyouserver/application/RollingPaperServiceTest.java index adc6b8c..74b6ce0 100644 --- a/src/test/java/slvtwn/khu/toyouserver/application/RollingPaperServiceTest.java +++ b/src/test/java/slvtwn/khu/toyouserver/application/RollingPaperServiceTest.java @@ -16,7 +16,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.transaction.annotation.Transactional; -import slvtwn.khu.toyouserver.agent.modellabs.ModelLabsAgent; import slvtwn.khu.toyouserver.domain.Group; import slvtwn.khu.toyouserver.domain.Member; import slvtwn.khu.toyouserver.domain.RollingPaper; @@ -34,37 +33,37 @@ class RollingPaperServiceTest { @PersistenceContext private EntityManager entityManager; - @MockBean - private ModelLabsAgent modelLabsAgent; +// @MockBean +// private ModelLabsAgent modelLabsAgent; @Autowired private RollingPaperService rollingPaperService; - @Test - void 커버_이미지를_생성하면_롤링페이퍼_커버가_업데이트된다() { - // given - Group group = new Group("name"); - User user = new User("name", LocalDate.now(), "introduction", "profile_picture", null); - Member member = new Member(user, group); - RollingPaper rollingPaper = new RollingPaper(null, "title", "content", 1L, member); - - entityManager.persist(user); - entityManager.persist(group); - entityManager.persist(member); - entityManager.persist(rollingPaper); - - String coverImageUrl = "cover_image_url"; - CoverRequest request = new CoverRequest(List.of()); - - given(modelLabsAgent.generateCoverWithKeywords(any())) - .willReturn(List.of(coverImageUrl)); - - // when - rollingPaperService.generateCoverImageAndUpdateRollingPaper(request, rollingPaper.getId()); - - // then - assertThat(rollingPaper.getCoverImageUrl()).isEqualTo(coverImageUrl); - } +// @Test +// void 커버_이미지를_생성하면_롤링페이퍼_커버가_업데이트된다() { +// // given +// Group group = new Group("name"); +// User user = new User("name", LocalDate.now(), "introduction", "profile_picture", null); +// Member member = new Member(user, group); +// RollingPaper rollingPaper = new RollingPaper(null, "title", "content", 1L, member); +// +// entityManager.persist(user); +// entityManager.persist(group); +// entityManager.persist(member); +// entityManager.persist(rollingPaper); +// +// String coverImageUrl = "cover_image_url"; +// CoverRequest request = new CoverRequest(List.of()); +// +// given(modelLabsAgent.generateCoverWithKeywords(any())) +// .willReturn(List.of(coverImageUrl)); +// +// // when +// rollingPaperService.generateCoverImageAndUpdateRollingPaper(request, rollingPaper.getId()); +// +// // then +// assertThat(rollingPaper.getCoverImageUrl()).isEqualTo(coverImageUrl); +// } @Test void 롤링페이퍼를_전송할_수_있다() { diff --git a/src/test/resources/application.yml b/src/test/resources/application.yml index 48441f8..5fc4e5c 100644 --- a/src/test/resources/application.yml +++ b/src/test/resources/application.yml @@ -34,7 +34,15 @@ openai: endpoint: test model: gpt-4o-mini -model-labs: - endpoint: test - key: test - model-id: test \ No newline at end of file +sticker-model: + endpoint: abc + +cloud: + aws: + credentials: + access-key: abc + secret-key: abc + region: + static: abc + s3: + bucket: abc diff --git a/toyou-server-config b/toyou-server-config index d16c870..a5e3fc8 160000 --- a/toyou-server-config +++ b/toyou-server-config @@ -1 +1 @@ -Subproject commit d16c870f250210bf0ec44bbffe8a0397a43588c5 +Subproject commit a5e3fc8b5f6db092a27238396fd3f298c3e639b7