diff --git a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java index edb4338caa..e5222a66fc 100644 --- a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java +++ b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java @@ -6,57 +6,24 @@ import org.slf4j.LoggerFactory; import javax.sql.DataSource; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.SQLException; public class UserHistoryDao { private static final Logger log = LoggerFactory.getLogger(UserHistoryDao.class); - private final DataSource dataSource; + + private final JdbcTemplate jdbcTemplate; public UserHistoryDao(final DataSource dataSource) { - this.dataSource = dataSource; + this.jdbcTemplate = new JdbcTemplate(dataSource); } public UserHistoryDao(final JdbcTemplate jdbcTemplate) { - this.dataSource = null; + this.jdbcTemplate = jdbcTemplate; } - public void log(final UserHistory userHistory) { - final var sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; - - Connection conn = null; - PreparedStatement pstmt = null; - try { - conn = dataSource.getConnection(); - pstmt = conn.prepareStatement(sql); - - log.debug("query : {}", sql); - - pstmt.setLong(1, userHistory.getUserId()); - pstmt.setString(2, userHistory.getAccount()); - pstmt.setString(3, userHistory.getPassword()); - pstmt.setString(4, userHistory.getEmail()); - pstmt.setObject(5, userHistory.getCreatedAt()); - pstmt.setString(6, userHistory.getCreateBy()); - pstmt.executeUpdate(); - } catch (SQLException e) { - log.error(e.getMessage(), e); - throw new RuntimeException(e); - } finally { - try { - if (pstmt != null) { - pstmt.close(); - } - } catch (SQLException ignored) {} - - try { - if (conn != null) { - conn.close(); - } - } catch (SQLException ignored) {} - } + public void log(UserHistory userHistory) { + String sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; + jdbcTemplate.update(sql, userHistory.getUserId(), userHistory.getAccount(), userHistory.getPassword(), userHistory.getEmail(), userHistory.getCreatedAt(), userHistory.getCreateBy()); } } diff --git a/app/src/main/java/com/techcourse/service/UserService.java b/app/src/main/java/com/techcourse/service/UserService.java index fcf2159dc8..11df6c21e1 100644 --- a/app/src/main/java/com/techcourse/service/UserService.java +++ b/app/src/main/java/com/techcourse/service/UserService.java @@ -1,12 +1,21 @@ package com.techcourse.service; +import com.techcourse.config.DataSourceConfig; import com.techcourse.dao.UserDao; import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; import com.techcourse.domain.UserHistory; +import java.sql.Connection; +import java.sql.SQLException; +import javax.sql.DataSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.dao.DataAccessException; +import org.springframework.transaction.support.TransactionSynchronizationManager; public class UserService { + private static final Logger log = LoggerFactory.getLogger(UserService.class); private final UserDao userDao; private final UserHistoryDao userHistoryDao; @@ -24,9 +33,29 @@ public void insert(final User user) { } public void changePassword(final long id, final String newPassword, final String createBy) { - final var user = findById(id); - user.changePassword(newPassword); - userDao.update(user); - userHistoryDao.log(new UserHistory(user, createBy)); + DataSource dataSource = DataSourceConfig.getInstance(); + Connection conn = null; + try { + conn = dataSource.getConnection(); + conn.setAutoCommit(false); + TransactionSynchronizationManager.bindResource(dataSource, conn); + User user = findById(id); + user.changePassword(newPassword); + userDao.update(user); + userHistoryDao.log(new UserHistory(user, createBy)); + conn.commit(); + } catch (Exception e) { + log.error(e.getMessage(), e); + if (conn != null) { + try { + TransactionSynchronizationManager.unbindResource(dataSource); + conn.rollback(); + conn.close(); + } catch (SQLException ex) { + log.error(ex.getMessage(), ex); + } + } + throw new DataAccessException(); + } } } diff --git a/app/src/test/java/com/techcourse/service/UserServiceTest.java b/app/src/test/java/com/techcourse/service/UserServiceTest.java index 255a0ebfe7..83bc1d3505 100644 --- a/app/src/test/java/com/techcourse/service/UserServiceTest.java +++ b/app/src/test/java/com/techcourse/service/UserServiceTest.java @@ -8,13 +8,11 @@ import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; -@Disabled class UserServiceTest { private JdbcTemplate jdbcTemplate; diff --git a/jdbc/src/main/java/org/springframework/jdbc/ColumnConversionException.java b/jdbc/src/main/java/org/springframework/jdbc/ColumnConversionException.java new file mode 100644 index 0000000000..c4a1c92432 --- /dev/null +++ b/jdbc/src/main/java/org/springframework/jdbc/ColumnConversionException.java @@ -0,0 +1,7 @@ +package org.springframework.jdbc; + +import org.springframework.dao.DataAccessException; + +public class ColumnConversionException extends DataAccessException { + +} diff --git a/jdbc/src/main/java/org/springframework/jdbc/IncorrectResultSizeDataAccessException.java b/jdbc/src/main/java/org/springframework/jdbc/IncorrectResultSizeDataAccessException.java index 733a885568..7eb8a49f76 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/IncorrectResultSizeDataAccessException.java +++ b/jdbc/src/main/java/org/springframework/jdbc/IncorrectResultSizeDataAccessException.java @@ -1,5 +1,7 @@ package org.springframework.jdbc; -public class IncorrectResultSizeDataAccessException extends RuntimeException { +import org.springframework.dao.DataAccessException; + +public class IncorrectResultSizeDataAccessException extends DataAccessException { } diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/ColumnTypes.java b/jdbc/src/main/java/org/springframework/jdbc/core/ColumnTypes.java index d4cdb5054f..b7ecff3027 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/ColumnTypes.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/ColumnTypes.java @@ -1,5 +1,7 @@ package org.springframework.jdbc.core; +import org.springframework.jdbc.ColumnConversionException; + public class ColumnTypes { /** @@ -14,7 +16,7 @@ public static Class convertToClass(int types) { case 12: return String.class; default: - throw new IllegalStateException(); + throw new ColumnConversionException(); } } } diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/InstantiateUtil.java b/jdbc/src/main/java/org/springframework/jdbc/core/InstantiateUtil.java new file mode 100644 index 0000000000..115f1b0327 --- /dev/null +++ b/jdbc/src/main/java/org/springframework/jdbc/core/InstantiateUtil.java @@ -0,0 +1,22 @@ +package org.springframework.jdbc.core; + +import java.lang.reflect.Constructor; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; + +public class InstantiateUtil { + + public static T instantiate(ResultSet rs, Class requiredType, Object[] initArgs) + throws Exception { + ResultSetMetaData metaData = rs.getMetaData(); + int columnCount = metaData.getColumnCount(); + + Class[] columnTypes = new Class[columnCount]; + for (int i = 1; i <= columnCount; i++) { + int columnType = metaData.getColumnType(i); + columnTypes[i - 1] = ColumnTypes.convertToClass(columnType); + } + Constructor constructor = requiredType.getDeclaredConstructor(columnTypes); + return requiredType.cast(constructor.newInstance(initArgs)); + } +} diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 1afced8327..fb024480a4 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -1,10 +1,8 @@ package org.springframework.jdbc.core; -import java.lang.reflect.Constructor; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; -import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; @@ -13,6 +11,7 @@ import org.slf4j.LoggerFactory; import org.springframework.dao.DataAccessException; import org.springframework.jdbc.IncorrectResultSizeDataAccessException; +import org.springframework.transaction.support.TransactionSynchronizationManager; public class JdbcTemplate { @@ -24,63 +23,33 @@ public JdbcTemplate(final DataSource dataSource) { this.dataSource = dataSource; } - private Connection getConnection() throws SQLException { - return dataSource.getConnection(); - } - public int update(String sql, Object... args) throws DataAccessException { - try ( - Connection conn = getConnection(); - PreparedStatement pstmt = conn.prepareStatement(sql) - ) { - for (int i = 1; i <= args.length; i++) { - pstmt.setObject(i, args[i - 1]); - } + return execute(sql, (pstmt) -> { + prepareStatement(pstmt, args); return pstmt.executeUpdate(); - } catch (Exception e) { - log.error(e.getMessage(), e); - throw new DataAccessException(e); - } + }); } public T queryForObject(String sql, Class requiredType, Object... args) throws DataAccessException { - try ( - Connection conn = getConnection(); - PreparedStatement pstmt = conn.prepareStatement(sql) - ) { - for (int i = 1; i <= args.length; i++) { - pstmt.setObject(1, args[i - 1]); - } - ResultSet rs = pstmt.executeQuery(); - ResultSetMetaData metaData = rs.getMetaData(); - int columnCount = rs.getMetaData().getColumnCount(); - - Class[] columnTypes = new Class[columnCount]; - for (int i = 1; i <= columnCount; i++) { - int columnType = metaData.getColumnType(i); - columnTypes[i - 1] = ColumnTypes.convertToClass(columnType); - } - Constructor constructor = requiredType.getDeclaredConstructor(columnTypes); - if (rs.first() && rs.isLast()) { - Object[] initargs = new Object[columnCount]; - for (int i = 1; i <= columnCount; i++) { - initargs[i - 1] = rs.getObject(i); + return execute(sql, (pstmt) -> { + prepareStatement(pstmt, args); + try (ResultSet rs = pstmt.executeQuery()) { + int columnCount = rs.getMetaData().getColumnCount(); + if (rs.first() && rs.isLast()) { + Object[] initArgs = new Object[columnCount]; + for (int i = 1; i <= columnCount; i++) { + initArgs[i - 1] = rs.getObject(i); + } + return InstantiateUtil.instantiate(rs, requiredType, initArgs); } - return requiredType.cast(constructor.newInstance(initargs)); + throw new IncorrectResultSizeDataAccessException(); } - throw new IncorrectResultSizeDataAccessException(); - } catch (Exception e) { - log.error(e.getMessage(), e); - throw new DataAccessException(e); - } + }); } public List query(String sql, RowMapper rowMapper) throws DataAccessException { - try ( - Connection conn = getConnection(); - PreparedStatement pstmt = conn.prepareStatement(sql) - ) { + return execute(sql, (pstmt) -> { try (ResultSet resultSet = pstmt.executeQuery()) { List results = new ArrayList<>(); while (resultSet.next()) { @@ -88,9 +57,44 @@ public List query(String sql, RowMapper rowMapper) throws DataAccessEx } return results; } - } catch (SQLException e) { + }); + } + + private void prepareStatement(PreparedStatement pstmt, Object[] args) throws SQLException { + for (int i = 1; i <= args.length; i++) { + pstmt.setObject(i, args[i - 1]); + } + } + + private T execute(String sql, StatementExecution function) { + Connection conn = null; + try { + conn = getConnection(); + PreparedStatement pstmt = conn.prepareStatement(sql); + return function.apply(pstmt); + } catch (Exception e) { log.error(e.getMessage(), e); throw new DataAccessException(e); + } finally { + try { + if (isConnectionManuallyInstantiated(conn)) { + conn.close(); + } + } catch (SQLException e) { + log.error(e.getMessage(), e); + throw new DataAccessException(e); + } } } + + private Connection getConnection() throws SQLException { + if (TransactionSynchronizationManager.getResource(dataSource) != null) { + return TransactionSynchronizationManager.getResource(dataSource); + } + return dataSource.getConnection(); + } + + private boolean isConnectionManuallyInstantiated(Connection conn) { + return conn != null && TransactionSynchronizationManager.getResource(dataSource) == null; + } } diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/StatementExecution.java b/jdbc/src/main/java/org/springframework/jdbc/core/StatementExecution.java new file mode 100644 index 0000000000..fa17285d0e --- /dev/null +++ b/jdbc/src/main/java/org/springframework/jdbc/core/StatementExecution.java @@ -0,0 +1,8 @@ +package org.springframework.jdbc.core; + +@FunctionalInterface +public interface StatementExecution { + + R apply(T t) throws Exception; + +} diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java index 715557fc66..cef718060a 100644 --- a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -1,23 +1,31 @@ package org.springframework.transaction.support; -import javax.sql.DataSource; import java.sql.Connection; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import javax.sql.DataSource; -public abstract class TransactionSynchronizationManager { +public class TransactionSynchronizationManager { - private static final ThreadLocal> resources = new ThreadLocal<>(); + private static final ThreadLocal> resources; - private TransactionSynchronizationManager() {} + static { + resources = new ThreadLocal<>(); + resources.set(new ConcurrentHashMap<>()); + } + + private TransactionSynchronizationManager() { + } public static Connection getResource(DataSource key) { - return null; + return resources.get().get(key); } public static void bindResource(DataSource key, Connection value) { + resources.get().put(key, value); } public static Connection unbindResource(DataSource key) { - return null; + return resources.get().remove(key); } }