diff --git a/app/src/main/java/com/techcourse/dao/UserDao.java b/app/src/main/java/com/techcourse/dao/UserDao.java index db2c407299..e5a9559361 100644 --- a/app/src/main/java/com/techcourse/dao/UserDao.java +++ b/app/src/main/java/com/techcourse/dao/UserDao.java @@ -1,7 +1,6 @@ package com.techcourse.dao; import com.techcourse.domain.User; -import java.sql.Connection; import java.util.List; import javax.sql.DataSource; import org.springframework.dao.DataAccessException; @@ -41,12 +40,6 @@ public void update(final User user) { jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail()); } - public void update(Connection connection, User user) { - final var sql = "update users set account = ?, password = ?, email = ? "; - - jdbcTemplate.update(connection, sql, user.getAccount(), user.getPassword(), user.getEmail()); - } - public List findAll() { final var sql = "select id, account, password, email from users "; diff --git a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java index 2cd13076a1..7ef907c94e 100644 --- a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java +++ b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java @@ -1,7 +1,6 @@ package com.techcourse.dao; import com.techcourse.domain.UserHistory; -import java.sql.Connection; import javax.sql.DataSource; import org.springframework.jdbc.core.JdbcTemplate; @@ -31,18 +30,4 @@ public void log(final UserHistory userHistory) { ); } - public void log(Connection connection, UserHistory userHistory) { - final var sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; - - jdbcTemplate.update( - connection, - sql, - userHistory.getUserId(), - userHistory.getAccount(), - userHistory.getPassword(), - userHistory.getEmail(), - userHistory.getCreatedAt(), - userHistory.getCreateBy() - ); - } } diff --git a/app/src/main/java/com/techcourse/service/AppUserService.java b/app/src/main/java/com/techcourse/service/AppUserService.java new file mode 100644 index 0000000000..dd06db3096 --- /dev/null +++ b/app/src/main/java/com/techcourse/service/AppUserService.java @@ -0,0 +1,37 @@ +package com.techcourse.service; + +import com.techcourse.dao.UserDao; +import com.techcourse.dao.UserHistoryDao; +import com.techcourse.domain.User; +import com.techcourse.domain.UserHistory; + +public class AppUserService implements UserService { + + private final UserDao userDao; + private final UserHistoryDao userHistoryDao; + + public AppUserService(UserDao userDao, UserHistoryDao userHistoryDao) { + this.userDao = userDao; + this.userHistoryDao = userHistoryDao; + } + + @Override + public User findById(final long id) { + return userDao.findById(id); + } + + @Override + public void insert(final User user) { + userDao.insert(user); + } + + @Override + public void changePassword(long id, String newPassword, String createBy) { + final var user = findById(id); + user.changePassword(newPassword); + + userDao.update(user); + userHistoryDao.log(new UserHistory(user, createBy)); + } + +} diff --git a/app/src/main/java/com/techcourse/service/TxUserService.java b/app/src/main/java/com/techcourse/service/TxUserService.java new file mode 100644 index 0000000000..35faabd098 --- /dev/null +++ b/app/src/main/java/com/techcourse/service/TxUserService.java @@ -0,0 +1,31 @@ +package com.techcourse.service; + +import com.techcourse.domain.User; +import org.springframework.transaction.support.TransactionManager; + +public class TxUserService implements UserService { + + private final TransactionManager transactionManager; + private final UserService userService; + + public TxUserService(TransactionManager transactionManager, UserService userService) { + this.transactionManager = transactionManager; + this.userService = userService; + } + + @Override + public User findById(long id) { + return transactionManager.query(() -> userService.findById(id)); + } + + @Override + public void insert(User user) { + transactionManager.execute(() -> userService.insert(user)); + } + + + @Override + public void changePassword(long id, String newPassword, String createBy) { + transactionManager.execute(() -> userService.changePassword(id, newPassword, createBy)); + } +} diff --git a/app/src/main/java/com/techcourse/service/UserService.java b/app/src/main/java/com/techcourse/service/UserService.java index 5a7c714f0f..b14dbcacbf 100644 --- a/app/src/main/java/com/techcourse/service/UserService.java +++ b/app/src/main/java/com/techcourse/service/UserService.java @@ -1,76 +1,12 @@ package com.techcourse.service; -import com.techcourse.config.DataSourceConfig; -import com.techcourse.dao.UserDao; -import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; -import com.techcourse.domain.UserHistory; -import java.sql.Connection; -import java.sql.SQLException; -import javax.sql.DataSource; -import org.springframework.dao.DataAccessException; -public class UserService { +public interface UserService { - private final UserDao userDao; - private final UserHistoryDao userHistoryDao; + User findById(final long id); - public UserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { - this.userDao = userDao; - this.userHistoryDao = userHistoryDao; - } - - public User findById(final long id) { - return userDao.findById(id); - } - - public void insert(final User user) { - userDao.insert(user); - } - - public void changePassword(final long id, final String newPassword, final String createBy) { - final var user = findById(id); - user.changePassword(newPassword); - - DataSource dataSource = DataSourceConfig.getInstance(); - Connection connection = null; - - try { - connection = dataSource.getConnection(); - connection.setAutoCommit(false); - - userDao.update(connection, user); - userHistoryDao.log(connection, new UserHistory(user, createBy)); - - connection.commit(); - } catch (RuntimeException | SQLException e) { - rollback(connection); - throw new DataAccessException(e); - } finally { - close(connection); - } - } - - private void rollback(Connection connection) { - if (connection == null) { - return; - } - try { - connection.rollback(); - } catch (final SQLException e) { - throw new DataAccessException(e); - } - } - - private void close(final Connection connection) { - if (connection == null) { - return; - } - try { - connection.close(); - } catch (final SQLException e) { - throw new DataAccessException(e); - } - } + void insert(final User user); + void changePassword(final long id, final String newPassword, final String createBy); } diff --git a/app/src/main/java/com/techcourse/support/jdbc/init/DatabasePopulatorUtils.java b/app/src/main/java/com/techcourse/support/jdbc/init/DatabasePopulatorUtils.java index 0a371ecfa3..db8f5cd5de 100644 --- a/app/src/main/java/com/techcourse/support/jdbc/init/DatabasePopulatorUtils.java +++ b/app/src/main/java/com/techcourse/support/jdbc/init/DatabasePopulatorUtils.java @@ -32,13 +32,17 @@ public static void execute(final DataSource dataSource) { if (statement != null) { statement.close(); } - } catch (SQLException ignored) {} + } catch (SQLException e) { + log.warn(String.valueOf(e)); + } try { if (connection != null) { connection.close(); } - } catch (SQLException ignored) {} + } catch (SQLException e) { + log.warn(String.valueOf(e)); + } } } diff --git a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java index d1166e1d0c..c720ab0d58 100644 --- a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java +++ b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java @@ -2,7 +2,6 @@ import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.UserHistory; -import java.sql.Connection; import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; @@ -17,8 +16,4 @@ public void log(final UserHistory userHistory) { throw new DataAccessException(); } - @Override - public void log(Connection connection, UserHistory userHistory) { - throw new DataAccessException(); - } } diff --git a/app/src/test/java/com/techcourse/service/UserServiceTest.java b/app/src/test/java/com/techcourse/service/UserServiceTest.java index 2f31959edb..e0f643e778 100644 --- a/app/src/test/java/com/techcourse/service/UserServiceTest.java +++ b/app/src/test/java/com/techcourse/service/UserServiceTest.java @@ -12,6 +12,7 @@ import org.junit.jupiter.api.Test; import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.transaction.support.TransactionManager; class UserServiceTest { @@ -31,7 +32,7 @@ void setUp() { @Test void testChangePassword() { final var userHistoryDao = new UserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final var userService = new AppUserService(userDao, userHistoryDao); final var newPassword = "qqqqq"; final var createBy = "gugu"; @@ -46,7 +47,12 @@ void testChangePassword() { void testTransactionRollback() { // 트랜잭션 롤백 테스트를 위해 mock으로 교체 final var userHistoryDao = new MockUserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + // 애플리케이션 서비스 + final var appUserService = new AppUserService(userDao, userHistoryDao); + TransactionManager transactionManager = new TransactionManager(DataSourceConfig.getInstance()); + + // 트랜잭션 서비스 추상화 + final var userService = new TxUserService(transactionManager,appUserService); final var newPassword = "newPassword"; final var createBy = "gugu"; diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index e4159e5e11..ae957c4ae4 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -11,6 +11,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; public class JdbcTemplate { @@ -24,8 +25,9 @@ public JdbcTemplate(final DataSource dataSource) { } public void update(final String sql, final Object... args) { + Connection connection = DataSourceUtils.getConnection(dataSource); + try ( - Connection connection = dataSource.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(sql) ) { log.debug(QUERY_FORMAT, sql); @@ -37,20 +39,11 @@ public void update(final String sql, final Object... args) { } } - public void update(Connection connection, String sql, Object... args) { - try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { - log.debug(QUERY_FORMAT, sql); - - bindStatementWithArgs(args, preparedStatement); - preparedStatement.executeUpdate(); - } catch (SQLException e) { - throw new DataAccessException(e); - } - } public Optional queryForObject(final String sql, final RowMapper rowMapper, final Object... args) { + Connection connection = DataSourceUtils.getConnection(dataSource); + try ( - Connection connection = dataSource.getConnection(); PreparedStatement preparedStatement = getPreparedStatement(connection, sql, args); ResultSet resultSet = preparedStatement.executeQuery() ) { @@ -68,26 +61,10 @@ public Optional queryForObject(final String sql, final RowMapper rowMa } } - private PreparedStatement getPreparedStatement( - Connection connection, - String sql, - Object[] args - ) throws SQLException { - PreparedStatement preparedStatement = connection.prepareStatement(sql); - bindStatementWithArgs(args, preparedStatement); - - return preparedStatement; - } - - private void bindStatementWithArgs(Object[] args, PreparedStatement preparedStatement) throws SQLException { - for (int i = 1; i <= args.length; i++) { - preparedStatement.setObject(i, args[i - 1]); - } - } - public List query(final String sql, final RowMapper rowMapper) { + Connection connection = DataSourceUtils.getConnection(dataSource); + try ( - Connection connection = dataSource.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement(sql); ResultSet resultSet = preparedStatement.executeQuery() ) { @@ -106,4 +83,21 @@ public List query(final String sql, final RowMapper rowMapper) { } } + private PreparedStatement getPreparedStatement( + Connection connection, + String sql, + Object[] args + ) throws SQLException { + PreparedStatement preparedStatement = connection.prepareStatement(sql); + bindStatementWithArgs(args, preparedStatement); + + return preparedStatement; + } + + private void bindStatementWithArgs(Object[] args, PreparedStatement preparedStatement) throws SQLException { + for (int i = 1; i <= args.length; i++) { + preparedStatement.setObject(i, args[i - 1]); + } + } + } diff --git a/jdbc/src/main/java/org/springframework/transaction/TransactionException.java b/jdbc/src/main/java/org/springframework/transaction/TransactionException.java new file mode 100644 index 0000000000..6e08ab6c43 --- /dev/null +++ b/jdbc/src/main/java/org/springframework/transaction/TransactionException.java @@ -0,0 +1,9 @@ +package org.springframework.transaction; + +public class TransactionException extends RuntimeException { + + public TransactionException(Exception e) { + super(e); + } + +} diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionManager.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionManager.java new file mode 100644 index 0000000000..811d92aa3b --- /dev/null +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionManager.java @@ -0,0 +1,86 @@ +package org.springframework.transaction.support; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.function.Supplier; +import javax.sql.DataSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; +import org.springframework.transaction.TransactionException; + +public class TransactionManager { + + private static final Logger log = LoggerFactory.getLogger(TransactionManager.class); + + private final DataSource dataSource; + + public TransactionManager(DataSource dataSource) { + this.dataSource = dataSource; + } + + public void execute(Runnable commandExecutor) { + Connection connection = DataSourceUtils.getConnection(dataSource); + + try { + connection.setAutoCommit(false); + + commandExecutor.run(); + + connection.commit(); + } catch (Exception e) { + rollback(connection); + throw new DataAccessException(e); + } finally { + release(connection); + } + } + + public T query(Supplier queryExecutor) { + Connection connection = DataSourceUtils.getConnection(dataSource); + + try { + connection.setReadOnly(true); + + connection.setAutoCommit(false); + T result = queryExecutor.get(); + + connection.commit(); + return result; + } catch (Exception e) { + rollback(connection); + throw new TransactionException(e); + } finally { + release(connection); + } + } + + private void rollback(Connection connection) { + if (connection == null) { + return; + } + + try { + connection.rollback(); + } catch (SQLException e) { + throw new TransactionException(e); + } + } + + private void release(Connection connection) { + if (connection == null) { + return; + } + + try { + connection.setAutoCommit(true); + + DataSourceUtils.releaseConnection(connection, dataSource); + TransactionSynchronizationManager.unbindResource(dataSource); + } catch (Exception e) { + log.warn(String.valueOf(e)); + } + } + +} diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java index 715557fc66..c59b208818 100644 --- a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -1,23 +1,38 @@ package org.springframework.transaction.support; -import javax.sql.DataSource; import java.sql.Connection; +import java.util.HashMap; import java.util.Map; +import javax.sql.DataSource; public abstract class TransactionSynchronizationManager { - private static final ThreadLocal> resources = new ThreadLocal<>(); + private static final ThreadLocal> resources = ThreadLocal.withInitial(HashMap::new); - private TransactionSynchronizationManager() {} + private TransactionSynchronizationManager() { + } public static Connection getResource(DataSource key) { - return null; + Map dataSourceConnectionMap = resources.get(); + + return dataSourceConnectionMap.get(key); } public static void bindResource(DataSource key, Connection value) { + Map dataSourceConnectionMap = resources.get(); + + dataSourceConnectionMap.put(key, value); } public static Connection unbindResource(DataSource key) { - return null; + Map dataSourceConnectionMap = resources.get(); + Connection value = dataSourceConnectionMap.remove(key); + + if (dataSourceConnectionMap.isEmpty()) { + resources.remove(); + } + + return value; } + } diff --git a/jdbc/src/test/java/org/springframework/transaction/support/TransactionManagerTest.java b/jdbc/src/test/java/org/springframework/transaction/support/TransactionManagerTest.java new file mode 100644 index 0000000000..2409eec721 --- /dev/null +++ b/jdbc/src/test/java/org/springframework/transaction/support/TransactionManagerTest.java @@ -0,0 +1,120 @@ +package org.springframework.transaction.support; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.function.Supplier; +import javax.sql.DataSource; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; +import org.springframework.transaction.TransactionException; + +class TransactionManagerTest { + + @Mock + private DataSource dataSource; + + @Mock + private Connection connection; + + private TransactionManager transactionManager; + + @BeforeEach + public void setUp() { + MockitoAnnotations.openMocks(this); + + when(DataSourceUtils.getConnection(dataSource)).thenReturn(connection); + transactionManager = new TransactionManager(dataSource); + } + + @Nested + class Execute { + @Test + void commit() throws SQLException { + // given + Runnable command = mock(Runnable.class); + + // when + transactionManager.execute(command); + + // then + verify(command, times(1)).run(); + verify(connection, times(1)).setAutoCommit(false); + verify(connection, times(1)).commit(); + } + + @Test + void rollback() throws SQLException { + // given + Runnable command = () -> { + throw new RuntimeException("Test"); + }; + + // when + // then + assertThatThrownBy(() -> transactionManager.execute(command)) + .isInstanceOf(DataAccessException.class); + + verify(connection, times(1)).rollback(); + } + } + + @Nested + class Query { + @Test + void commit() throws SQLException { + // given + Supplier query = () -> "result"; + + // when + String result = transactionManager.query(query); + + // then + verify(connection, times(1)).setReadOnly(true); + verify(connection, times(1)).setAutoCommit(false); + verify(connection, times(1)).commit(); + + assertThat(result).isEqualTo("result"); + } + + @Test + void rollback() throws SQLException { + // given + Supplier query = () -> { + throw new RuntimeException("Test"); + }; + + // when + // then + assertThatThrownBy(() -> transactionManager.query(query)) + .isInstanceOf(TransactionException.class); + + verify(connection, times(1)).rollback(); + } + } + + @Test + void release() throws SQLException { + // given + Runnable command = mock(Runnable.class); + + // when + transactionManager.execute(command); + + // then + verify(connection, times(1)).setAutoCommit(true); + verify(connection, times(1)).close(); + } + +} diff --git a/jdbc/src/test/java/org/springframework/transaction/support/TransactionSynchronizationManagerTest.java b/jdbc/src/test/java/org/springframework/transaction/support/TransactionSynchronizationManagerTest.java new file mode 100644 index 0000000000..b4fdbb69de --- /dev/null +++ b/jdbc/src/test/java/org/springframework/transaction/support/TransactionSynchronizationManagerTest.java @@ -0,0 +1,50 @@ +package org.springframework.transaction.support; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.sql.Connection; +import javax.sql.DataSource; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; + +class TransactionSynchronizationManagerTest { + + @Mock + private DataSource dataSource; + @Mock + private Connection connection; + + @Test + void getResource() { + // when + Connection actual = TransactionSynchronizationManager.getResource(dataSource); + + // then + assertThat(actual).isNull(); + } + + @Test + void bindResource() { + // when + TransactionSynchronizationManager.bindResource(dataSource, connection); + + // then + Connection actual = TransactionSynchronizationManager.getResource(dataSource); + assertThat(actual).isEqualTo(connection); + } + + @Test + void unbindResource() { + // given + TransactionSynchronizationManager.bindResource(dataSource, connection); + + // when + Connection unboundConnection = TransactionSynchronizationManager.unbindResource(dataSource); + + // then + Connection actual = TransactionSynchronizationManager.getResource(dataSource); + + assertThat(unboundConnection).isEqualTo(connection); + assertThat(actual).isNull(); + } +}