diff --git a/app/src/main/java/com/techcourse/dao/UserDao.java b/app/src/main/java/com/techcourse/dao/UserDao.java index 8de46b98c7..b53aedb837 100644 --- a/app/src/main/java/com/techcourse/dao/UserDao.java +++ b/app/src/main/java/com/techcourse/dao/UserDao.java @@ -6,7 +6,6 @@ import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; -import java.sql.Connection; import java.util.List; import java.util.NoSuchElementException; @@ -20,32 +19,32 @@ public UserDao(final JdbcTemplate jdbcTemplate) { this.jdbcTemplate = jdbcTemplate; } - public void insert(Connection con, final User user) { + public void insert(final User user) { final var sql = "insert into users (account, password, email) values (?, ?, ?)"; - jdbcTemplate.update(con,sql, user.getAccount(), user.getPassword(), user.getEmail()); + jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail()); } - public void update(Connection con,final User user) { + public void update(final User user) { String sql = "UPDATE users SET account = ?, password = ?, email = ? WHERE id = ?"; - jdbcTemplate.update(con,sql, user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); + jdbcTemplate.update(sql, user.getAccount(), user.getPassword(), user.getEmail(), user.getId()); } - public List findAll(Connection con) { + public List findAll() { String sql = "SELECT id, account, password, email FROM users"; - return jdbcTemplate.query(con,sql, getUserRowMapper()); + return jdbcTemplate.query(sql, getUserRowMapper()); } - public User findById(Connection con,final Long id) { + public User findById(final Long id) { final var sql = "select id, account, password, email from users where id = ?"; - return jdbcTemplate.queryForObject(con,sql, getUserRowMapper(), id) + return jdbcTemplate.queryForObject(sql, getUserRowMapper(), id) .orElseThrow(() -> new NoSuchElementException("결과가 존재하지 않습니다")); } - public User findByAccount(Connection con,final String account) { + public User findByAccount(final String account) { String sql = "SELECT id, account, password, email FROM users WHERE account = ?"; - return jdbcTemplate.queryForObject(con,sql, getUserRowMapper(), account) + return jdbcTemplate.queryForObject(sql, getUserRowMapper(), account) .orElseThrow(() -> new NoSuchElementException("결과가 존재하지 않습니다")); } diff --git a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java index 27e9c8e038..b1ae0a4ea9 100644 --- a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java +++ b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java @@ -5,8 +5,6 @@ import org.slf4j.LoggerFactory; import org.springframework.jdbc.core.JdbcTemplate; -import java.sql.Connection; - public class UserHistoryDao { private static final Logger log = LoggerFactory.getLogger(UserHistoryDao.class); @@ -17,9 +15,9 @@ public UserHistoryDao(final JdbcTemplate jdbcTemplate) { this.jdbcTemplate = jdbcTemplate; } - public void log(Connection con, final UserHistory userHistory) { + public void log(final UserHistory userHistory) { final var sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; - jdbcTemplate.update(con, sql, + jdbcTemplate.update(sql, userHistory.getUserId(), userHistory.getAccount(), userHistory.getPassword(), diff --git a/app/src/main/java/com/techcourse/service/AppUserService.java b/app/src/main/java/com/techcourse/service/AppUserService.java new file mode 100644 index 0000000000..5216e7e9ff --- /dev/null +++ b/app/src/main/java/com/techcourse/service/AppUserService.java @@ -0,0 +1,33 @@ +package com.techcourse.service; + +import com.techcourse.dao.UserDao; +import com.techcourse.dao.UserHistoryDao; +import com.techcourse.domain.User; +import com.techcourse.domain.UserHistory; + +public class AppUserService implements UserService { + + private final UserDao userDao; + private final UserHistoryDao userHistoryDao; + + public AppUserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { + this.userDao = userDao; + this.userHistoryDao = userHistoryDao; + } + + public User findById(final long id) { + return userDao.findById(id); + } + + + public void insert(final User user) { + userDao.insert(user); + } + + public void changePassword(final long id, final String newPassword, final String createBy) { + final var user = userDao.findById(id); + user.changePassword(newPassword); + userDao.update(user); + userHistoryDao.log(new UserHistory(user, createBy)); + } +} diff --git a/app/src/main/java/com/techcourse/service/TransactionHandler.java b/app/src/main/java/com/techcourse/service/TransactionHandler.java new file mode 100644 index 0000000000..0401e32004 --- /dev/null +++ b/app/src/main/java/com/techcourse/service/TransactionHandler.java @@ -0,0 +1,47 @@ +package com.techcourse.service; + +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; +import org.springframework.transaction.support.TransactionSynchronizationManager; + +import javax.sql.DataSource; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.sql.Connection; +import java.sql.SQLException; + +public class TransactionHandler implements InvocationHandler { + + private final Object service; + private final DataSource dataSource; + + public TransactionHandler(Object service, DataSource dataSource) { + this.service = service; + this.dataSource = dataSource; + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + Connection connection = DataSourceUtils.getConnection(dataSource); + TransactionSynchronizationManager.bindResource(dataSource, connection); + try { + connection.setAutoCommit(false); + Object invoke = method.invoke(service, args); + connection.commit(); + return invoke; + } catch (Exception e) { + rollback(connection); + throw new DataAccessException(e); + } finally { + DataSourceUtils.releaseConnection(connection, dataSource); + } + } + + private void rollback(Connection connection) { + try { + connection.rollback(); + } catch (SQLException ex) { + throw new DataAccessException(ex); + } + } +} diff --git a/app/src/main/java/com/techcourse/service/UserService.java b/app/src/main/java/com/techcourse/service/UserService.java index 7b9d04038d..42d01bf760 100644 --- a/app/src/main/java/com/techcourse/service/UserService.java +++ b/app/src/main/java/com/techcourse/service/UserService.java @@ -1,88 +1,10 @@ package com.techcourse.service; -import com.techcourse.dao.UserDao; -import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; -import com.techcourse.domain.UserHistory; -import org.springframework.dao.DataAccessException; -import javax.sql.DataSource; -import java.sql.Connection; -import java.sql.SQLException; +public interface UserService { -public class UserService { - - private final UserDao userDao; - private final UserHistoryDao userHistoryDao; - private final DataSource dataSource; - - public UserService(final UserDao userDao, final UserHistoryDao userHistoryDao, DataSource dataSource) { - this.userDao = userDao; - this.userHistoryDao = userHistoryDao; - this.dataSource = dataSource; - } - - public User findById(final long id) { - try (Connection con = dataSource.getConnection()) { - try { - con.setAutoCommit(false); - - User user = userDao.findById(con, id); - - con.commit(); - return user; - } catch (Exception e) { - con.rollback(); - throw new DataAccessException(e); - } finally { - con.setAutoCommit(true); - con.close(); - } - } catch (SQLException e) { - throw new DataAccessException(e); - } - } - - public void insert(final User user) { - try (Connection con = dataSource.getConnection()) { - try { - con.setAutoCommit(false); - - userDao.insert(con, user); - - con.commit(); - } catch (Exception e) { - con.rollback(); - throw new DataAccessException(e); - } finally { - con.setAutoCommit(true); - con.close(); - } - } catch (SQLException e) { - throw new DataAccessException(e); - } - } - - public void changePassword(final long id, final String newPassword, final String createBy) { - try (Connection con = dataSource.getConnection()) { - try { - con.setAutoCommit(false); - - final var user = findById(id); - user.changePassword(newPassword); - userDao.update(con, user); - userHistoryDao.log(con, new UserHistory(user, createBy)); - - con.commit(); - } catch (Exception e) { - con.rollback(); - throw new DataAccessException(e); - } finally { - con.setAutoCommit(true); - con.close(); - } - } catch (SQLException e) { - throw new DataAccessException(e); - } - } + User findById(final long id); + void insert(final User user); + void changePassword(final long id, final String newPassword, final String createBy); } diff --git a/app/src/test/java/com/techcourse/dao/UserDaoTest.java b/app/src/test/java/com/techcourse/dao/UserDaoTest.java index 720989aff4..0f5ef02529 100644 --- a/app/src/test/java/com/techcourse/dao/UserDaoTest.java +++ b/app/src/test/java/com/techcourse/dao/UserDaoTest.java @@ -3,14 +3,11 @@ import com.techcourse.config.DataSourceConfig; import com.techcourse.domain.User; import com.techcourse.support.jdbc.init.DatabasePopulatorUtils; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.jdbc.core.JdbcTemplate; import javax.sql.DataSource; - -import java.sql.Connection; import java.sql.SQLException; import static org.assertj.core.api.Assertions.assertThat; @@ -19,77 +16,62 @@ class UserDaoTest { private UserDao userDao; private DataSource dataSource; - private Connection con; @BeforeEach - void setup() throws SQLException { + void setup() { DatabasePopulatorUtils.execute(DataSourceConfig.getInstance()); dataSource = DataSourceConfig.getInstance(); - con = dataSource.getConnection(); - userDao = new UserDao(new JdbcTemplate()); + userDao = new UserDao(new JdbcTemplate(dataSource)); final var user = new User("gugu", "password", "hkkang@woowahan.com"); - userDao.insert(con, user); - } - - @AfterEach - void close() throws SQLException { - con.close(); + userDao.insert(user); } @Test - void findAll() throws SQLException { - Connection con = dataSource.getConnection(); - final var users = userDao.findAll(con); + void findAll() { + final var users = userDao.findAll(); assertThat(users).isNotEmpty(); } @Test - void findById() throws SQLException { - Connection con = dataSource.getConnection(); + void findById() { - final var user = userDao.findById(con, 1L); + final var user = userDao.findById(1L); assertThat(user.getAccount()).isEqualTo("gugu"); } @Test - void findByAccount() throws SQLException { - Connection con = dataSource.getConnection(); + void findByAccount() { final var account = "gugu"; - final var user = userDao.findByAccount(con, account); + final var user = userDao.findByAccount(account); assertThat(user.getAccount()).isEqualTo(account); } @Test - void insert() throws SQLException { - Connection con1 = dataSource.getConnection(); - Connection con2 = dataSource.getConnection(); + void insert() { final var account = "insert-gugu"; final var user = new User(account, "password", "hkkang@woowahan.com"); - userDao.insert(con1, user); + userDao.insert(user); - final var actual = userDao.findById(con2, 2L); + final var actual = userDao.findById(2L); assertThat(actual.getAccount()).isEqualTo(account); } @Test - void update() throws SQLException { - Connection con1 = dataSource.getConnection(); - Connection con2 = dataSource.getConnection(); - Connection con3 = dataSource.getConnection(); + void update() { final var newPassword = "password99"; - final var user = userDao.findById(con1, 1L); + final var user = userDao.findById(1L); user.changePassword(newPassword); - userDao.update(con2, user); + userDao.update(user); - final var actual = userDao.findById(con3, 1L); + final var actual = userDao.findById(1L); assertThat(actual.getPassword()).isEqualTo(newPassword); } diff --git a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java index c4caf1f9ff..7cbcb31823 100644 --- a/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java +++ b/app/src/test/java/com/techcourse/service/MockUserHistoryDao.java @@ -5,8 +5,6 @@ import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; -import java.sql.Connection; - public class MockUserHistoryDao extends UserHistoryDao { public MockUserHistoryDao(final JdbcTemplate jdbcTemplate) { @@ -14,8 +12,8 @@ public MockUserHistoryDao(final JdbcTemplate jdbcTemplate) { } @Override - public void log(Connection con, UserHistory userHistory) { - super.log(con, userHistory); + public void log(UserHistory userHistory) { + super.log(userHistory); throw new DataAccessException(); } } diff --git a/app/src/test/java/com/techcourse/service/UserServiceTest.java b/app/src/test/java/com/techcourse/service/UserServiceTest.java index 9b42461dbe..d2d5982e90 100644 --- a/app/src/test/java/com/techcourse/service/UserServiceTest.java +++ b/app/src/test/java/com/techcourse/service/UserServiceTest.java @@ -11,8 +11,7 @@ import org.springframework.jdbc.core.JdbcTemplate; import javax.sql.DataSource; -import java.sql.Connection; -import java.sql.SQLException; +import java.lang.reflect.Proxy; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -23,23 +22,21 @@ class UserServiceTest { private JdbcTemplate jdbcTemplate; private UserDao userDao; private DataSource dataSource; - private Connection con; @BeforeEach - void setUp() throws SQLException { - this.jdbcTemplate = new JdbcTemplate(); - this.userDao = new UserDao(jdbcTemplate); + void setUp() { DatabasePopulatorUtils.execute(DataSourceConfig.getInstance()); dataSource = DataSourceConfig.getInstance(); - con = dataSource.getConnection(); + this.jdbcTemplate = new JdbcTemplate(dataSource); + this.userDao = new UserDao(jdbcTemplate); final var user = new User("gugu", "password", "hkkang@woowahan.com"); - userDao.insert(con, user); + userDao.insert(user); } @Test void testChangePassword() { final var userHistoryDao = new UserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao, dataSource); + final var userService = getTransactionalUserService(userHistoryDao); final var newPassword = "qqqqq"; final var createBy = "gugu"; @@ -54,7 +51,7 @@ void testChangePassword() { void testTransactionRollback() { // 트랜잭션 롤백 테스트를 위해 mock으로 교체 MockUserHistoryDao userHistoryDao = new MockUserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao, dataSource); + final var userService = getTransactionalUserService(userHistoryDao); final var newPassword = "newPassword"; final var createBy = "gugu"; @@ -66,4 +63,10 @@ void testTransactionRollback() { assertThat(actual.getPassword()).isNotEqualTo(newPassword); } + + private UserService getTransactionalUserService(final UserHistoryDao userHistoryDao) { + return (UserService) Proxy.newProxyInstance(UserService.class.getClassLoader(), + new Class[]{UserService.class}, + new TransactionHandler(new AppUserService(userDao, userHistoryDao), dataSource)); + } } diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 4173e56358..be43214cee 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -3,8 +3,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.dao.DataAccessException; -import org.springframework.dao.DataUpdateException; +import org.springframework.jdbc.datasource.DataSourceUtils; +import javax.sql.DataSource; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -18,18 +19,25 @@ public class JdbcTemplate { private static final Logger log = LoggerFactory.getLogger(JdbcTemplate.class); - public void update(Connection con, String sql, Object... arguments) { + private final DataSource dataSource; + + public JdbcTemplate(DataSource dataSource) { + this.dataSource = dataSource; + } + + public void update(String sql, Object... arguments) { + Connection con = DataSourceUtils.getConnection(dataSource); try (PreparedStatement pstmt = con.prepareStatement(sql)) { log.debug("query : {}", sql); setArguments(pstmt, arguments); pstmt.executeUpdate(); } catch (SQLException e) { - log.error(e.getMessage(), e); - throw new DataUpdateException(e.getMessage(), e); + throw new DataAccessException(e); } } - public Optional queryForObject(Connection con,String sql, RowMapper rowMapper, Object... arguments) { + public Optional queryForObject(String sql, RowMapper rowMapper, Object... arguments) { + Connection con = DataSourceUtils.getConnection(dataSource); try (PreparedStatement pstmt = con.prepareStatement(sql)) { setArguments(pstmt, arguments); log.debug("query : {}", sql); @@ -45,12 +53,12 @@ public Optional queryForObject(Connection con,String sql, RowMapper ro throw new NoSuchElementException(); }, pstmt); } catch (SQLException e) { - log.error(e.getMessage(), e); - throw new DataAccessException(e.getMessage(), e); + throw new DataAccessException(e); } } - public List query(Connection con,String sql, RowMapper rowMapper, Object... arguments) { + public List query(String sql, RowMapper rowMapper, Object... arguments) { + Connection con = DataSourceUtils.getConnection(dataSource); try (PreparedStatement pstmt = con.prepareStatement(sql)) { setArguments(pstmt, arguments); log.debug("query : {}", sql); @@ -63,12 +71,11 @@ public List query(Connection con,String sql, RowMapper rowMapper, Obje return results; }, pstmt); } catch (SQLException e) { - log.error(e.getMessage(), e); - throw new DataAccessException(e.getMessage(), e); + throw new DataAccessException(e); } } - private void setArguments(PreparedStatement pstmt, Object[] arguments) throws SQLException { + private void setArguments(PreparedStatement pstmt, Object[] arguments) { PreparedStatementSetter psSetter = getPreparedStatementSetter(arguments); psSetter.setValues(pstmt); } diff --git a/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java b/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java index 3c40bfec52..be2f114ac1 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java +++ b/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java @@ -29,6 +29,7 @@ public static Connection getConnection(DataSource dataSource) throws CannotGetJd public static void releaseConnection(Connection connection, DataSource dataSource) { try { + TransactionSynchronizationManager.unbindResource(dataSource); connection.close(); } catch (SQLException ex) { throw new CannotGetJdbcConnectionException("Failed to close JDBC Connection"); diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java index 715557fc66..d9062a0493 100644 --- a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -2,22 +2,25 @@ import javax.sql.DataSource; import java.sql.Connection; +import java.util.HashMap; import java.util.Map; public abstract class TransactionSynchronizationManager { - private static final ThreadLocal> resources = new ThreadLocal<>(); + private static final ThreadLocal> resources = ThreadLocal.withInitial(HashMap::new); - private TransactionSynchronizationManager() {} + private TransactionSynchronizationManager() { + } public static Connection getResource(DataSource key) { - return null; + return resources.get().get(key); } public static void bindResource(DataSource key, Connection value) { + resources.get().put(key, value); } public static Connection unbindResource(DataSource key) { - return null; + return resources.get().remove(key); } }