diff --git a/app/src/main/java/com/techcourse/dao/UserDao.java b/app/src/main/java/com/techcourse/dao/UserDao.java index 03a9da78dd..2b1706c293 100644 --- a/app/src/main/java/com/techcourse/dao/UserDao.java +++ b/app/src/main/java/com/techcourse/dao/UserDao.java @@ -37,11 +37,6 @@ public void update(final User user) { jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); } - public void update(final Connection connection, final User user) { - final String sql = "update users set account = ?, password = ?, email = ? where id = ?"; - jdbcTemplate.update(connection, sql, user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); - } - public List findAll() { final String sql = "select id, account, password, email from users"; return jdbcTemplate.queryForList(sql, ROW_MAPPER); diff --git a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java index c40975a4f5..51a8b60d3e 100644 --- a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java +++ b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java @@ -30,19 +30,4 @@ public void log(final UserHistory userHistory) { userHistory.getCreateBy() ); } - - public void log(final Connection connection, final UserHistory userHistory) { - final var sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; - - jdbcTemplate.update( - connection, - sql, - userHistory.getUserId(), - userHistory.getAccount(), - userHistory.getPassword(), - userHistory.getEmail(), - userHistory.getCreatedAt(), - userHistory.getCreateBy() - ); - } } diff --git a/app/src/main/java/com/techcourse/service/AppUserService.java b/app/src/main/java/com/techcourse/service/AppUserService.java new file mode 100644 index 0000000000..b1e84f38d8 --- /dev/null +++ b/app/src/main/java/com/techcourse/service/AppUserService.java @@ -0,0 +1,35 @@ +package com.techcourse.service; + +import com.techcourse.dao.UserDao; +import com.techcourse.dao.UserHistoryDao; +import com.techcourse.domain.User; +import com.techcourse.domain.UserHistory; + +public class AppUserService implements UserService { + + private final UserDao userDao; + private final UserHistoryDao userHistoryDao; + + public AppUserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { + this.userDao = userDao; + this.userHistoryDao = userHistoryDao; + } + + @Override + public User findById(final long id) { + return userDao.findById(id); + } + + @Override + public void insert(final User user) { + userDao.insert(user); + } + + @Override + public void changePassword(final long id, final String newPassword, final String createBy) { + final User user = findById(id); + user.changePassword(newPassword); + userDao.update(user); + userHistoryDao.log(new UserHistory(user, createBy)); + } +} diff --git a/app/src/main/java/com/techcourse/service/TxUserService.java b/app/src/main/java/com/techcourse/service/TxUserService.java new file mode 100644 index 0000000000..0a602f75f2 --- /dev/null +++ b/app/src/main/java/com/techcourse/service/TxUserService.java @@ -0,0 +1,30 @@ +package com.techcourse.service; + +import com.techcourse.domain.User; +import org.springframework.transaction.TransactionExecutor; + +public class TxUserService implements UserService { + + private final UserService userService; + private final TransactionExecutor transactionExecutor; + + public TxUserService(final UserService userService, final TransactionExecutor transactionExecutor) { + this.userService = userService; + this.transactionExecutor = transactionExecutor; + } + + @Override + public User findById(final long id) { + return transactionExecutor.execute(() -> userService.findById(id)); + } + + @Override + public void insert(final User user) { + transactionExecutor.execute(() -> userService.insert(user)); + } + + @Override + public void changePassword(final long id, final String newPassword, final String createBy) { + transactionExecutor.execute(() -> userService.changePassword(id, newPassword, createBy)); + } +} diff --git a/app/src/main/java/com/techcourse/service/UserService.java b/app/src/main/java/com/techcourse/service/UserService.java index 292ac481d2..42d01bf760 100644 --- a/app/src/main/java/com/techcourse/service/UserService.java +++ b/app/src/main/java/com/techcourse/service/UserService.java @@ -1,64 +1,10 @@ package com.techcourse.service; -import com.techcourse.config.DataSourceConfig; -import com.techcourse.dao.UserDao; -import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; -import com.techcourse.domain.UserHistory; -import java.sql.Connection; -import java.sql.SQLException; -import javax.sql.DataSource; -import org.springframework.dao.DataAccessException; -public class UserService { +public interface UserService { - private final UserDao userDao; - private final UserHistoryDao userHistoryDao; - - public UserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { - this.userDao = userDao; - this.userHistoryDao = userHistoryDao; - } - - public User findById(final long id) { - return userDao.findById(id); - } - - public void insert(final User user) { - userDao.insert(user); - } - - public void changePassword(final long id, final String newPassword, final String createBy) { - final DataSource dataSource = DataSourceConfig.getInstance(); - final Connection connection = getConnection(dataSource); - try (connection) { - connection.setAutoCommit(false); - - final User user = findById(id); - user.changePassword(newPassword); - userDao.update(connection, user); - userHistoryDao.log(connection, new UserHistory(user, createBy)); - - connection.commit(); - } catch (SQLException | DataAccessException e) { - rollback(connection); - throw new DataAccessException(e); - } - } - - private Connection getConnection(final DataSource dataSource) { - try { - return dataSource.getConnection(); - } catch (SQLException e) { - throw new DataAccessException(); - } - } - - private void rollback(final Connection connection) { - try { - connection.rollback(); - } catch (SQLException e) { - throw new DataAccessException(); - } - } + User findById(final long id); + void insert(final User user); + void changePassword(final long id, final String newPassword, final String createBy); } diff --git a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java index a768d17b27..3d79d30584 100644 --- a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java +++ b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java @@ -16,9 +16,4 @@ public MockUserHistoryDao(final JdbcTemplate jdbcTemplate) { public void log(final UserHistory userHistory) { throw new DataAccessException(); } - - @Override - public void log(final Connection connection, final UserHistory userHistory) { - throw new DataAccessException(); - } } diff --git a/app/src/test/java/com/techcourse/service/UserServiceTest.java b/app/src/test/java/com/techcourse/service/UserServiceTest.java index 2d4abce635..af1a1bbcbb 100644 --- a/app/src/test/java/com/techcourse/service/UserServiceTest.java +++ b/app/src/test/java/com/techcourse/service/UserServiceTest.java @@ -8,8 +8,8 @@ import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.transaction.TransactionExecutor; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -31,14 +31,15 @@ void setUp() { @Test void testChangePassword() { - final var userHistoryDao = new UserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final UserHistoryDao userHistoryDao = new UserHistoryDao(jdbcTemplate); + final AppUserService appUserService = new AppUserService(userDao, userHistoryDao); + final TxUserService txUserService = new TxUserService(appUserService, new TransactionExecutor(DataSourceConfig.getInstance())); final var newPassword = "qqqqq"; final var createBy = "gugu"; - userService.changePassword(1L, newPassword, createBy); + txUserService.changePassword(1L, newPassword, createBy); - final var actual = userService.findById(1L); + final var actual = txUserService.findById(1L); assertThat(actual.getPassword()).isEqualTo(newPassword); } @@ -46,16 +47,19 @@ void testChangePassword() { @Test void testTransactionRollback() { // 트랜잭션 롤백 테스트를 위해 mock으로 교체 - final var userHistoryDao = new MockUserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final MockUserHistoryDao userHistoryDao = new MockUserHistoryDao(jdbcTemplate); + // 애플리케이션 서비스 + final AppUserService appUserService = new AppUserService(userDao, userHistoryDao); + // 트랜잭션 서비스 추상화 + final TxUserService txUserService = new TxUserService(appUserService, new TransactionExecutor(DataSourceConfig.getInstance())); final var newPassword = "newPassword"; final var createBy = "gugu"; // 트랜잭션이 정상 동작하는지 확인하기 위해 의도적으로 MockUserHistoryDao에서 예외를 발생시킨다. assertThrows(DataAccessException.class, - () -> userService.changePassword(1L, newPassword, createBy)); + () -> txUserService.changePassword(1L, newPassword, createBy)); - final var actual = userService.findById(1L); + final var actual = txUserService.findById(1L); assertThat(actual.getPassword()).isNotEqualTo(newPassword); } diff --git a/jdbc/build.gradle b/jdbc/build.gradle index 83f293f626..bcb0088880 100644 --- a/jdbc/build.gradle +++ b/jdbc/build.gradle @@ -15,6 +15,7 @@ dependencies { implementation "org.apache.commons:commons-lang3:3.13.0" implementation "ch.qos.logback:logback-classic:1.2.12" + testImplementation "com.h2database:h2:2.2.220" testImplementation "org.assertj:assertj-core:3.24.2" testImplementation "org.junit.jupiter:junit-jupiter-api:5.7.2" testImplementation "org.mockito:mockito-core:5.4.0" diff --git a/jdbc/src/main/java/org/springframework/jdbc/ConnectionHolder.java b/jdbc/src/main/java/org/springframework/jdbc/ConnectionHolder.java new file mode 100644 index 0000000000..0ebbf43eaa --- /dev/null +++ b/jdbc/src/main/java/org/springframework/jdbc/ConnectionHolder.java @@ -0,0 +1,31 @@ +package org.springframework.jdbc; + +import static java.util.Objects.isNull; + +import java.sql.Connection; + +public class ConnectionHolder { + + private Connection connection; + private boolean transactionActive = false; + + public ConnectionHolder(final Connection connection) { + this.connection = connection; + } + + public void setTransactionActive(final boolean transactionActive) { + this.transactionActive = transactionActive; + } + + public boolean isTransactionActive() { + return transactionActive; + } + + public Connection getConnection() { + return connection; + } + + public boolean has(final Connection connection) { + return this.connection == connection; + } +} diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index dcae16428c..e99ecd1596 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -8,6 +8,7 @@ import java.util.List; import javax.sql.DataSource; import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; public class JdbcTemplate { @@ -23,24 +24,13 @@ public int update(final String sql, final Object... args) { private T execute(final PreparedStatementCallback preparedStatementCallback, final ExecutionCallback executionCallback) { - try (final Connection connection = dataSource.getConnection()) { - return executeWithConnection(connection, preparedStatementCallback, executionCallback); - } catch (SQLException e) { - throw new DataAccessException(e); - } - } - - public int update(final Connection connection, final String sql, final Object... args) { - return executeWithConnection(connection, conn -> prepareStatement(sql, conn, args), PreparedStatement::executeUpdate); - } - - private T executeWithConnection(final Connection connection, - final PreparedStatementCallback preparedStatementCallback, - final ExecutionCallback executionCallback) { + final Connection connection = DataSourceUtils.getConnection(dataSource); try (final PreparedStatement pstmt = preparedStatementCallback.prepareStatement(connection)) { return executionCallback.execute(pstmt); } catch (SQLException e) { throw new DataAccessException(e); + } finally { + DataSourceUtils.releaseConnection(connection, dataSource); } } diff --git a/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java b/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java index 3c40bfec52..3f5bd8a0a0 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java +++ b/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java @@ -1,6 +1,10 @@ package org.springframework.jdbc.datasource; +import static java.util.Objects.isNull; + +import org.springframework.dao.DataAccessException; import org.springframework.jdbc.CannotGetJdbcConnectionException; +import org.springframework.jdbc.ConnectionHolder; import org.springframework.transaction.support.TransactionSynchronizationManager; import javax.sql.DataSource; @@ -12,14 +16,14 @@ public abstract class DataSourceUtils { private DataSourceUtils() {} - public static Connection getConnection(DataSource dataSource) throws CannotGetJdbcConnectionException { - Connection connection = TransactionSynchronizationManager.getResource(dataSource); - if (connection != null) { - return connection; + public static Connection getConnection(final DataSource dataSource) throws CannotGetJdbcConnectionException { + final ConnectionHolder connectionHolder = TransactionSynchronizationManager.getResource(dataSource); + if (connectionHolder != null) { + return connectionHolder.getConnection(); } try { - connection = dataSource.getConnection(); + final Connection connection = dataSource.getConnection(); TransactionSynchronizationManager.bindResource(dataSource, connection); return connection; } catch (SQLException ex) { @@ -27,8 +31,44 @@ public static Connection getConnection(DataSource dataSource) throws CannotGetJd } } - public static void releaseConnection(Connection connection, DataSource dataSource) { + public static void startTransaction(final Connection connection, final DataSource dataSource) { + final ConnectionHolder connectionHolder = getConnectionHolder(connection, dataSource); + try{ + connectionHolder.setTransactionActive(true); + connection.setAutoCommit(false); + }catch(SQLException e) { + throw new DataAccessException(); + } + } + + private static ConnectionHolder getConnectionHolder(final Connection connection, final DataSource dataSource) { + final ConnectionHolder connectionHolder = TransactionSynchronizationManager.getResource(dataSource); + if(isNull(connectionHolder)) { + throw new IllegalStateException(); + } + if(!connectionHolder.has(connection)) { + throw new IllegalStateException(); + } + return connectionHolder; + } + + public static void finishTransaction(final Connection connection, final DataSource dataSource) { + final ConnectionHolder connectionHolder = getConnectionHolder(connection, dataSource); + try{ + connection.commit(); + connectionHolder.setTransactionActive(false); + }catch(SQLException e) { + throw new DataAccessException(); + } + } + + public static void releaseConnection(final Connection connection, final DataSource dataSource) { try { + final ConnectionHolder connectionHolder = getConnectionHolder(connection, dataSource); + if(connectionHolder.isTransactionActive()) { + return; + } + TransactionSynchronizationManager.unbindResource(dataSource); connection.close(); } catch (SQLException ex) { throw new CannotGetJdbcConnectionException("Failed to close JDBC Connection"); diff --git a/jdbc/src/main/java/org/springframework/transaction/TransactionExecutor.java b/jdbc/src/main/java/org/springframework/transaction/TransactionExecutor.java new file mode 100644 index 0000000000..aff987b306 --- /dev/null +++ b/jdbc/src/main/java/org/springframework/transaction/TransactionExecutor.java @@ -0,0 +1,54 @@ +package org.springframework.transaction; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.function.Supplier; +import javax.sql.DataSource; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; + +public class TransactionExecutor { + + private final DataSource dataSource; + + public TransactionExecutor(final DataSource dataSource) { + this.dataSource = dataSource; + } + + public T execute(final Supplier supplier) { + final Connection connection = DataSourceUtils.getConnection(dataSource); + try { + DataSourceUtils.startTransaction(connection, dataSource); + final T t = supplier.get(); + DataSourceUtils.finishTransaction(connection, dataSource); + return t; + } catch (DataAccessException e) { + rollback(connection); + throw new DataAccessException(e); + } finally { + DataSourceUtils.releaseConnection(connection, dataSource); + } + } + + public void execute(final Runnable runnable) { + final Connection connection = DataSourceUtils.getConnection(dataSource); + try { + DataSourceUtils.startTransaction(connection, dataSource); + runnable.run(); + DataSourceUtils.finishTransaction(connection, dataSource); + } catch (DataAccessException e) { + rollback(connection); + throw new DataAccessException(e); + } finally { + DataSourceUtils.releaseConnection(connection, dataSource); + } + } + + private void rollback(final Connection connection) { + try { + connection.rollback(); + } catch (SQLException e) { + throw new DataAccessException(); + } + } +} diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java index 715557fc66..fb1b532596 100644 --- a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -1,23 +1,36 @@ package org.springframework.transaction.support; +import java.util.HashMap; +import javax.annotation.Nullable; import javax.sql.DataSource; import java.sql.Connection; import java.util.Map; +import org.springframework.jdbc.ConnectionHolder; public abstract class TransactionSynchronizationManager { - private static final ThreadLocal> resources = new ThreadLocal<>(); + private static final ThreadLocal> resources = ThreadLocal.withInitial(HashMap::new); private TransactionSynchronizationManager() {} - public static Connection getResource(DataSource key) { - return null; + @Nullable + public static ConnectionHolder getResource(final DataSource key) { + return getResources().getOrDefault(key, null); } - public static void bindResource(DataSource key, Connection value) { + public static void bindResource(final DataSource key, final Connection value) { + final Map resources = getResources(); + if(resources.containsKey(key)){ + throw new IllegalStateException(); + } + resources.put(key, new ConnectionHolder(value)); } - public static Connection unbindResource(DataSource key) { - return null; + public static ConnectionHolder unbindResource(final DataSource key) { + return getResources().remove(key); + } + + private static Map getResources() { + return resources.get(); } } diff --git a/jdbc/src/test/java/org/springframework/transaction/support/TransactionSynchronizationManagerTest.java b/jdbc/src/test/java/org/springframework/transaction/support/TransactionSynchronizationManagerTest.java new file mode 100644 index 0000000000..2cb8c7280b --- /dev/null +++ b/jdbc/src/test/java/org/springframework/transaction/support/TransactionSynchronizationManagerTest.java @@ -0,0 +1,57 @@ +package org.springframework.transaction.support; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Objects; +import javax.sql.DataSource; +import org.h2.jdbcx.JdbcDataSource; +import org.junit.jupiter.api.Test; + +class TransactionSynchronizationManagerTest { + + @Test + void 두개의_쓰레드로_ThreadLocal_테스트() { + //첫번째 쓰레드에서 리소스 바인딩 + final DataSource dataSource = DataSourceConfig.getInstance(); + TransactionSynchronizationManager.bindResource(dataSource, getConnection(dataSource)); + + //두번째 쓰레드에서 리소스 조회 + new Thread(() -> { + final Connection connection = TransactionSynchronizationManager.getResource(dataSource); + assertThat(connection).isNull(); + }).start(); + } + + private Connection getConnection(final DataSource dataSource) { + try { + return dataSource.getConnection(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private static class DataSourceConfig { + + private static javax.sql.DataSource INSTANCE; + + public static javax.sql.DataSource getInstance() { + if (Objects.isNull(INSTANCE)) { + INSTANCE = createJdbcDataSource(); + } + return INSTANCE; + } + + private static JdbcDataSource createJdbcDataSource() { + final var jdbcDataSource = new JdbcDataSource(); + jdbcDataSource.setUrl("jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;"); + jdbcDataSource.setUser(""); + jdbcDataSource.setPassword(""); + return jdbcDataSource; + } + + private DataSourceConfig() { + } + } +} diff --git a/study/src/test/java/transaction/stage2/Stage2Test.java b/study/src/test/java/transaction/stage2/Stage2Test.java index 9a4ff4e580..a3992480bb 100644 --- a/study/src/test/java/transaction/stage2/Stage2Test.java +++ b/study/src/test/java/transaction/stage2/Stage2Test.java @@ -113,9 +113,9 @@ void testMandatory() { } /** - * 아래 테스트는 몇 개의 물리적 트랜잭션이 동작할까? 주석 처리 안했을때(required, not supported) 논리적으로는 2개, 물리적으로 1개? + * 아래 테스트는 몇 개의 물리적 트랜잭션이 동작할까? 주석 처리 안했을때(required, not supported) 논리적으로는 1개, 물리적으로 2개? * FirstUserService.saveFirstTransactionWithNotSupported() 메서드의 @Transactional을 주석 처리하자. - * 다시 테스트를 실행하면 몇 개의 물리적 트랜잭션이 동작할까? 주석 처리 했을때(x, not supported) 논리적으로는 1개, 물리적으로는 0개? + * 다시 테스트를 실행하면 몇 개의 물리적 트랜잭션이 동작할까? 주석 처리 했을때(x, not supported) 논리적으로는 0개, 물리적으로는 2개? * not supported: 현재 트랜잭션이 존재하는 경우 먼저 이를 일시 중지한 다음 트랜잭션 없이 실행한다. * 스프링 공식 문서에서 물리적 트랜잭션과 논리적 트랜잭션의 차이점이 무엇인지 찾아보자. * 외부 트랜잭션 범위는 내부 트랜잭션 범위와 논리적으로 독립이지만, 동일한 물리적 트랜잭션에 매핑된다.