diff --git a/app/src/main/java/com/techcourse/dao/UserDao.java b/app/src/main/java/com/techcourse/dao/UserDao.java index bdae8e4ee8..e98cd144be 100644 --- a/app/src/main/java/com/techcourse/dao/UserDao.java +++ b/app/src/main/java/com/techcourse/dao/UserDao.java @@ -3,6 +3,7 @@ import com.techcourse.domain.User; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.ResultSetMapper; +import org.springframework.transaction.support.TransactionSynchronizationManager; import javax.sql.DataSource; import java.sql.Connection; @@ -23,14 +24,14 @@ public UserDao(final DataSource dataSource) { this.jdbcTemplate = new JdbcTemplate(dataSource); } - public void insert(final Connection connection, final User user) { + public void insert(final User user) { final String sql = "insert into users (account, password, email) values (?, ?, ?)"; - jdbcTemplate.update(connection, sql, user.getAccount(), user.getPassword(), user.getEmail()); + jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail()); } - public void update(final Connection connection, final User user) { + public void update(final User user) { final String sql = "update users set account = ?, password = ?, email = ? where id = ?"; - jdbcTemplate.update(connection, sql, user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); + jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); } public List findAll() { diff --git a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java index f509343709..6b11f05174 100644 --- a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java +++ b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java @@ -12,34 +12,21 @@ public class UserHistoryDao { - private static final Logger log = LoggerFactory.getLogger(UserHistoryDao.class); - - private final DataSource dataSource; + private final JdbcTemplate jdbcTemplate; public UserHistoryDao(final DataSource dataSource) { - this.dataSource = dataSource; - } - - public UserHistoryDao(final JdbcTemplate jdbcTemplate) { - this.dataSource = null; + this.jdbcTemplate = new JdbcTemplate(dataSource); } - public void log(final Connection connection, final UserHistory userHistory) { + public void log(final UserHistory userHistory) { final var sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; - - try (final PreparedStatement pstmt = connection.prepareStatement(sql)) { - log.debug("query : {}", sql); - - pstmt.setLong(1, userHistory.getUserId()); - pstmt.setString(2, userHistory.getAccount()); - pstmt.setString(3, userHistory.getPassword()); - pstmt.setString(4, userHistory.getEmail()); - pstmt.setObject(5, userHistory.getCreatedAt()); - pstmt.setString(6, userHistory.getCreateBy()); - pstmt.executeUpdate(); - } catch (SQLException e) { - log.error(e.getMessage(), e); - throw new RuntimeException(e); - } + jdbcTemplate.update(sql, + userHistory.getUserId(), + userHistory.getAccount(), + userHistory.getPassword(), + userHistory.getEmail(), + userHistory.getCreatedAt(), + userHistory.getCreateBy() + ); } } diff --git a/app/src/main/java/com/techcourse/service/AppUserService.java b/app/src/main/java/com/techcourse/service/AppUserService.java new file mode 100644 index 0000000000..a8d453b094 --- /dev/null +++ b/app/src/main/java/com/techcourse/service/AppUserService.java @@ -0,0 +1,35 @@ +package com.techcourse.service; + +import com.techcourse.dao.UserDao; +import com.techcourse.dao.UserHistoryDao; +import com.techcourse.domain.User; +import com.techcourse.domain.UserHistory; + +public class AppUserService implements UserService { + + private final UserDao userDao; + private final UserHistoryDao userHistoryDao; + + public AppUserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { + this.userDao = userDao; + this.userHistoryDao = userHistoryDao; + } + + @Override + public User findById(final long id) { + return userDao.findById(id); + } + + @Override + public void insert(final User user) { + userDao.insert(user); + } + + @Override + public void changePassword(final long id, final String newPassword, final String createBy) { + final var user = findById(id); + user.changePassword(newPassword); + userDao.update(user); + userHistoryDao.log(new UserHistory(user, createBy)); + } +} diff --git a/app/src/main/java/com/techcourse/service/TxUserService.java b/app/src/main/java/com/techcourse/service/TxUserService.java new file mode 100644 index 0000000000..4bc2d4021d --- /dev/null +++ b/app/src/main/java/com/techcourse/service/TxUserService.java @@ -0,0 +1,30 @@ +package com.techcourse.service; + +import com.techcourse.domain.User; +import org.springframework.transaction.support.TransactionTemplate; + +public class TxUserService implements UserService { + + private final AppUserService appUserService; + private final TransactionTemplate transactionTemplate; + + public TxUserService(final AppUserService appUserService, final TransactionTemplate transactionTemplate) { + this.appUserService = appUserService; + this.transactionTemplate = transactionTemplate; + } + + @Override + public User findById(final long id) { + return appUserService.findById(id); + } + + @Override + public void insert(final User user) { + transactionTemplate.execute(connection -> appUserService.insert(user)); + } + + @Override + public void changePassword(final long id, final String newPassword, final String createBy) { + transactionTemplate.execute(connection -> appUserService.changePassword(id, newPassword, createBy)); + } +} diff --git a/app/src/main/java/com/techcourse/service/UserService.java b/app/src/main/java/com/techcourse/service/UserService.java index 589ad7b241..42d01bf760 100644 --- a/app/src/main/java/com/techcourse/service/UserService.java +++ b/app/src/main/java/com/techcourse/service/UserService.java @@ -1,85 +1,10 @@ package com.techcourse.service; -import com.techcourse.config.DataSourceConfig; -import com.techcourse.dao.UserDao; -import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; -import com.techcourse.domain.UserHistory; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.dao.DataAccessException; -import java.sql.Connection; -import java.sql.SQLException; +public interface UserService { -public class UserService { - - private static final Logger log = LoggerFactory.getLogger(UserService.class); - - private final UserDao userDao; - private final UserHistoryDao userHistoryDao; - - public UserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { - this.userDao = userDao; - this.userHistoryDao = userHistoryDao; - } - - public User findById(final long id) { - return userDao.findById(id); - } - - public void insert(final User user) { - Connection connection = null; - try { - connection = DataSourceConfig.getInstance().getConnection(); - connection.setAutoCommit(false); - userDao.insert(connection, user); - connection.commit(); - } catch (SQLException e) { - log.error(e.getMessage()); - rollback(connection); - } finally { - release(connection); - } - } - - public void changePassword(final long id, final String newPassword, final String createBy) { - Connection connection = null; - try { - connection = DataSourceConfig.getInstance().getConnection(); - connection.setAutoCommit(false); - final var user = findById(id); - user.changePassword(newPassword); - userDao.update(connection, user); - userHistoryDao.log(connection, new UserHistory(user, createBy)); - connection.commit(); - } catch (SQLException e) { - log.error(e.getMessage()); - rollback(connection); - } finally { - release(connection); - } - } - - private void rollback(final Connection connection) { - if (connection != null) { - try { - connection.rollback(); - } catch (SQLException e) { - log.error(e.getMessage()); - throw new DataAccessException("rollback 에 실패했습니다."); - } - } - } - - private void release(final Connection connection) { - if (connection != null) { - try { - connection.close(); - } catch (SQLException e) { - log.error(e.getMessage()); - throw new DataAccessException("자원 정리에 실패했습니다."); - } - } - } + User findById(final long id); + void insert(final User user); + void changePassword(final long id, final String newPassword, final String createBy); } diff --git a/app/src/test/java/com/techcourse/dao/UserDaoTest.java b/app/src/test/java/com/techcourse/dao/UserDaoTest.java index b23effaaf8..a85b72f4b3 100644 --- a/app/src/test/java/com/techcourse/dao/UserDaoTest.java +++ b/app/src/test/java/com/techcourse/dao/UserDaoTest.java @@ -28,7 +28,7 @@ void setup() throws SQLException { userDao = new UserDao(dataSource); connection = dataSource.getConnection(); final var user = new User("hongsil", "486", "gurwns9325@gmail.com"); - userDao.insert(connection, user); + userDao.insert(user); } @Test @@ -54,7 +54,7 @@ void findById_make_exception_when_no_result() { @Test void findByAccount() { final var account = "mylove_hongsil"; - userDao.insert(connection, new User(account, "비밀번호486", "love@with.you")); + userDao.insert(new User(account, "비밀번호486", "love@with.you")); final var user = userDao.findByAccount(account); assertThat(user.getAccount()).isEqualTo(account); @@ -63,8 +63,8 @@ void findByAccount() { @Test void findByAccount_make_exception_when_multiple_result() { final var user = new User("ditoo", "password", "ditoo@gmail.com"); - userDao.insert(connection, user); - userDao.insert(connection, user); + userDao.insert(user); + userDao.insert(user); assertThatThrownBy(() -> userDao.findByAccount(user.getAccount())) .isInstanceOf(IncorrectResultSizeDataAccessException.class); } @@ -73,7 +73,7 @@ void findByAccount_make_exception_when_multiple_result() { void insert() { final var account = "insert-gugu"; final var user = new User(account, "password", "hkkang@woowahan.com"); - userDao.insert(connection, user); + userDao.insert(user); final var actual = userDao.findById(2L); @@ -86,7 +86,7 @@ void update() { final var user = userDao.findById(1L); user.changePassword(newPassword); - userDao.update(connection, user); + userDao.update(user); final var actual = userDao.findById(1L); diff --git a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java index 9937b42bbc..f3aa28c7a4 100644 --- a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java +++ b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java @@ -5,16 +5,17 @@ import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; +import javax.sql.DataSource; import java.sql.Connection; public class MockUserHistoryDao extends UserHistoryDao { - public MockUserHistoryDao(final JdbcTemplate jdbcTemplate) { - super(jdbcTemplate); + public MockUserHistoryDao(final DataSource dataSource) { + super(dataSource); } @Override - public void log(final Connection connection, final UserHistory userHistory) { + public void log(final UserHistory userHistory) { throw new DataAccessException(); } } diff --git a/app/src/test/java/com/techcourse/service/UserServiceTest.java b/app/src/test/java/com/techcourse/service/TxUserServiceTest.java similarity index 54% rename from app/src/test/java/com/techcourse/service/UserServiceTest.java rename to app/src/test/java/com/techcourse/service/TxUserServiceTest.java index fa14862557..138b3db602 100644 --- a/app/src/test/java/com/techcourse/service/UserServiceTest.java +++ b/app/src/test/java/com/techcourse/service/TxUserServiceTest.java @@ -6,44 +6,47 @@ import com.techcourse.domain.User; import com.techcourse.support.jdbc.init.DatabasePopulatorUtils; import org.springframework.dao.DataAccessException; -import org.springframework.jdbc.core.JdbcTemplate; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.transaction.support.TransactionTemplate; import javax.sql.DataSource; -import java.sql.SQLException; - import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; -class UserServiceTest { +class TxUserServiceTest { - private JdbcTemplate jdbcTemplate; + private DataSource dataSource; private UserDao userDao; + private AppUserService appUserService; + private TransactionTemplate transactionTemplate; @BeforeEach - void setUp() throws SQLException { - final DataSource dataSource = DataSourceConfig.getInstance(); - this.jdbcTemplate = new JdbcTemplate(dataSource); + void setUp() { + this.dataSource = DataSourceConfig.getInstance(); this.userDao = new UserDao(dataSource); + UserHistoryDao userHistoryDao = new UserHistoryDao(dataSource); + this.appUserService = new AppUserService(userDao, userHistoryDao); + this.transactionTemplate = new TransactionTemplate(dataSource); DatabasePopulatorUtils.execute(DataSourceConfig.getInstance()); final var user = new User("gugu", "password", "hkkang@woowahan.com"); - userDao.insert(dataSource.getConnection(), user); + userDao.insert(user); } @Test void testChangePassword() { - final var userHistoryDao = new UserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final var txUserService = new TxUserService(appUserService, transactionTemplate); final var newPassword = "qqqqq"; final var createBy = "gugu"; - userService.changePassword(1L, newPassword, createBy); + appUserService.changePassword(1L, newPassword, createBy); - final var actual = userService.findById(1L); + final var actual = txUserService.findById(1L); assertThat(actual.getPassword()).isEqualTo(newPassword); } @@ -51,16 +54,16 @@ void testChangePassword() { @Test void testTransactionRollback() { // 트랜잭션 롤백 테스트를 위해 mock으로 교체 - final var userHistoryDao = new MockUserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final AppUserService appUserService = new AppUserService(userDao, new MockUserHistoryDao(dataSource)); + final var txUserService = new TxUserService(appUserService, transactionTemplate); final var newPassword = "newPassword"; final var createBy = "gugu"; // 트랜잭션이 정상 동작하는지 확인하기 위해 의도적으로 MockUserHistoryDao에서 예외를 발생시킨다. assertThrows(DataAccessException.class, - () -> userService.changePassword(1L, newPassword, createBy)); + () -> txUserService.changePassword(1L, newPassword, createBy)); - final var actual = userService.findById(1L); + final var actual = txUserService.findById(1L); assertThat(actual.getPassword()).isNotEqualTo(newPassword); } diff --git a/app/src/test/java/com/techcourse/transaction/NestedTransactionTest.java b/app/src/test/java/com/techcourse/transaction/NestedTransactionTest.java new file mode 100644 index 0000000000..0464da5d52 --- /dev/null +++ b/app/src/test/java/com/techcourse/transaction/NestedTransactionTest.java @@ -0,0 +1,62 @@ +package com.techcourse.transaction; + +import com.techcourse.config.DataSourceConfig; +import com.techcourse.dao.UserDao; +import com.techcourse.dao.UserHistoryDao; +import com.techcourse.domain.User; +import com.techcourse.service.AppUserService; +import com.techcourse.service.TxUserService; +import com.techcourse.support.jdbc.init.DatabasePopulatorUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.transaction.support.TransactionTemplate; + +import javax.sql.DataSource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class NestedTransactionTest { + + private JdbcTemplate jdbcTemplate; + private UserDao userDao; + private UserHistoryDao userHistoryDao; + private TransactionTemplate transactionTemplate; + + @BeforeEach + void setUp() { + final DataSource dataSource = DataSourceConfig.getInstance(); + this.jdbcTemplate = new JdbcTemplate(dataSource); + this.userDao = new UserDao(dataSource); + + DatabasePopulatorUtils.execute(dataSource); + final var user = new User("gugu", "password", "hkkang@woowahan.com"); + userDao.insert(user); + userHistoryDao = new UserHistoryDao(dataSource); + transactionTemplate = new TransactionTemplate(dataSource); + } + + @Test + @DisplayName("transaction안에서 transaction을 호출하는 경우 기존 transaction에 합류하도록 구현") + void nestedCaseTest() { + //given + final AppUserService appUserService = new AppUserService(userDao, userHistoryDao); + final TxUserService txUserService = new TxUserService(appUserService, transactionTemplate); + final Long id = 1L; + final String newPassword = "newPassword"; + final String createdBy = "hong"; + + //when + assertThatThrownBy(() -> transactionTemplate.execute(connection -> { + txUserService.changePassword(id, newPassword, createdBy); + throw new RuntimeException(); + })); + + //then + final User user = userDao.findById(id); + assertThat(user.getPassword()) + .isNotEqualTo(newPassword); + } +} diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 247e8d2dac..3ac1568efd 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -6,6 +6,8 @@ import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.jdbc.datasource.DataSourceUtils; +import org.springframework.transaction.support.TransactionSynchronizationManager; import javax.sql.DataSource; import java.sql.Connection; @@ -28,23 +30,28 @@ public JdbcTemplate(final DataSource dataSource) { this.dataSource = dataSource; } - public int update(final Connection connection, final String sql, final Object... args) { + public int update(final String sql, final Object... args) { + final Connection connection = DataSourceUtils.getConnection(dataSource); try (final PreparedStatement ps = createPreparedStatement(connection, sql, args)) { return ps.executeUpdate(); } catch (SQLException e) { log.error(e.getMessage()); throw new DataAccessException(e); + } finally { + if (!TransactionSynchronizationManager.isActualTransactionActive()) { + DataSourceUtils.releaseConnection(connection, dataSource); + } } } public List query(final String sql, final ResultSetMapper resultSetMapper) { - try (final Connection connection = dataSource.getConnection(); - final PreparedStatement ps = connection.prepareStatement(sql); + final Connection connection = DataSourceUtils.getConnection(dataSource); + try (final PreparedStatement ps = connection.prepareStatement(sql); final ResultSet resultSet = ps.executeQuery() ) { log.debug("query: {}", sql); final List results = new ArrayList<>(); - while(resultSet.next()) { + while (resultSet.next()) { results.add(resultSetMapper.apply(resultSet)); } return results; diff --git a/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java b/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java index 3c40bfec52..395f375bf5 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java +++ b/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java @@ -32,6 +32,8 @@ public static void releaseConnection(Connection connection, DataSource dataSourc connection.close(); } catch (SQLException ex) { throw new CannotGetJdbcConnectionException("Failed to close JDBC Connection"); + } finally { + TransactionSynchronizationManager.unbindResource(dataSource); } } } diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionCallback.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionCallback.java new file mode 100644 index 0000000000..02b723950f --- /dev/null +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionCallback.java @@ -0,0 +1,9 @@ +package org.springframework.transaction.support; + +import java.sql.Connection; + +@FunctionalInterface +public interface TransactionCallback { + + void doInTransaction(Connection connection); +} diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java index 715557fc66..68fd3622e1 100644 --- a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -1,23 +1,58 @@ package org.springframework.transaction.support; +import javax.annotation.Nullable; import javax.sql.DataSource; import java.sql.Connection; +import java.util.HashMap; import java.util.Map; public abstract class TransactionSynchronizationManager { private static final ThreadLocal> resources = new ThreadLocal<>(); + private static final ThreadLocal actualTransactionActive = ThreadLocal.withInitial(() -> Boolean.FALSE); private TransactionSynchronizationManager() {} - public static Connection getResource(DataSource key) { - return null; + @Nullable + public static Connection getResource(final DataSource key) { + final Map map = resources.get(); + if (map == null) { + return null; + } + return map.get(key); } - public static void bindResource(DataSource key, Connection value) { + public static void bindResource(final DataSource key, final Connection value) { + final Map map = resources.get(); + if (map == null) { + final Map newMap = new HashMap<>(); + newMap.put(key, value); + resources.set(newMap); + return; + } + map.put(key, value); + resources.set(map); } - public static Connection unbindResource(DataSource key) { - return null; + public static void unbindResource(final DataSource key) { + final Map map = resources.get(); + if (map == null) { + throw new IllegalStateException("No resource for key [" + key + "] bound to thread"); + } + if (map.get(key) == null) { + throw new IllegalStateException("No value for key [" + key + "] bound to thread"); + } + map.remove(key); + if (map.isEmpty()) { + resources.remove(); + } + } + + public static boolean isActualTransactionActive() { + return actualTransactionActive.get(); + } + + public static void setActualTransactionActiveTrue() { + actualTransactionActive.set(true); } } diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionTemplate.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionTemplate.java new file mode 100644 index 0000000000..33d360c7fc --- /dev/null +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionTemplate.java @@ -0,0 +1,48 @@ +package org.springframework.transaction.support; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.SQLException; + +public class TransactionTemplate { + + private static final Logger log = LoggerFactory.getLogger(TransactionTemplate.class); + + private final DataSource dataSource; + + public TransactionTemplate(final DataSource dataSource) { + this.dataSource = dataSource; + } + + public void execute(final TransactionCallback transactionCallback) { + final Connection connection = DataSourceUtils.getConnection(dataSource); + try { + if (!TransactionSynchronizationManager.isActualTransactionActive()) { + connection.setAutoCommit(false); + TransactionSynchronizationManager.setActualTransactionActiveTrue(); + transactionCallback.doInTransaction(connection); + connection.commit(); + } + } catch (SQLException e) { + log.error(e.getMessage(), e); + rollback(connection); + throw new DataAccessException("transaction 설정에 오류가 발생했습니다."); + } finally { + DataSourceUtils.releaseConnection(connection, dataSource); + } + } + + private void rollback(final Connection connection) { + try { + connection.rollback(); + } catch (SQLException e) { + log.error(e.getMessage(), e); + throw new DataAccessException("rollback 에 실패했습니다."); + } + } +} diff --git a/jdbc/src/test/java/nextstep/jdbc/JdbcTemplateTest.java b/jdbc/src/test/java/nextstep/jdbc/JdbcTemplateTest.java index d777c8becf..ee4f307424 100644 --- a/jdbc/src/test/java/nextstep/jdbc/JdbcTemplateTest.java +++ b/jdbc/src/test/java/nextstep/jdbc/JdbcTemplateTest.java @@ -1,5 +1,19 @@ package nextstep.jdbc; +import org.junit.jupiter.api.BeforeEach; +import org.mockito.Mockito; +import org.springframework.jdbc.core.JdbcTemplate; + +import javax.sql.DataSource; + class JdbcTemplateTest { + DataSource dataSource; + JdbcTemplate jdbcTemplate; + + @BeforeEach + void setUp() { + dataSource = Mockito.mock(DataSource.class); + jdbcTemplate = new JdbcTemplate(dataSource); + } } diff --git a/jdbc/src/test/java/nextstep/jdbc/TransactionSynchronizationManagerTest.java b/jdbc/src/test/java/nextstep/jdbc/TransactionSynchronizationManagerTest.java new file mode 100644 index 0000000000..ce1bcabd39 --- /dev/null +++ b/jdbc/src/test/java/nextstep/jdbc/TransactionSynchronizationManagerTest.java @@ -0,0 +1,75 @@ +package nextstep.jdbc; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.transaction.support.TransactionSynchronizationManager; + +import javax.sql.DataSource; +import java.sql.Connection; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TransactionSynchronizationManagerTest { + + DataSource dataSource; + + @BeforeEach + void setUp() { + dataSource = Mockito.mock(DataSource.class); + } + + @Test + void getResourceTest_when_datasource_not_exist() { + final Connection resource = TransactionSynchronizationManager.getResource(dataSource); + assertThat(resource).isNull(); + } + + @Test + void getResourceTest_when_datasource_exist() { + // given + final Connection expected = Mockito.mock(Connection.class); + TransactionSynchronizationManager.bindResource(dataSource, expected); + + // when + final Connection actual = TransactionSynchronizationManager.getResource(dataSource); + + // then + assertThat(actual).isEqualTo(expected); + } + + @Test + void unbindResource_success() { + // given + final Connection connection = Mockito.mock(Connection.class); + TransactionSynchronizationManager.bindResource(dataSource, connection); + assertThat(TransactionSynchronizationManager.getResource(dataSource)).isNotNull(); + + // when + TransactionSynchronizationManager.unbindResource(dataSource); + + // then + assertThat(TransactionSynchronizationManager.getResource(dataSource)).isNull(); + } + + @Test + void unbindResource_fail_when_no_dataSource() { + assertThatThrownBy(() -> TransactionSynchronizationManager.unbindResource(dataSource)) + .isInstanceOf(IllegalStateException.class); + } + + @Test + void isActualTransactionActive_false() { + assertThat(TransactionSynchronizationManager.isActualTransactionActive()).isFalse(); + } + + @Test + void isActualTransactionActive_true() { + // given + TransactionSynchronizationManager.setActualTransactionActiveTrue(); + + // then + assertThat(TransactionSynchronizationManager.isActualTransactionActive()).isTrue(); + } +}