diff --git a/app/src/main/java/com/techcourse/dao/UserDao.java b/app/src/main/java/com/techcourse/dao/UserDao.java index 07c234b160..1fb78501e5 100644 --- a/app/src/main/java/com/techcourse/dao/UserDao.java +++ b/app/src/main/java/com/techcourse/dao/UserDao.java @@ -46,20 +46,18 @@ public void update(final User user) { public List findAll() { final String sql = "select id, account, password, email from users"; - return jdbcTemplate.query(sql, USER_ROW_MAPPER); + return jdbcTemplate.queryForList(sql, USER_ROW_MAPPER); } public User findById(final Long id) { final var sql = "select id, account, password, email from users where id = ?"; final PreparedStatementSetter preparedStatementSetter = pstmt -> pstmt.setLong(1, id); - return jdbcTemplate.query(sql, USER_ROW_MAPPER, preparedStatementSetter); - + return jdbcTemplate.queryForObject(sql, USER_ROW_MAPPER, preparedStatementSetter); } public User findByAccount(final String account) { final var sql = "select id, account, password, email from users where account = ?"; - final PreparedStatementSetter preparedStatementSetter = pstmt -> pstmt.setString(1, account); - return jdbcTemplate.query(sql, USER_ROW_MAPPER, preparedStatementSetter); + return jdbcTemplate.queryForObject(sql, USER_ROW_MAPPER, account); } } diff --git a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java index edb4338caa..b85213a199 100644 --- a/app/src/main/java/com/techcourse/dao/UserHistoryDao.java +++ b/app/src/main/java/com/techcourse/dao/UserHistoryDao.java @@ -1,62 +1,29 @@ package com.techcourse.dao; import com.techcourse.domain.UserHistory; -import org.springframework.jdbc.core.JdbcTemplate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - -import javax.sql.DataSource; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.SQLException; +import org.springframework.jdbc.core.JdbcTemplate; public class UserHistoryDao { private static final Logger log = LoggerFactory.getLogger(UserHistoryDao.class); - private final DataSource dataSource; - - public UserHistoryDao(final DataSource dataSource) { - this.dataSource = dataSource; - } + private final JdbcTemplate jdbcTemplate; public UserHistoryDao(final JdbcTemplate jdbcTemplate) { - this.dataSource = null; + this.jdbcTemplate = jdbcTemplate; } public void log(final UserHistory userHistory) { - final var sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; - - Connection conn = null; - PreparedStatement pstmt = null; - try { - conn = dataSource.getConnection(); - pstmt = conn.prepareStatement(sql); - - log.debug("query : {}", sql); - + final String sql = "insert into user_history (user_id, account, password, email, created_at, created_by) values (?, ?, ?, ?, ?, ?)"; + jdbcTemplate.update(sql, pstmt -> { pstmt.setLong(1, userHistory.getUserId()); pstmt.setString(2, userHistory.getAccount()); pstmt.setString(3, userHistory.getPassword()); pstmt.setString(4, userHistory.getEmail()); pstmt.setObject(5, userHistory.getCreatedAt()); pstmt.setString(6, userHistory.getCreateBy()); - pstmt.executeUpdate(); - } catch (SQLException e) { - log.error(e.getMessage(), e); - throw new RuntimeException(e); - } finally { - try { - if (pstmt != null) { - pstmt.close(); - } - } catch (SQLException ignored) {} - - try { - if (conn != null) { - conn.close(); - } - } catch (SQLException ignored) {} - } + }); } } diff --git a/app/src/main/java/com/techcourse/service/AppUserService.java b/app/src/main/java/com/techcourse/service/AppUserService.java new file mode 100644 index 0000000000..a35cdc7bfe --- /dev/null +++ b/app/src/main/java/com/techcourse/service/AppUserService.java @@ -0,0 +1,36 @@ +package com.techcourse.service; + +import com.techcourse.dao.UserDao; +import com.techcourse.dao.UserHistoryDao; +import com.techcourse.domain.User; +import com.techcourse.domain.UserHistory; +import org.springframework.dao.DataAccessException; + +public class AppUserService implements UserService { + + private final UserDao userDao; + private final UserHistoryDao userHistoryDao; + + public AppUserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { + this.userDao = userDao; + this.userHistoryDao = userHistoryDao; + } + + @Override + public void insert(final User user) { + userDao.insert(user); + } + + @Override + public void changePassword(final long id, final String newPassword, final String createBy) throws DataAccessException { + final var user = findById(id); + user.changePassword(newPassword); + userDao.update(user); + userHistoryDao.log(new UserHistory(user, createBy)); + } + + @Override + public User findById(final long id) { + return userDao.findById(id); + } +} diff --git a/app/src/main/java/com/techcourse/service/TxUserService.java b/app/src/main/java/com/techcourse/service/TxUserService.java new file mode 100644 index 0000000000..8e888474ea --- /dev/null +++ b/app/src/main/java/com/techcourse/service/TxUserService.java @@ -0,0 +1,46 @@ +package com.techcourse.service; + +import com.techcourse.config.DataSourceConfig; +import com.techcourse.domain.User; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; + +import java.sql.Connection; +import java.sql.SQLException; + +public class TxUserService implements UserService { + private final UserService userService; + + public TxUserService(final UserService userService) { + this.userService = userService; + } + + @Override + public User findById(final long id) { + return userService.findById(id); + } + + @Override + public void insert(final User user) { + userService.insert(user); + } + + @Override + public void changePassword(final long id, final String newPassword, final String createBy) { + final Connection connection = DataSourceUtils.getConnection(DataSourceConfig.getInstance()); + try { + connection.setAutoCommit(false); + userService.changePassword(id, newPassword, createBy); + connection.commit(); + } catch (final Exception e) { + try { + connection.rollback(); + } catch (SQLException ex) { + throw new DataAccessException(ex); + } + throw new DataAccessException(e); + } finally { + DataSourceUtils.releaseConnection(connection, DataSourceConfig.getInstance()); + } + } +} diff --git a/app/src/main/java/com/techcourse/service/UserService.java b/app/src/main/java/com/techcourse/service/UserService.java index fcf2159dc8..805425e246 100644 --- a/app/src/main/java/com/techcourse/service/UserService.java +++ b/app/src/main/java/com/techcourse/service/UserService.java @@ -1,32 +1,11 @@ package com.techcourse.service; -import com.techcourse.dao.UserDao; -import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; -import com.techcourse.domain.UserHistory; -public class UserService { +public interface UserService { + User findById(final long id); - private final UserDao userDao; - private final UserHistoryDao userHistoryDao; + void insert(User user); - public UserService(final UserDao userDao, final UserHistoryDao userHistoryDao) { - this.userDao = userDao; - this.userHistoryDao = userHistoryDao; - } - - public User findById(final long id) { - return userDao.findById(id); - } - - public void insert(final User user) { - userDao.insert(user); - } - - public void changePassword(final long id, final String newPassword, final String createBy) { - final var user = findById(id); - user.changePassword(newPassword); - userDao.update(user); - userHistoryDao.log(new UserHistory(user, createBy)); - } + void changePassword(long id, String newPassword, String createBy); } diff --git a/app/src/test/java/com/techcourse/dao/UserDaoTest.java b/app/src/test/java/com/techcourse/dao/UserDaoTest.java index 7e9cc2b01a..e7995b4c90 100644 --- a/app/src/test/java/com/techcourse/dao/UserDaoTest.java +++ b/app/src/test/java/com/techcourse/dao/UserDaoTest.java @@ -11,13 +11,13 @@ class UserDaoTest { + private final JdbcTemplate jdbcTemplate = new JdbcTemplate(DataSourceConfig.getInstance()); private UserDao userDao; @BeforeEach void setup() { DatabasePopulatorUtils.execute(DataSourceConfig.getInstance()); - - userDao = new UserDao(new JdbcTemplate(DataSourceConfig.getInstance())); + userDao = new UserDao(jdbcTemplate); final var user = new User("gugu", "password", "hkkang@woowahan.com"); userDao.insert(user); } diff --git a/app/src/test/java/com/techcourse/service/UserServiceTest.java b/app/src/test/java/com/techcourse/service/AppUserServiceTest.java similarity index 85% rename from app/src/test/java/com/techcourse/service/UserServiceTest.java rename to app/src/test/java/com/techcourse/service/AppUserServiceTest.java index 255a0ebfe7..6b588471e7 100644 --- a/app/src/test/java/com/techcourse/service/UserServiceTest.java +++ b/app/src/test/java/com/techcourse/service/AppUserServiceTest.java @@ -5,17 +5,15 @@ import com.techcourse.dao.UserHistoryDao; import com.techcourse.domain.User; import com.techcourse.support.jdbc.init.DatabasePopulatorUtils; -import org.springframework.dao.DataAccessException; -import org.springframework.jdbc.core.JdbcTemplate; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.core.JdbcTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; -@Disabled -class UserServiceTest { +class AppUserServiceTest { private JdbcTemplate jdbcTemplate; private UserDao userDao; @@ -33,7 +31,7 @@ void setUp() { @Test void testChangePassword() { final var userHistoryDao = new UserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + final var userService = new AppUserService(userDao, userHistoryDao); final var newPassword = "qqqqq"; final var createBy = "gugu"; @@ -48,7 +46,10 @@ void testChangePassword() { void testTransactionRollback() { // 트랜잭션 롤백 테스트를 위해 mock으로 교체 final var userHistoryDao = new MockUserHistoryDao(jdbcTemplate); - final var userService = new UserService(userDao, userHistoryDao); + // 애플리케이션 서비스 + final var appUserService = new AppUserService(userDao, userHistoryDao); + // 트랜잭션 서비스 추상화 + final var userService = new TxUserService(appUserService); final var newPassword = "newPassword"; final var createBy = "gugu"; diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 45aa7b6e39..4e86b77bcb 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -3,6 +3,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.datasource.DataSourceUtils; import javax.sql.DataSource; import java.sql.Connection; @@ -22,70 +23,80 @@ public JdbcTemplate(final DataSource dataSource) { this.dataSource = dataSource; } - public void update(final String sql, final PreparedStatementSetter preparedStatementSetter) throws DataAccessException { - try (Connection conn = dataSource.getConnection()) { - try (PreparedStatement pstmt = conn.prepareStatement(sql)) { - preparedStatementSetter.setValues(pstmt); - pstmt.executeUpdate(); - } - } catch (final SQLException e) { - throw new DataAccessException(e); - } + public int update(final String sql, final Object... args) throws DataAccessException { + return update(sql, pstmt -> { + setPreparedStatementWithArgs(args, pstmt); + }); } - public T query(final String sql, final RowMapper rowMapper, final PreparedStatementSetter preparedStatementSetter) { - try (final Connection connection = dataSource.getConnection(); - final PreparedStatement pstmt = connection.prepareStatement(sql)) { - try (final ResultSet rs = executeQuery(preparedStatementSetter, pstmt)) { - log.debug("query : {}", sql); - if (rs.next()) { - return rowMapper.mapRow(rs, rs.getRow()); - } - throw new DataAccessException("Empty Result"); - } - } catch (final SQLException e) { - throw new DataAccessException(e); + public int update(final String sql, final PreparedStatementSetter preparedStatementSetter) throws DataAccessException { + return execute(sql, pstmt -> { + preparedStatementSetter.setValues(pstmt); + return pstmt.executeUpdate(); + }); + } + + public T queryForObject(final String sql, final RowMapper rowMapper, final Object... args) throws DataAccessException { + return queryForObject(sql, rowMapper, pstmt -> { + setPreparedStatementWithArgs(args, pstmt); + }); + } + + public T queryForObject(final String sql, final RowMapper rowMapper, final PreparedStatementSetter preparedStatementSetter) throws DataAccessException { + final List result = query(sql, extractList(rowMapper), preparedStatementSetter); + if (result.isEmpty()) { + throw new DataAccessException("No results"); + } + if (result.size() > 1) { + throw new DataAccessException("Too many results"); } + return result.get(0); } - private ResultSet executeQuery(final PreparedStatementSetter preparedStatementSetter, final PreparedStatement preparedStatement) throws SQLException { - preparedStatementSetter.setValues(preparedStatement); - return preparedStatement.executeQuery(); + public List queryForList(final String sql, final RowMapper rowMapper, Object... args) throws DataAccessException { + return queryForList(sql, rowMapper, pstmt -> { + setPreparedStatementWithArgs(args, pstmt); + }); } - public T query(final String sql, final RowMapper rowMapper, final Object... args) { - try (final Connection connection = dataSource.getConnection(); - final PreparedStatement pstmt = connection.prepareStatement(sql)) { - try (final ResultSet rs = executeQuery(args, pstmt)) { - if (rs.next()) { - return rowMapper.mapRow(rs, rs.getRow()); - } - throw new DataAccessException("Empty Result"); + public List queryForList(final String sql, final RowMapper rowMapper, final PreparedStatementSetter preparedStatementSetter) throws DataAccessException { + return query(sql, extractList(rowMapper), preparedStatementSetter); + } + + + private T query(final String sql, ResultSetExtractor rse, PreparedStatementSetter pss) throws DataAccessException { + return execute(sql, pstmt -> { + pss.setValues(pstmt); + try (final ResultSet rs = pstmt.executeQuery()) { + return rse.extractData(rs); } - } catch (final SQLException e) { - throw new DataAccessException(e); + }); + } + + private T execute(final String sql, final PreparedStatementCallback preparedStatementCallback) throws DataAccessException { + final Connection connection = DataSourceUtils.getConnection(dataSource); + try (final PreparedStatement pstmt = connection.prepareStatement(sql)) { + return preparedStatementCallback.doInPreparedStatement(pstmt); + } catch (SQLException e) { + throw new DataAccessException(e.getMessage()); + } finally { + DataSourceUtils.releaseConnection(connection, dataSource); } } - private ResultSet executeQuery(final Object[] args, final PreparedStatement pstmt) throws SQLException { + private void setPreparedStatementWithArgs(final Object[] args, final PreparedStatement pstmt) throws SQLException { for (int i = 0; i < args.length; i++) { pstmt.setObject(i + 1, args[i]); } - return pstmt.executeQuery(); } - public List query(final String sql, final RowMapper rowMapper) throws DataAccessException { - try (final Connection connection = dataSource.getConnection(); - final PreparedStatement pstmt = connection.prepareStatement(sql)) { - try (final ResultSet rs = pstmt.executeQuery()) { - List result = new ArrayList<>(); - while (rs.next()) { - result.add(rowMapper.mapRow(rs, rs.getRow())); - } - return result; + private ResultSetExtractor> extractList(final RowMapper rowMapper) { + return rs -> { + final List results = new ArrayList<>(); + while (rs.next()) { + results.add(rowMapper.mapRow(rs, rs.getRow())); } - } catch (final SQLException e) { - throw new DataAccessException(e); - } + return results; + }; } } diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCallback.java b/jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCallback.java new file mode 100644 index 0000000000..10046145ae --- /dev/null +++ b/jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCallback.java @@ -0,0 +1,9 @@ +package org.springframework.jdbc.core; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +@FunctionalInterface +public interface PreparedStatementCallback { + T doInPreparedStatement(PreparedStatement ps) throws SQLException; +} diff --git a/jdbc/src/main/java/org/springframework/jdbc/core/ResultSetExtractor.java b/jdbc/src/main/java/org/springframework/jdbc/core/ResultSetExtractor.java new file mode 100644 index 0000000000..dfde52820a --- /dev/null +++ b/jdbc/src/main/java/org/springframework/jdbc/core/ResultSetExtractor.java @@ -0,0 +1,9 @@ +package org.springframework.jdbc.core; + +import java.sql.ResultSet; +import java.sql.SQLException; + +@FunctionalInterface +public interface ResultSetExtractor { + T extractData(ResultSet rs) throws SQLException; +} diff --git a/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java b/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java index 3c40bfec52..26d0933826 100644 --- a/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java +++ b/jdbc/src/main/java/org/springframework/jdbc/datasource/DataSourceUtils.java @@ -10,7 +10,8 @@ // 4단계 미션에서 사용할 것 public abstract class DataSourceUtils { - private DataSourceUtils() {} + private DataSourceUtils() { + } public static Connection getConnection(DataSource dataSource) throws CannotGetJdbcConnectionException { Connection connection = TransactionSynchronizationManager.getResource(dataSource); @@ -29,6 +30,10 @@ public static Connection getConnection(DataSource dataSource) throws CannotGetJd public static void releaseConnection(Connection connection, DataSource dataSource) { try { + final Connection resource = TransactionSynchronizationManager.getResource(dataSource); + if (resource.equals(connection)) { + TransactionSynchronizationManager.unbindResource(dataSource); + } connection.close(); } catch (SQLException ex) { throw new CannotGetJdbcConnectionException("Failed to close JDBC Connection"); diff --git a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java index 715557fc66..2de08f4f68 100644 --- a/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java +++ b/jdbc/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -2,22 +2,31 @@ import javax.sql.DataSource; import java.sql.Connection; +import java.util.HashMap; import java.util.Map; public abstract class TransactionSynchronizationManager { - private static final ThreadLocal> resources = new ThreadLocal<>(); + private static final ThreadLocal> resources = ThreadLocal.withInitial(HashMap::new); - private TransactionSynchronizationManager() {} + private TransactionSynchronizationManager() { + } public static Connection getResource(DataSource key) { - return null; + return resources.get().get(key); } public static void bindResource(DataSource key, Connection value) { + if (resources.get().containsKey(key)) { + throw new IllegalStateException("Already value [" + resources.get().get(key) + "] for key [" + key + "] bound to thread [" + Thread.currentThread().getName() + "]"); + } + resources.get().put(key, value); } public static Connection unbindResource(DataSource key) { - return null; + if (!resources.get().containsKey(key)) { + throw new IllegalStateException("No value for key [" + key + "] bound to thread [" + Thread.currentThread().getName() + "]"); + } + return resources.get().remove(key); } } diff --git a/jdbc/src/test/java/nextstep/jdbc/JdbcTemplateTest.java b/jdbc/src/test/java/nextstep/jdbc/JdbcTemplateTest.java index 7177ebdf7d..4754ff15d3 100644 --- a/jdbc/src/test/java/nextstep/jdbc/JdbcTemplateTest.java +++ b/jdbc/src/test/java/nextstep/jdbc/JdbcTemplateTest.java @@ -6,8 +6,8 @@ import org.springframework.dao.DataAccessException; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.datasource.DataSourceUtils; -import javax.sql.DataSource; import java.sql.Connection; import java.sql.SQLException; import java.sql.Statement; @@ -23,8 +23,8 @@ class JdbcTemplateTest { @BeforeEach public void setUp() { - final DataSource instance = DataSourceConfig.getInstance(); - try (final Connection connection = instance.getConnection()) { + Connection connection = DataSourceUtils.getConnection(DataSourceConfig.getInstance()); + try { final Statement statement = connection.createStatement(); statement.execute("drop table if exists users"); statement.execute("create table if not exists users (id bigint auto_increment, account varchar(255), password varchar(255), email varchar(255), primary key (id))"); @@ -53,7 +53,7 @@ void update() { "email", resultSet.getString("email") ); String selectSql = "select * from users"; - final List> result = jdbcTemplate.query(selectSql, rowMapper); + final List> result = jdbcTemplate.queryForList(selectSql, rowMapper); assertSoftly(softly -> { softly.assertThat(result).hasSize(1); softly.assertThat(result.get(0).get("account")).isEqualTo("account"); @@ -79,7 +79,7 @@ public void queryList() { //when String selectSql = "select * from users"; - final List> result = jdbcTemplate.query(selectSql, (resultSet, rowNum) -> Map.of( + final List> result = jdbcTemplate.queryForList(selectSql, (resultSet, rowNum) -> Map.of( "id", resultSet.getLong("id"), "account", resultSet.getString("account"), "password", resultSet.getString("password"), @@ -105,7 +105,7 @@ void vargsQuery() { //when String selectSql = "select * from users where account = ?"; - final Map result = jdbcTemplate.query(selectSql, (resultSet, rowNum) -> Map.of( + final Map result = jdbcTemplate.queryForObject(selectSql, (resultSet, rowNum) -> Map.of( "id", resultSet.getLong("id"), "account", resultSet.getString("account"), "password", resultSet.getString("password"),