Skip to content

Commit

Permalink
Merge pull request #17 from sranka/16/mssql_identity_insert
Browse files Browse the repository at this point in the history
fix(mssql): automatically set identity_insert only when inserting to identity column
  • Loading branch information
sranka authored Sep 5, 2022
2 parents 19a4002 + a702a33 commit 7dec4ec
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 26 deletions.
8 changes: 7 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@
<artifactId>commons-dbcp2</artifactId>
<version>2.1.1</version>
</dependency>
<!-- logging -->
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
<version>1.2</version>
</dependency>
<!-- tests -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.1</version>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.math.BigDecimal;
import java.sql.*;
import java.util.Map;
import java.util.function.Consumer;

/**
* Import pushed data into a database.
Expand All @@ -24,6 +25,8 @@ public class DbImportResultConsumer implements ResultConsumer<RowData>{
private final Connection con;
private final DBFacade db;
private final Map<String,String> actualColumns;
private Consumer<ResultSetInfo> notifyOnStartFn = (r) -> {};


// initialize in on start
private PreparedStatement stmt = null;
Expand All @@ -50,8 +53,13 @@ public DbImportResultConsumer(String tableName, Connection connection, DBFacade
this.actualColumns = actualColumns;
}

public void setNotifyOnStartFn(Consumer<ResultSetInfo> consumer){
this.notifyOnStartFn = consumer;
}

@Override
public void onStart(ResultSetInfo info) {
this.notifyOnStartFn.accept(info);
// set connection to info, so blobs can be serialized without extra resources
if (this.db.canCreateBlobs()){
info.connection = con;
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/pz/tool/jdbcimage/main/DBFacade.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.commons.dbcp2.BasicDataSource;

import pz.tool.jdbcimage.LoggedUtils;
import pz.tool.jdbcimage.ResultSetInfo;

/**
* Facade that isolates specifics of a particular database
Expand Down Expand Up @@ -215,6 +216,11 @@ public void beforeImportTable(Connection con, String table, TableInfo tableInfo)
l.beforeImportTable(con, table, tableInfo);
}
}
public void beforeImportTableData(Connection con, String table, TableInfo tableInfo, ResultSetInfo fileInfo) throws SQLException{
for (DBFacadeListener l: listeners) {
l.beforeImportTableData(con, table, tableInfo, fileInfo);
}
}
public void afterImportTable(Connection con, String table, TableInfo tableInfo) throws SQLException{
for (DBFacadeListener l: listeners) {
l.afterImportTable(con, table, tableInfo);
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/pz/tool/jdbcimage/main/DBFacadeListener.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package pz.tool.jdbcimage.main;

import pz.tool.jdbcimage.ResultSetInfo;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collections;
Expand Down Expand Up @@ -40,4 +42,6 @@ static List<DBFacadeListener> getInstances(String classNames){
void importFinished();
void beforeImportTable(Connection con, String table, DBFacade.TableInfo tableInfo) throws SQLException;
void afterImportTable(Connection con, String table, DBFacade.TableInfo tableInfo) throws SQLException;
default void beforeImportTableData(Connection con, String table, DBFacade.TableInfo tableInfo, ResultSetInfo fileInfo) throws SQLException {
}
}
56 changes: 32 additions & 24 deletions src/main/java/pz/tool/jdbcimage/main/Mssql.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;

import org.apache.commons.dbcp2.BasicDataSource;

import pz.tool.jdbcimage.LoggedUtils;
import pz.tool.jdbcimage.ResultSetInfo;
import pz.tool.jdbcimage.db.SqlExecuteCommand;
import pz.tool.jdbcimage.db.TableGroupedCommands;

Expand All @@ -29,8 +27,8 @@ public void setupDataSource(BasicDataSource bds) {
@Override
public List<String> getDbUserTables(Connection con) throws SQLException {
List<String> retVal = new ArrayList<>();
try(ResultSet tables = con.getMetaData().getTables(con.getCatalog(), "dbo", "%", new String[]{"TABLE"})){
while(tables.next()){
try (ResultSet tables = con.getMetaData().getTables(con.getCatalog(), "dbo", "%", new String[]{"TABLE"})) {
while (tables.next()) {
String tableName = tables.getString(3);
retVal.add(tableName);
}
Expand Down Expand Up @@ -105,37 +103,53 @@ public String getTruncateTableSql(String tableName) {
return "DELETE FROM " + escapeTableName(tableName);
}

private Map<String, Set<String>> tableIdentityColumns = Collections.emptyMap();

private boolean importsToIdentityColumns(TableInfo tableInfo, ResultSetInfo fileInfo) {
Set<String> identityColumns = tableIdentityColumns.get(tableInfo.getTableName());
if (identityColumns != null) {
Set<String> schemaColumns = tableInfo.getTableColumns().keySet();
Set<String> importedColumns = Arrays.stream(fileInfo.columns)
.filter(col -> schemaColumns.contains(col.toLowerCase()))
.collect(Collectors.toSet());
return identityColumns.stream().anyMatch(importedColumns::contains);
}
return false;
}

@Override
public void afterImportTable(Connection con, String table, TableInfo tableInfo) throws SQLException {
super.afterImportTable(con, table, tableInfo);
if (tableInfo.get("hasId")!=null) {
if (tableInfo.get("identity_insert_on") != null) {
try (Statement stmt = con.createStatement()) {
stmt.execute("SET IDENTITY_INSERT [" + table + "] OFF");
}
}
}

@Override
public void beforeImportTable(Connection con, String table, TableInfo tableInfo) throws SQLException {
super.beforeImportTable(con, table, tableInfo);
if (tableInfo.get("hasId")!=null) {
public void beforeImportTableData(Connection con, String table, TableInfo tableInfo, ResultSetInfo fileInfo) throws SQLException {
super.beforeImportTableData(con, table, tableInfo, fileInfo);
if (importsToIdentityColumns(tableInfo, fileInfo)) {
try (Statement stmt = con.createStatement()) {
stmt.execute("SET IDENTITY_INSERT [" + table + "] ON");
}
tableInfo.put("identity_insert_on", true);
}
}

private Map<String, Boolean> hasIdentityColumns = null;

@Override
public void importStarted() {
Map<String, Boolean> retVal = new HashMap<>();
Map<String, Set<String>> retVal = new HashMap<>();
try (Connection con = mainToolBase.getReadOnlyConnection()) {
try (Statement stmt = con.createStatement()) {
try (ResultSet rs = stmt.executeQuery(
"select name from sys.objects where type = 'U' and OBJECTPROPERTY(object_id, 'TableHasIdentity')=1")) {
"SELECT Object_Name(object_id),name FROM sys.columns " +
"WHERE is_identity=1 And Objectproperty(object_id,'IsUserTable')=1")) {
while (rs.next()) {
retVal.put(rs.getString(1), Boolean.TRUE);
String table = rs.getString(1);
String col = rs.getString(2).toLowerCase();
retVal.computeIfAbsent(table, k -> new HashSet<>()).add(col);
}
}
} finally {
Expand All @@ -148,21 +162,15 @@ public void importStarted() {
} catch (SQLException e) {
throw new RuntimeException(e);
}
hasIdentityColumns = retVal;
}
@Override
public TableInfo getTableInfo(String tableName) {
TableInfo retVal = super.getTableInfo(tableName);
retVal.put("hasId", hasIdentityColumns.get(tableName));

return retVal;
tableIdentityColumns = retVal;
}

public static class Types {
public static final int SQL_VARIANT = -156;
public static final int DATETIMEOFFSET = -155;

private Types() { }
private Types() {
}
}

}
10 changes: 9 additions & 1 deletion src/main/java/pz/tool/jdbcimage/main/SingleTableImport.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,15 @@ public long importTable(String tableName, File file, DBFacade.TableInfo tableInf
// import data
tableInfo.setTableColumns(actualColumns);
dbFacade.beforeImportTable(con, tableName, tableInfo);
ResultProducerRunner runner = new ResultProducerRunner(producer, new DbImportResultConsumer(tableName, con, dbFacade, actualColumns));
DbImportResultConsumer consumer = new DbImportResultConsumer(tableName, con, dbFacade, actualColumns);
consumer.setNotifyOnStartFn(fileInfo -> {
try {
dbFacade.beforeImportTableData(con, tableName, tableInfo, fileInfo);
} catch(SQLException e){
throw new RuntimeException(e);
}
});
ResultProducerRunner runner = new ResultProducerRunner(producer, consumer);
long rows = runner.run();
dbFacade.afterImportTable(con, tableName, tableInfo);

Expand Down

0 comments on commit 7dec4ec

Please sign in to comment.