diff --git a/src/main/java/com/example/CatchStudy/global/config/StompHandler.java b/src/main/java/com/example/CatchStudy/global/config/StompHandler.java new file mode 100644 index 0000000..ec7804a --- /dev/null +++ b/src/main/java/com/example/CatchStudy/global/config/StompHandler.java @@ -0,0 +1,52 @@ +package com.example.CatchStudy.global.config; + +import com.example.CatchStudy.domain.dto.response.Response; +import com.example.CatchStudy.global.exception.CatchStudyException; +import com.example.CatchStudy.global.exception.ErrorCode; +import com.example.CatchStudy.global.jwt.JwtUtil; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.jsonwebtoken.ExpiredJwtException; +import jakarta.servlet.http.HttpServletResponse; +import lombok.RequiredArgsConstructor; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +@Component +@RequiredArgsConstructor +public class StompHandler implements ChannelInterceptor { + + private final JwtUtil jwtUtil; + + @Override + public Message preSend(Message message, MessageChannel channel) { + StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); + String accessToken = ""; + // 연결 요청에 대해 실행 + if(accessor.getCommand() == StompCommand.CONNECT) { + accessToken = accessor.getFirstNativeHeader("Authorization"); + } + + if (StringUtils.hasText(accessToken) && accessToken.startsWith("Bearer ")) { + accessToken = accessToken.substring(7); + } + + try { + jwtUtil.validateAccessToken(accessToken); + } catch (ExpiredJwtException e) { + throw new CatchStudyException(ErrorCode.EXPIRED_ACCESS_TOKEN); + } + + Authentication authentication = jwtUtil.getAuthentication(accessToken); + SecurityContextHolder.getContext().setAuthentication(authentication); + accessor.setUser(authentication); + + return message; + } +} diff --git a/src/main/java/com/example/CatchStudy/global/config/StompWebSocketConfig.java b/src/main/java/com/example/CatchStudy/global/config/StompWebSocketConfig.java index f84e509..4326019 100644 --- a/src/main/java/com/example/CatchStudy/global/config/StompWebSocketConfig.java +++ b/src/main/java/com/example/CatchStudy/global/config/StompWebSocketConfig.java @@ -2,6 +2,7 @@ import lombok.RequiredArgsConstructor; import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; @@ -12,6 +13,8 @@ @RequiredArgsConstructor public class StompWebSocketConfig implements WebSocketMessageBrokerConfigurer { + private final StompHandler stompHandler; + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { @@ -26,4 +29,9 @@ public void configureMessageBroker(MessageBrokerRegistry registry) { registry.setApplicationDestinationPrefixes("/pub"); // 사용자 -> 서버 registry.enableSimpleBroker("/sub"); // 서버 -> 사용자 } + + @Override + public void configureClientInboundChannel(ChannelRegistration registration) { + registration.interceptors(stompHandler); + } }