Skip to content

Commit

Permalink
feat: WebSocket을 HTTP로 변경
Browse files Browse the repository at this point in the history
  • Loading branch information
redcarrot1 committed Feb 7, 2024
1 parent 1cc4f22 commit 5616e71
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,138 +25,138 @@
@RequiredArgsConstructor
public class ClientInferenceController2 extends WebSocketClientController {

private final Map<WebSocketSession, SessionInfo> sessions = new ConcurrentHashMap<>();
private final int savePeriodMilliSeconds = 5000;
private final byte[] scoreByte = {0x30, 0x31, 0x32, 0x33, 0x34, 0x35};
private final InferenceService inferenceService;
private final ClientRepository clientRepository;
private final ImageInferenceRepository imageInferenceRepository;
private final SocketClient socketClient = new SocketClient("127.0.0.1", 9999, 2000);

//@Value("${image.path}")
private String path = "/home/hsk4991149/static/image/";

@Override
public void afterConnectionEstablished(WebSocketSession session) {
Integer totalExpectedSeconds = parseTimeHeader(session);
if (totalExpectedSeconds == null) return;

Client client = new Client(totalExpectedSeconds);
clientRepository.save(client);

session.setBinaryMessageSizeLimit(1024 * 1024 * 10);
long currentTimeMillis = System.currentTimeMillis();
sessions.put(session, new SessionInfo(client, currentTimeMillis, totalExpectedSeconds));
}

private Integer parseTimeHeader(WebSocketSession session) {
List<String> time = session.getHandshakeHeaders().get("time");
if (time == null) {
System.err.println("time 헤더가 없습니다.");
try {
session.sendMessage(new TextMessage("time 헤더가 없습니다."));
session.close();
} catch (IOException e) {
}
return null;
}
System.out.println("time 헤더: " + time.get(0));

int totalExpectedSeconds;
try {
totalExpectedSeconds = Integer.parseInt(time.get(0));
} catch (NumberFormatException e) {
System.err.println("time 헤더가 숫자가 아닙니다.");
try {
session.sendMessage(new TextMessage("time 헤더가 숫자가 아닙니다."));
session.close();
} catch (IOException e1) {
}
return null;
}
return totalExpectedSeconds;
}

@Override
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) {
byte[] bytes = message.getPayload().array();
// if (bytes.length == 1) {
// for (int i = 0; i <= 5; i++) {
// if (bytes[0] == scoreByte[i]) {
// System.out.println("점수 : " + i);
// Client client = sessions.get(session).getClient();
// client.updateResult(i, solution.gdsc.PathPal.domain.client.domain.CloseStatus.NORMAL);
// clientRepository.save(client);
// }
// }
// private final Map<WebSocketSession, SessionInfo> sessions = new ConcurrentHashMap<>();
// private final int savePeriodMilliSeconds = 5000;
// private final byte[] scoreByte = {0x30, 0x31, 0x32, 0x33, 0x34, 0x35};
// private final InferenceService inferenceService;
// private final ClientRepository clientRepository;
// private final ImageInferenceRepository imageInferenceRepository;
// private final SocketClient socketClient = new SocketClient("127.0.0.1", 9999, 2000);
//
// //@Value("${image.path}")
// private String path = "/home/hsk4991149/static/image/";
//
// @Override
// public void afterConnectionEstablished(WebSocketSession session) {
// Integer totalExpectedSeconds = parseTimeHeader(session);
// if (totalExpectedSeconds == null) return;
//
// Client client = new Client(totalExpectedSeconds);
// clientRepository.save(client);
//
// session.setBinaryMessageSizeLimit(1024 * 1024 * 10);
// long currentTimeMillis = System.currentTimeMillis();
// sessions.put(session, new SessionInfo(client, currentTimeMillis, totalExpectedSeconds));
// }
//
// private Integer parseTimeHeader(WebSocketSession session) {
// List<String> time = session.getHandshakeHeaders().get("time");
// if (time == null) {
// System.err.println("time 헤더가 없습니다.");
// try {
// session.sendMessage(new TextMessage("time 헤더가 없습니다."));
// session.close();
// } catch (IOException e) {
// }
// return;
// return null;
// }

String responseMessage = "";
try {
List<Inference> inferences = socketClient.inferenceImage(bytes);
System.out.println("Inference 변환 성공");
// List<InferenceTranslate> inferenceTranslates = inferenceService.convertInference(inferences);
// System.out.println("Translate 성공");
// responseMessage = JsonUtil.toJsonFormat(inferenceTranslates);
// System.out.println("json 변환 성공");
responseMessage = inferenceService.convertInference2(inferences);
} catch (Exception e) {
System.err.println("추론 실패");
responseMessage = "";
}

saveImageAndInferenceResult(session, bytes, responseMessage);

try {
System.out.println("전송 메시지: " + responseMessage);
session.sendMessage(new TextMessage(responseMessage));
} catch (IOException e) {
System.err.println("전송실패");
throw new RuntimeException(e);
}
}

private void saveImageAndInferenceResult(WebSocketSession session, byte[] bytes, String responseMessage) {
long currentTimeMillis = System.currentTimeMillis();

SessionInfo sessionInfo = sessions.get(session);
if (currentTimeMillis - sessionInfo.recentlySaveTime > savePeriodMilliSeconds) {

String fileFullName = path + sessionInfo.getClient().getId() + "T" + currentTimeMillis + ".jpeg";
try (FileImageOutputStream imageOutput = new FileImageOutputStream(new File(fileFullName))) {
imageOutput.write(bytes, 0, bytes.length);

ImageInference imageInference = new ImageInference(fileFullName, sessionInfo.getClient(), responseMessage);
imageInferenceRepository.save(imageInference);

sessionInfo.updateRecentlySaveTime(currentTimeMillis);
System.out.println("데이터 저장 (currentTimeMillis: " + currentTimeMillis + "). imageId = " + fileFullName);
} catch (Exception e) {
System.err.println("이미지 저장 실패");
}
}
}

@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
if (status.equalsCode(CloseStatus.NORMAL)) {
System.out.println("정상 종료");
Client client = sessions.get(session).getClient();
client.updateResult(-1, solution.gdsc.PathPal.domain.client.domain.CloseStatus.NORMAL);
clientRepository.save(client);
}
else {
System.out.println("비정상 종료. reason = " + status.getReason() + "\n");
Client client = sessions.get(session).getClient();
client.updateResult(-1, solution.gdsc.PathPal.domain.client.domain.CloseStatus.convert(status));
clientRepository.save(client);
}
sessions.remove(session);
}
// System.out.println("time 헤더: " + time.get(0));
//
// int totalExpectedSeconds;
// try {
// totalExpectedSeconds = Integer.parseInt(time.get(0));
// } catch (NumberFormatException e) {
// System.err.println("time 헤더가 숫자가 아닙니다.");
// try {
// session.sendMessage(new TextMessage("time 헤더가 숫자가 아닙니다."));
// session.close();
// } catch (IOException e1) {
// }
// return null;
// }
// return totalExpectedSeconds;
// }
//
// @Override
// protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) {
// byte[] bytes = message.getPayload().array();
//// if (bytes.length == 1) {
//// for (int i = 0; i <= 5; i++) {
//// if (bytes[0] == scoreByte[i]) {
//// System.out.println("점수 : " + i);
//// Client client = sessions.get(session).getClient();
//// client.updateResult(i, solution.gdsc.PathPal.domain.client.domain.CloseStatus.NORMAL);
//// clientRepository.save(client);
//// }
//// }
////
//// try {
//// session.close();
//// } catch (IOException e) {
//// }
//// return;
//// }
//
// String responseMessage = "";
// try {
// List<Inference> inferences = socketClient.inferenceImage(bytes);
// System.out.println("Inference 변환 성공");
//// List<InferenceTranslate> inferenceTranslates = inferenceService.convertInference(inferences);
//// System.out.println("Translate 성공");
//// responseMessage = JsonUtil.toJsonFormat(inferenceTranslates);
//// System.out.println("json 변환 성공");
// responseMessage = inferenceService.convertInference2(inferences);
// } catch (Exception e) {
// System.err.println("추론 실패");
// responseMessage = "";
// }
//
// saveImageAndInferenceResult(session, bytes, responseMessage);
//
// try {
// System.out.println("전송 메시지: " + responseMessage);
// session.sendMessage(new TextMessage(responseMessage));
// } catch (IOException e) {
// System.err.println("전송실패");
// throw new RuntimeException(e);
// }
// }
//
// private void saveImageAndInferenceResult(WebSocketSession session, byte[] bytes, String responseMessage) {
// long currentTimeMillis = System.currentTimeMillis();
//
// SessionInfo sessionInfo = sessions.get(session);
// if (currentTimeMillis - sessionInfo.recentlySaveTime > savePeriodMilliSeconds) {
//
// String fileFullName = path + sessionInfo.getClient().getId() + "T" + currentTimeMillis + ".jpeg";
// try (FileImageOutputStream imageOutput = new FileImageOutputStream(new File(fileFullName))) {
// imageOutput.write(bytes, 0, bytes.length);
//
// ImageInference imageInference = new ImageInference(fileFullName, sessionInfo.getClient(), responseMessage);
// imageInferenceRepository.save(imageInference);
//
// sessionInfo.updateRecentlySaveTime(currentTimeMillis);
// System.out.println("데이터 저장 (currentTimeMillis: " + currentTimeMillis + "). imageId = " + fileFullName);
// } catch (Exception e) {
// System.err.println("이미지 저장 실패");
// }
// }
// }
//
// @Override
// public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
// if (status.equalsCode(CloseStatus.NORMAL)) {
// System.out.println("정상 종료");
// Client client = sessions.get(session).getClient();
// client.updateResult(-1, solution.gdsc.PathPal.domain.client.domain.CloseStatus.NORMAL);
// clientRepository.save(client);
// }
// else {
// System.out.println("비정상 종료. reason = " + status.getReason() + "\n");
// Client client = sessions.get(session).getClient();
// client.updateResult(-1, solution.gdsc.PathPal.domain.client.domain.CloseStatus.convert(status));
// clientRepository.save(client);
// }
// sessions.remove(session);
// }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package solution.gdsc.PathPal.domain.inference.api;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpSession;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import solution.gdsc.PathPal.domain.client.domain.Client;
import solution.gdsc.PathPal.domain.client.domain.ImageInference;
import solution.gdsc.PathPal.domain.client.repository.ClientRepository;
import solution.gdsc.PathPal.domain.client.repository.ImageInferenceRepository;
import solution.gdsc.PathPal.domain.inference.domain.Inference;
import solution.gdsc.PathPal.domain.inference.service.InferenceService;
import solution.gdsc.PathPal.domain.inference.service.SocketClient;

import javax.imageio.stream.FileImageOutputStream;
import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@RestController
@RequiredArgsConstructor
public class HttpClientInferenceController {

private final SocketClient socketClient = new SocketClient("127.0.0.1", 9999, 2000);
private final InferenceService inferenceService;
private final ImageInferenceRepository imageInferenceRepository;
private final ClientRepository clientRepository;
private final int savePeriodMilliSeconds = 5000;
private String path = "/home/hsk4991149/static/image/";
//private String path = "/Users/hongseungtaeg/Desktop/project/GDSC-PathPal/PathPal/src/main/resources/static/";

private final Map<HttpSession, SessionInfo> sessionMap = new HashMap<>();

@PostMapping("/inference")
public String inference(@RequestParam(name = "image") MultipartFile file,
@RequestParam(name = "time") Integer time,
HttpServletRequest request) throws Exception {
HttpSession session = request.getSession();

SessionInfo sessionInfo = sessionMap.get(session);
if (sessionInfo == null) {
int expectedSeconds = time == null ? 500 : time;
Client client = new Client(expectedSeconds);
clientRepository.save(client);

long currentTimeMillis = System.currentTimeMillis();
sessionInfo = new SessionInfo(client, currentTimeMillis, expectedSeconds);
sessionMap.put(session, sessionInfo);
}

byte[] bytes = file.getBytes();
List<Inference> inferences = socketClient.inferenceImage(bytes);
String responseMessage = inferenceService.convertInference2(inferences);

long currentTimeMillis = System.currentTimeMillis();
if (currentTimeMillis - sessionInfo.recentlySaveTime > savePeriodMilliSeconds) {

String fileFullName = path + sessionInfo.getClient().getId() + "T" + currentTimeMillis + ".jpeg";
try (FileImageOutputStream imageOutput = new FileImageOutputStream(new File(fileFullName))) {
imageOutput.write(bytes, 0, bytes.length);

ImageInference imageInference = new ImageInference(fileFullName, sessionInfo.getClient(), responseMessage);
imageInferenceRepository.save(imageInference);

sessionInfo.updateRecentlySaveTime(currentTimeMillis);
System.out.println("데이터 저장 (currentTimeMillis: " + currentTimeMillis + "). imageId = " + fileFullName);
} catch (Exception e) {
System.err.println("이미지 저장 실패");
}
}
return responseMessage;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
@EnableWebSocket
public class WebSocketConfiguration implements WebSocketConfigurer {

@Autowired
private WebSocketClientController clientInferenceController;
//@Autowired
//private WebSocketClientController clientInferenceController;
//private SocketHandlerTest clientInferenceController;

@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(clientInferenceController, "/socket") // Handshake 주소
.setAllowedOrigins("*");
// registry.addHandler(clientInferenceController, "/socket") // Handshake 주소
// .setAllowedOrigins("*");
}
}

0 comments on commit 5616e71

Please sign in to comment.