diff --git a/app/src/main/java/com/techcourse/dao/UserDao.java b/app/src/main/java/com/techcourse/dao/UserDao.java index c5839dbe09..d6e5f74009 100644 --- a/app/src/main/java/com/techcourse/dao/UserDao.java +++ b/app/src/main/java/com/techcourse/dao/UserDao.java @@ -1,7 +1,6 @@ package com.techcourse.dao; import com.techcourse.domain.User; -import java.sql.Connection; import java.util.List; import java.util.NoSuchElementException; import javax.sql.DataSource; @@ -14,8 +13,6 @@ public class UserDao { private static final Logger log = LoggerFactory.getLogger(UserDao.class); - private final JdbcTemplate jdbcTemplate; - private static final RowMapper userRowMapper = (resultSet, rowNum) -> new User( resultSet.getLong("id"), resultSet.getString("account"), @@ -23,6 +20,8 @@ public class UserDao { resultSet.getString("email") ); + private final JdbcTemplate jdbcTemplate; + public UserDao(final DataSource dataSource) { this.jdbcTemplate = new JdbcTemplate(dataSource); } @@ -36,63 +35,25 @@ public void insert(final User user) { jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail()); } - public void insert(final Connection connection, final User user) { - final String sql = "insert into users (account, password, email) values (?, ?, ?)"; - jdbcTemplate.update(connection, - sql, - user.getAccount(), user.getPassword(), user.getEmail()); - } - public void update(final User user) { final String sql = "update users set account = ?, password =?, email = ? where id = ?"; jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); } - public void update(final Connection connection, final User user) { - final String sql = "update users set account = ?, password =?, email = ? where id = ?"; - jdbcTemplate.update(connection, - sql, - user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); - } - public List findAll() { final String sql = "select id, account, password, email from users"; return jdbcTemplate.query(sql, userRowMapper); } - public List findAll(final Connection connection) { - final String sql = "select id, account, password, email from users"; - return jdbcTemplate.query(connection, - sql, - userRowMapper); - } - public User findById(final Long id) { final String sql = "select id, account, password, email from users where id = ?"; return jdbcTemplate.queryForObject(sql, userRowMapper, id) .orElseThrow(() -> new NoSuchElementException("id에 해당하는 user가 없습니다.")); } - public User findById(final Connection connection, final Long id) { - final String sql = "select id, account, password, email from users where id = ?"; - return jdbcTemplate.queryForObject(connection, - sql, - userRowMapper, - id) - .orElseThrow(() -> new NoSuchElementException("id에 해당하는 user가 없습니다.")); - } - public User findByAccount(final String account) { final String sql = "select id, account, password, email from users where account = ?"; return jdbcTemplate.queryForObject(sql, userRowMapper, account) .orElseThrow(() -> new NoSuchElementException("account에 해당하는 user가 없습니다.")); } - - public User findByAccount(final Connection connection, final String account) { - final String sql = "select id, account, password, email from users where account = ?"; - return jdbcTemplate.queryForObject(connection, - sql, - userRowMapper, account) - .orElseThrow(() -> new NoSuchElementException("account에 해당하는 user가 없습니다.")); - } } diff --git a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java index a740de8a9c..86d28240e9 100644 --- a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java +++ b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java @@ -1,7 +1,6 @@ package com.techcourse.dao; import com.techcourse.domain.UserHistory; -import java.sql.Connection; import javax.sql.DataSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,17 +31,4 @@ public void log(final UserHistory userHistory) { userHistory.getCreateBy() ); } - - public void log(final Connection connection, final UserHistory userHistory) { - final String sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; - jdbcTemplate.update(connection, - sql, - userHistory.getUserId(), - userHistory.getAccount(), - userHistory.getPassword(), - userHistory.getEmail(), - userHistory.getCreatedAt(), - userHistory.getCreateBy() - ); - } } diff --git a/app/src/main/java/com/techcourse/service/AppUserService.java b/app/src/main/java/com/techcourse/service/AppUserService.java new file mode 100644 index 0000000000..feeeca2c24 --- /dev/null +++ b/app/src/main/java/com/techcourse/service/AppUserService.java @@ -0,0 +1,37 @@ +package com.techcourse.service; + +import com.techcourse.dao.UserDao; +import com.techcourse.dao.UserHistoryDao; +import com.techcourse.domain.User; +import com.techcourse.domain.UserHistory; + +public class AppUserService implements UserService { + + private final UserDao userDao; + private final UserHistoryDao userHistoryDao; + + public AppUserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { + this.userDao = userDao; + this.userHistoryDao = userHistoryDao; + } + + @Override + public User findById(final long id) { + return userDao.findById(id); + } + + @Override + public void insert(final User user) { + userDao.insert(user); + } + + @Override + public void changePassword(final long id, + final String newPassword, + final String createBy) { + final User user = userDao.findById(id); + user.changePassword(newPassword); + userDao.update(user); + userHistoryDao.log(new UserHistory(user, createBy)); + } +} diff --git a/app/src/main/java/com/techcourse/service/TxUserService.java b/app/src/main/java/com/techcourse/service/TxUserService.java new file mode 100644 index 0000000000..1ba3379b95 --- /dev/null +++ b/app/src/main/java/com/techcourse/service/TxUserService.java @@ -0,0 +1,43 @@ +package com.techcourse.service; + +import com.techcourse.domain.User; +import org.springframework.transaction.TransactionManager; + +public class TxUserService implements UserService { + + private final UserService userService; + private final TransactionManager transactionManager; + + public TxUserService(final UserService userService, final TransactionManager transactionManager) { + this.userService = userService; + this.transactionManager = transactionManager; + } + + @Override + public User findById(final long id) { + return transactionManager.execute( + () -> userService.findById(id) + ); + } + + @Override + public void insert(final User user) { + transactionManager.execute( + () -> { + userService.insert(user); + return null; + }); + } + + @Override + public void changePassword(final long id, + final String newPassword, + final String createBy) { + transactionManager.execute( + () -> { + userService.changePassword(id, newPassword, createBy); + return null; + } + ); + } +} diff --git a/app/src/main/java/com/techcourse/service/UserService.java b/app/src/main/java/com/techcourse/service/UserService.java index e63313eb5c..b14dbcacbf 100644 --- a/app/src/main/java/com/techcourse/service/UserService.java +++ b/app/src/main/java/com/techcourse/service/UserService.java @@ -1,43 +1,12 @@ package com.techcourse.service; -import com.techcourse.config.DataSourceConfig; -import com.techcourse.dao.UserDao; -import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; -import com.techcourse.domain.UserHistory; -import org.springframework.transaction.TransactionManager; -public class UserService { +public interface UserService { - private final UserDao userDao; - private final UserHistoryDao userHistoryDao; - private final TransactionManager transactionManager; + User findById(final long id); - public UserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { - this.userDao = userDao; - this.userHistoryDao = userHistoryDao; - this.transactionManager = new TransactionManager(DataSourceConfig.getInstance()); - } + void insert(final User user); - public User findById(final long id) { - return transactionManager.execute(connection -> userDao.findById(connection, id)); - } - - public void insert(final User user) { - transactionManager.execute(connection -> { - userDao.insert(connection, user); - return null; - }); - } - - public void changePassword(final long id, final String newPassword, final String createBy) { - transactionManager.execute(connection -> { - final var user = userDao.findById(connection, id); - user.changePassword(newPassword); - userDao.update(connection, user); - userHistoryDao.log(connection, new UserHistory(user, createBy)); - return null; - } - ); - } + void changePassword(final long id, final String newPassword, final String createBy); } diff --git a/app/src/test/java/com/techcourse/dao/UserDaoTest.java b/app/src/test/java/com/techcourse/dao/UserDaoTest.java index 04b32e8928..e5de85b69e 100644 --- a/app/src/test/java/com/techcourse/dao/UserDaoTest.java +++ b/app/src/test/java/com/techcourse/dao/UserDaoTest.java @@ -6,8 +6,6 @@ import com.techcourse.config.DataSourceConfig; import com.techcourse.domain.User; import com.techcourse.support.jdbc.init.DatabasePopulatorUtils; -import java.sql.Connection; -import java.sql.SQLException; import java.util.NoSuchElementException; import javax.sql.DataSource; import org.junit.jupiter.api.BeforeEach; @@ -30,32 +28,28 @@ class UserDaoTest { private final DataSource dataSource = DataSourceConfig.getInstance(); - private Connection connection; - @BeforeEach - void setup() throws SQLException { - connection = dataSource.getConnection(); - - JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); - jdbcTemplate.execute(connection, INIT_USER_TABLE_SQL); + void setup() { + final JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); + jdbcTemplate.execute(INIT_USER_TABLE_SQL); DatabasePopulatorUtils.execute(dataSource); userDao = new UserDao(dataSource); - final var user = new User("gugu", "password", "hkkang@woowahan.com"); - userDao.insert(connection, user); + final User user = new User("gugu", "password", "hkkang@woowahan.com"); + userDao.insert(user); } @Test void findAll() { - final var users = userDao.findAll(connection); + final var users = userDao.findAll(); assertThat(users).isNotEmpty(); } @Test void findById() { - final var user = userDao.findById(connection, 1L); + final var user = userDao.findById(1L); assertThat(user.getAccount()).isEqualTo("gugu"); } @@ -63,7 +57,7 @@ void findById() { @Test void findById_fail() { assertThatThrownBy( - () -> userDao.findById(connection, -1L) + () -> userDao.findById(-1L) ).isInstanceOf(NoSuchElementException.class) .hasMessage("id에 해당하는 user가 없습니다."); } @@ -71,7 +65,7 @@ void findById_fail() { @Test void findByAccount() { final var account = "gugu"; - final var user = userDao.findByAccount(connection, account); + final var user = userDao.findByAccount(account); assertThat(user.getAccount()).isEqualTo(account); } @@ -79,7 +73,7 @@ void findByAccount() { @Test void findByAccount_fail() { assertThatThrownBy( - () -> userDao.findByAccount(connection, "joy") + () -> userDao.findByAccount("joy") ).isInstanceOf(NoSuchElementException.class) .hasMessage("account에 해당하는 user가 없습니다."); } @@ -88,9 +82,9 @@ void findByAccount_fail() { void insert() { final var account = "insert-gugu"; final var user = new User(account, "password", "hkkang@woowahan.com"); - userDao.insert(connection, user); + userDao.insert(user); - final var actual = userDao.findById(connection, 2L); + final var actual = userDao.findById(2L); assertThat(actual.getAccount()).isEqualTo(account); } @@ -98,12 +92,12 @@ void insert() { @Test void update() { final var newPassword = "password99"; - final var user = userDao.findById(connection, 1L); + final var user = userDao.findById(1L); user.changePassword(newPassword); - userDao.update(connection, user); + userDao.update(user); - final var actual = userDao.findById(connection, 1L); + final var actual = userDao.findById(1L); assertThat(actual.getPassword()).isEqualTo(newPassword); } diff --git a/app/src/test/java/com/techcourse/dao/UserHistoryDaoTest.java b/app/src/test/java/com/techcourse/dao/UserHistoryDaoTest.java index 4bc0510f9c..84de774f40 100644 --- a/app/src/test/java/com/techcourse/dao/UserHistoryDaoTest.java +++ b/app/src/test/java/com/techcourse/dao/UserHistoryDaoTest.java @@ -5,15 +5,13 @@ import com.techcourse.config.DataSourceConfig; import com.techcourse.domain.UserHistory; import com.techcourse.support.jdbc.init.DatabasePopulatorUtils; -import java.sql.Connection; -import java.sql.SQLException; import javax.sql.DataSource; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; class UserHistoryDaoTest { - public static final UserHistory USER_HISTORY = new UserHistory(1L, + private static final UserHistory USER_HISTORY = new UserHistory(1L, 1L, "joy", "joy1234", @@ -22,15 +20,12 @@ class UserHistoryDaoTest { private final DataSource dataSource = DataSourceConfig.getInstance(); - private Connection connection; - private UserHistoryDao userHistoryDao; @BeforeEach - void setUp() throws SQLException { + void setUp() { DatabasePopulatorUtils.execute(dataSource); userHistoryDao = new UserHistoryDao(dataSource); - connection = dataSource.getConnection(); } @Test @@ -39,11 +34,4 @@ void log() { () -> userHistoryDao.log(USER_HISTORY) ); } - - @Test - void connection_log() { - assertDoesNotThrow( - () -> userHistoryDao.log(connection, USER_HISTORY) - ); - } } diff --git a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java index 4f6f9aba26..2ee12b195f 100644 --- a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java +++ b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java @@ -2,7 +2,6 @@ import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.UserHistory; -import java.sql.Connection; import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; @@ -13,7 +12,7 @@ public MockUserHistoryDao(final JdbcTemplate jdbcTemplate) { } @Override - public void log(final Connection connection, final UserHistory userHistory) { + public void log(final UserHistory userHistory) { throw new DataAccessException(); } } diff --git a/app/src/test/java/com/techcourse/service/UserServiceTest.java b/app/src/test/java/com/techcourse/service/UserServiceTest.java index 037e11c1f7..5676494367 100644 --- a/app/src/test/java/com/techcourse/service/UserServiceTest.java +++ b/app/src/test/java/com/techcourse/service/UserServiceTest.java @@ -8,35 +8,35 @@ import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; import com.techcourse.support.jdbc.init.DatabasePopulatorUtils; -import java.sql.Connection; -import java.sql.SQLException; import javax.sql.DataSource; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.transaction.TransactionManager; class UserServiceTest { + private final DataSource dataSource = DataSourceConfig.getInstance(); private JdbcTemplate jdbcTemplate; private UserDao userDao; @BeforeEach - void setUp() throws SQLException { - final DataSource dataSource = DataSourceConfig.getInstance(); - final Connection connection = dataSource.getConnection(); + void setUp() { this.jdbcTemplate = new JdbcTemplate(dataSource); this.userDao = new UserDao(jdbcTemplate); DatabasePopulatorUtils.execute(dataSource); final var user = new User("gugu", "password", "hkkang@woowahan.com"); - userDao.insert(connection, user); + userDao.insert(user); } @Test void testChangePassword() { final var userHistoryDao = new UserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final var appUserService = new AppUserService(userDao, userHistoryDao); + final TransactionManager transactionManager = new TransactionManager(dataSource); + final var userService = new TxUserService(appUserService, transactionManager); final var newPassword = "qqqqq"; final var createBy = "gugu"; @@ -51,7 +51,11 @@ void testChangePassword() { void testTransactionRollback() { // 트랜잭션 롤백 테스트를 위해 mock으로 교체 final var userHistoryDao = new MockUserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + // 애플리케이션 서비스 + final var appUserService = new AppUserService(userDao, userHistoryDao); + // 트랜잭션 서비스 추상화 + final TransactionManager transactionManager = new TransactionManager(dataSource); + final var userService = new TxUserService(appUserService, transactionManager); final var newPassword = "newPassword"; final var createBy = "gugu"; diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/BaseJdbcTemplate.java b/jdbc/src/main/java/org/springframework/jdbc/core/BaseJdbcTemplate.java index fd8ad6cb45..b0ff56f187 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/BaseJdbcTemplate.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/BaseJdbcTemplate.java @@ -1,7 +1,5 @@ package org.springframework.jdbc.core; -import static java.util.Objects.requireNonNull; - import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; @@ -9,6 +7,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; public class BaseJdbcTemplate { @@ -16,26 +15,15 @@ public class BaseJdbcTemplate { private final DataSource dataSource; - public BaseJdbcTemplate(DataSource dataSource) { + public BaseJdbcTemplate(final DataSource dataSource) { this.dataSource = dataSource; } public T execute(final String sql, - final PreparedStatementAction action) { - try ( - final Connection connection = requireNonNull(dataSource).getConnection(); - final PreparedStatement preparedStatement = connection.prepareStatement(sql) - ) { - return action.execute(preparedStatement); - } catch (SQLException e) { - log.error(e.getMessage(), e); - throw new DataAccessException(e); - } - } + final PreparedStatementAction action + ) { + final Connection connection = DataSourceUtils.getConnection(dataSource); - public T execute(final Connection connection, - final String sql, - final PreparedStatementAction action) { try ( final PreparedStatement preparedStatement = connection.prepareStatement(sql) ) { diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 029397a53a..8ce9adeae5 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -1,6 +1,5 @@ package org.springframework.jdbc.core; -import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; @@ -22,7 +21,8 @@ public JdbcTemplate(final DataSource dataSource) { this.baseJdbcTemplate = new BaseJdbcTemplate(dataSource); } - public void update(final String sql, final Object... args) { + public void update(final String sql, + final Object... args) { baseJdbcTemplate.execute(sql, preparedStatement -> { setArguments(args, preparedStatement); @@ -32,30 +32,10 @@ public void update(final String sql, final Object... args) { ); } - public void update(final Connection connection, final String sql, final Object... args) { - baseJdbcTemplate.execute(connection, - sql, - preparedStatement -> { - setArguments(args, preparedStatement); - preparedStatement.executeUpdate(); - return null; - } - ); - } - - public Optional queryForObject(final String sql, final RowMapper rowMapper, final Object... args) { - return baseJdbcTemplate.execute(sql, - preparedStatement -> { - setArguments(args, preparedStatement); - final List results = getResults(rowMapper, preparedStatement.executeQuery()); - return Optional.ofNullable(getResult(results)); - }); - } - - public Optional queryForObject(final Connection connection, final String sql, final RowMapper rowMapper, + public Optional queryForObject(final String sql, + final RowMapper rowMapper, final Object... args) { - return baseJdbcTemplate.execute(connection, - sql, + return baseJdbcTemplate.execute(sql, preparedStatement -> { setArguments(args, preparedStatement); final List results = getResults(rowMapper, preparedStatement.executeQuery()); @@ -63,30 +43,26 @@ public Optional queryForObject(final Connection connection, final String }); } - private T getResult(List results) { + private T getResult(final List results) { if (results.isEmpty()) { return null; } if (results.size() > 1) { - throw new DataAccessException("조회된 데이터 수가 1을 초과합니다"); + throw new DataAccessException("조회된 데이터 수가 1을 초과합니다. count: " + results.size()); } return results.get(0); } - public List query(final String sql, final RowMapper rowMapper) { + public List query(final String sql, + final RowMapper rowMapper) { return baseJdbcTemplate.execute(sql, preparedStatement -> getResults(rowMapper, preparedStatement.executeQuery())); } - public List query(final Connection connection, final String sql, final RowMapper rowMapper) { - return baseJdbcTemplate.execute(connection, - sql, - preparedStatement -> getResults(rowMapper, preparedStatement.executeQuery())); - } - - private List getResults(final RowMapper rowMapper, final ResultSet resultSet) throws SQLException { + private List getResults(final RowMapper rowMapper, + final ResultSet resultSet) throws SQLException { final ArrayList results = new ArrayList<>(); while (resultSet.next()) { final T result = rowMapper.mapRow(resultSet, resultSet.getRow()); @@ -95,7 +71,9 @@ private List getResults(final RowMapper rowMapper, final ResultSet res return results; } - private void setArguments(final Object[] args, final PreparedStatement preparedStatement) throws SQLException { + private void setArguments(final Object[] args, + final PreparedStatement preparedStatement + ) throws SQLException { for (int i = 0; i < args.length; i++) { preparedStatement.setObject(i + 1, args[i]); } @@ -104,10 +82,4 @@ private void setArguments(final Object[] args, final PreparedStatement preparedS public void execute(final String sql) { baseJdbcTemplate.execute(sql, PreparedStatement::execute); } - - public void execute(final Connection connection, final String sql) { - baseJdbcTemplate.execute(connection, - sql, - PreparedStatement::execute); - } } diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementAction.java b/jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementAction.java index c24edb47af..349fd608aa 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementAction.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementAction.java @@ -6,5 +6,5 @@ @FunctionalInterface public interface PreparedStatementAction { - T execute(PreparedStatement preparedStatement) throws SQLException; + T execute(final PreparedStatement preparedStatement) throws SQLException; } diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/RowMapper.java b/jdbc/src/main/java/org/springframework/jdbc/core/RowMapper.java index b7f9133b54..8e21c74cf7 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/RowMapper.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/RowMapper.java @@ -6,5 +6,5 @@ @FunctionalInterface public interface RowMapper { - T mapRow(ResultSet resultSet, int rowNum) throws SQLException; + T mapRow(final ResultSet resultSet, final int rowNum) throws SQLException; } diff --git a/jdbc/src/main/java/org/springframework/transaction/ConnectionAction.java b/jdbc/src/main/java/org/springframework/transaction/ConnectionAction.java deleted file mode 100644 index 28d9017379..0000000000 --- a/jdbc/src/main/java/org/springframework/transaction/ConnectionAction.java +++ /dev/null @@ -1,9 +0,0 @@ -package org.springframework.transaction; - -import java.sql.Connection; - -@FunctionalInterface -public interface ConnectionAction { - - T execute(Connection connection); -} diff --git a/jdbc/src/main/java/org/springframework/transaction/TransactionAction.java b/jdbc/src/main/java/org/springframework/transaction/TransactionAction.java new file mode 100644 index 0000000000..0a996ce3f0 --- /dev/null +++ b/jdbc/src/main/java/org/springframework/transaction/TransactionAction.java @@ -0,0 +1,7 @@ +package org.springframework.transaction; + +@FunctionalInterface +public interface TransactionAction { + + T execute(); +} diff --git a/jdbc/src/main/java/org/springframework/transaction/TransactionManager.java b/jdbc/src/main/java/org/springframework/transaction/TransactionManager.java index 582e88d16f..c688bf76a2 100644 --- a/jdbc/src/main/java/org/springframework/transaction/TransactionManager.java +++ b/jdbc/src/main/java/org/springframework/transaction/TransactionManager.java @@ -1,13 +1,13 @@ package org.springframework.transaction; -import static java.util.Objects.requireNonNull; - import java.sql.Connection; import java.sql.SQLException; import javax.sql.DataSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; +import org.springframework.transaction.support.TransactionSynchronizationManager; public class TransactionManager { @@ -15,30 +15,36 @@ public class TransactionManager { private final DataSource dataSource; - public TransactionManager(DataSource dataSource) { + public TransactionManager(final DataSource dataSource) { this.dataSource = dataSource; } - public T execute(final ConnectionAction action) { - try (Connection connection = requireNonNull(dataSource).getConnection()) { + public T execute(final TransactionAction action) { + final Connection connection = DataSourceUtils.getConnection(dataSource); + try { return executeAction(action, connection); } catch (SQLException e) { throw new DataAccessException(e); + } finally { + DataSourceUtils.releaseConnection(connection, dataSource); + TransactionSynchronizationManager.unbindResource(dataSource); } } - private T executeAction(final ConnectionAction action, final Connection connection) throws SQLException { + private T executeAction(final TransactionAction action, + final Connection connection) throws SQLException { try { return commit(action, connection); - } catch (SQLException e) { + } catch (RuntimeException e) { connection.rollback(); throw new DataAccessException(e); } } - private T commit(final ConnectionAction action, final Connection connection) throws SQLException { + private T commit(final TransactionAction action, + final Connection connection) throws SQLException { connection.setAutoCommit(false); - final T result = action.execute(connection); + final T result = action.execute(); connection.commit(); return result; diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java index 715557fc66..08cf1011a4 100644 --- a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -1,23 +1,30 @@ package org.springframework.transaction.support; -import javax.sql.DataSource; import java.sql.Connection; +import java.util.HashMap; import java.util.Map; +import javax.sql.DataSource; public abstract class TransactionSynchronizationManager { - private static final ThreadLocal> resources = new ThreadLocal<>(); + private static final ThreadLocal> resources = ThreadLocal.withInitial(HashMap::new); - private TransactionSynchronizationManager() {} + private TransactionSynchronizationManager() { + } - public static Connection getResource(DataSource key) { - return null; + public static Connection getResource(final DataSource key) { + final Map resource = resources.get(); + return resource.get(key); } - public static void bindResource(DataSource key, Connection value) { + public static void bindResource(final DataSource key, + final Connection value) { + final Map resource = resources.get(); + resource.put(key, value); } - public static Connection unbindResource(DataSource key) { - return null; + public static Connection unbindResource(final DataSource key) { + final Map resource = resources.get(); + return resource.remove(key); } }