diff --git a/docs/sink-connector-config-options.rst b/docs/sink-connector-config-options.rst index da2a2e1f..6f93b8e9 100644 --- a/docs/sink-connector-config-options.rst +++ b/docs/sink-connector-config-options.rst @@ -58,6 +58,10 @@ Writes Use standard SQL ``INSERT`` statements. + ``multi`` + + Use multi-row inserts, e.g. ``INSERT INTO table_name (column_list) VALUES (value_list_1), (value_list_2), ... (value_list_n);`` + ``upsert`` Use the appropriate upsert semantics for the target database if it is supported by the connector, e.g. ``INSERT .. ON CONFLICT .. DO UPDATE SET ..``. @@ -68,7 +72,7 @@ Writes * Type: string * Default: insert - * Valid Values: [insert, upsert, update] + * Valid Values: [insert, multi, upsert, update] * Importance: high ``batch.size`` diff --git a/docs/sink-connector.md b/docs/sink-connector.md index 450f2501..b2e55a19 100644 --- a/docs/sink-connector.md +++ b/docs/sink-connector.md @@ -77,6 +77,14 @@ from Kafka. This mode is used by default. To enable it explicitly, set `insert.mode=insert`. +### Multi Mode + +In this mode, the connector executes an `INSERT` SQL query with multiple +values (effectively inserting multiple row/records per query). +Supported in `SqliteDatabaseDialect` and `PostgreSqlDatabaseDialect`. + +To use this mode, set `insert.mode=multi` + ### Update Mode In this mode, the connector executes `UPDATE` SQL query on each record diff --git a/src/main/java/io/aiven/connect/jdbc/dialect/DatabaseDialect.java b/src/main/java/io/aiven/connect/jdbc/dialect/DatabaseDialect.java index 57166c97..2a4befdc 100644 --- a/src/main/java/io/aiven/connect/jdbc/dialect/DatabaseDialect.java +++ b/src/main/java/io/aiven/connect/jdbc/dialect/DatabaseDialect.java @@ -324,6 +324,24 @@ String buildInsertStatement( Collection nonKeyColumns ); + /** + * Build an INSERT statement for multiple rows. + * + * @param table the identifier of the table; may not be null + * @param records number of rows which will be inserted; must be a positive number + * @param keyColumns the identifiers of the columns in the primary/unique key; may not be null + * but may be empty + * @param nonKeyColumns the identifiers of the other columns in the table; may not be null but may + * be empty + * @return the INSERT statement; may not be null + */ + String buildMultiInsertStatement( + TableId table, + int records, + Collection keyColumns, + Collection nonKeyColumns + ); + /** * Build the INSERT prepared statement expression for the given table and its columns. * @@ -494,7 +512,18 @@ interface StatementBinder { * @param record the sink record with values to be bound into the statement; never null * @throws SQLException if there is a problem binding values into the statement */ - void bindRecord(SinkRecord record) throws SQLException; + default void bindRecord(SinkRecord record) throws SQLException { + bindRecord(1, record); + } + + /** + * Bind the values in the supplied record, starting at the specified index. + * + * @param index the index at which binding starts; must be positive + * @param record the sink record with values to be bound into the statement; never null + * @throws SQLException if there is a problem binding values into the statement + */ + int bindRecord(int index, SinkRecord record) throws SQLException; } /** diff --git a/src/main/java/io/aiven/connect/jdbc/dialect/GenericDatabaseDialect.java b/src/main/java/io/aiven/connect/jdbc/dialect/GenericDatabaseDialect.java index 95a232ed..9372508b 100644 --- a/src/main/java/io/aiven/connect/jdbc/dialect/GenericDatabaseDialect.java +++ b/src/main/java/io/aiven/connect/jdbc/dialect/GenericDatabaseDialect.java @@ -51,6 +51,7 @@ import java.util.TimeZone; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import org.apache.kafka.common.config.types.Password; import org.apache.kafka.connect.data.Date; @@ -85,6 +86,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static io.aiven.connect.jdbc.util.CollectionUtils.isEmpty; +import static java.util.Objects.requireNonNull; +import static java.util.stream.IntStream.range; + /** * A {@link DatabaseDialect} implementation that provides functionality based upon JDBC and SQL. * @@ -1350,6 +1355,44 @@ public String buildInsertStatement( return builder.toString(); } + @Override + public String buildMultiInsertStatement(final TableId table, + final int records, + final Collection keyColumns, + final Collection nonKeyColumns) { + + if (records < 1) { + throw new IllegalArgumentException("number of records must be a positive number, but got: " + records); + } + if (isEmpty(keyColumns) && isEmpty(nonKeyColumns)) { + throw new IllegalArgumentException("no columns specified"); + } + requireNonNull(table, "table must not be null"); + + final String insertStatement = expressionBuilder() + .append("INSERT INTO ") + .append(table) + .append("(") + .appendList() + .delimitedBy(",") + .transformedBy(ExpressionBuilder.columnNames()) + .of(keyColumns, nonKeyColumns) + .append(") VALUES ") + .toString(); + + final String singleRowPlaceholder = expressionBuilder() + .append("(") + .appendMultiple(",", "?", keyColumns.size() + nonKeyColumns.size()) + .append(")") + .toString(); + + final String allRowsPlaceholder = range(1, records + 1) + .mapToObj(i -> singleRowPlaceholder) + .collect(Collectors.joining(",")); + + return insertStatement + allRowsPlaceholder; + } + @Override public String buildUpdateStatement( final TableId table, diff --git a/src/main/java/io/aiven/connect/jdbc/sink/BufferedRecords.java b/src/main/java/io/aiven/connect/jdbc/sink/BufferedRecords.java index e0741845..9d93bdea 100644 --- a/src/main/java/io/aiven/connect/jdbc/sink/BufferedRecords.java +++ b/src/main/java/io/aiven/connect/jdbc/sink/BufferedRecords.java @@ -41,6 +41,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static io.aiven.connect.jdbc.sink.JdbcSinkConfig.InsertMode.MULTI; + public class BufferedRecords { private static final Logger log = LoggerFactory.getLogger(BufferedRecords.class); @@ -53,6 +55,7 @@ public class BufferedRecords { private List records = new ArrayList<>(); private SchemaPair currentSchemaPair; private FieldsMetadata fieldsMetadata; + private TableDefinition tableDefinition; private PreparedStatement preparedStatement; private StatementBinder preparedStatementBinder; @@ -76,39 +79,10 @@ public List add(final SinkRecord record) throws SQLException { record.valueSchema() ); - if (currentSchemaPair == null) { - currentSchemaPair = schemaPair; - // re-initialize everything that depends on the record schema - fieldsMetadata = FieldsMetadata.extract( - tableId.tableName(), - config.pkMode, - config.pkFields, - config.fieldsWhitelist, - currentSchemaPair - ); - dbStructure.createOrAmendIfNecessary( - config, - connection, - tableId, - fieldsMetadata - ); + log.debug("buffered records in list {}", records.size()); - final TableDefinition tableDefinition = dbStructure.tableDefinitionFor(tableId, connection); - final String sql = getInsertSql(tableDefinition); - log.debug( - "{} sql: {}", - config.insertMode, - sql - ); - close(); - preparedStatement = connection.prepareStatement(sql); - preparedStatementBinder = dbDialect.statementBinder( - preparedStatement, - config.pkMode, - schemaPair, - fieldsMetadata, - config.insertMode - ); + if (currentSchemaPair == null) { + reInitialize(schemaPair); } final List flushed; @@ -134,27 +108,74 @@ public List add(final SinkRecord record) throws SQLException { return flushed; } + private void prepareStatement() throws SQLException { + final String sql; + log.debug("Generating query for insert mode {} and {} records", config.insertMode, records.size()); + if (config.insertMode == MULTI) { + sql = getMultiInsertSql(); + } else { + sql = getInsertSql(); + } + + log.debug("Prepared SQL {} for insert mode {}", sql, config.insertMode); + + close(); + preparedStatement = connection.prepareStatement(sql); + preparedStatementBinder = dbDialect.statementBinder( + preparedStatement, + config.pkMode, + currentSchemaPair, + fieldsMetadata, + config.insertMode + ); + } + + /** + * Re-initialize everything that depends on the record schema + */ + private void reInitialize(final SchemaPair schemaPair) throws SQLException { + currentSchemaPair = schemaPair; + fieldsMetadata = FieldsMetadata.extract( + tableId.tableName(), + config.pkMode, + config.pkFields, + config.fieldsWhitelist, + currentSchemaPair + ); + dbStructure.createOrAmendIfNecessary( + config, + connection, + tableId, + fieldsMetadata + ); + + tableDefinition = dbStructure.tableDefinitionFor(tableId, connection); + } + public List flush() throws SQLException { if (records.isEmpty()) { log.debug("Records is empty"); return new ArrayList<>(); } - log.debug("Flushing {} buffered records", records.size()); - for (final SinkRecord record : records) { - preparedStatementBinder.bindRecord(record); - } + prepareStatement(); + bindRecords(); + int totalUpdateCount = 0; boolean successNoInfo = false; - for (final int updateCount : preparedStatement.executeBatch()) { + + log.debug("Executing batch..."); + for (final int updateCount : executeBatch()) { if (updateCount == Statement.SUCCESS_NO_INFO) { successNoInfo = true; continue; } totalUpdateCount += updateCount; } + log.debug("Done executing batch."); if (totalUpdateCount != records.size() && !successNoInfo) { switch (config.insertMode) { case INSERT: + case MULTI: throw new ConnectException(String.format( "Update count (%d) did not sum up to total number of records inserted (%d)", totalUpdateCount, @@ -186,6 +207,30 @@ public List flush() throws SQLException { return flushedRecords; } + private int[] executeBatch() throws SQLException { + if (config.insertMode == MULTI) { + preparedStatement.addBatch(); + } + log.debug("Executing batch with insert mode {}", config.insertMode); + return preparedStatement.executeBatch(); + } + + private void bindRecords() throws SQLException { + log.debug("Binding {} buffered records", records.size()); + int index = 1; + for (final SinkRecord record : records) { + if (config.insertMode == MULTI) { + // All records are bound to the same prepared statement, + // so when binding fields for record N (N > 0) + // we need to start at the index where binding fields for record N - 1 stopped. + index = preparedStatementBinder.bindRecord(index, record); + } else { + preparedStatementBinder.bindRecord(record); + } + } + log.debug("Done binding records."); + } + public void close() throws SQLException { log.info("Closing BufferedRecords with preparedStatement: {}", preparedStatement); if (preparedStatement != null) { @@ -194,7 +239,30 @@ public void close() throws SQLException { } } - private String getInsertSql(final TableDefinition tableDefinition) { + private String getMultiInsertSql() { + if (config.insertMode != MULTI) { + throw new ConnectException(String.format( + "Multi-row first insert SQL unsupported by insert mode %s", + config.insertMode + )); + } + try { + return dbDialect.buildMultiInsertStatement( + tableId, + records.size(), + asColumns(fieldsMetadata.keyFieldNames), + asColumns(fieldsMetadata.nonKeyFieldNames) + ); + } catch (final UnsupportedOperationException e) { + throw new ConnectException(String.format( + "Write to table '%s' in MULTI mode is not supported with the %s dialect.", + tableId, + dbDialect.name() + )); + } + } + + private String getInsertSql() { switch (config.insertMode) { case INSERT: return dbDialect.buildInsertStatement( diff --git a/src/main/java/io/aiven/connect/jdbc/sink/JdbcSinkConfig.java b/src/main/java/io/aiven/connect/jdbc/sink/JdbcSinkConfig.java index de70e224..04be94c1 100644 --- a/src/main/java/io/aiven/connect/jdbc/sink/JdbcSinkConfig.java +++ b/src/main/java/io/aiven/connect/jdbc/sink/JdbcSinkConfig.java @@ -39,6 +39,7 @@ public class JdbcSinkConfig extends JdbcConfig { public enum InsertMode { INSERT, + MULTI, UPSERT, UPDATE; } @@ -122,6 +123,8 @@ public enum PrimaryKeyMode { "The insertion mode to use. Supported modes are:\n" + "``insert``\n" + " Use standard SQL ``INSERT`` statements.\n" + + "``multi``\n" + + " Use multi-row ``INSERT`` statements.\n" + "``upsert``\n" + " Use the appropriate upsert semantics for the target database if it is supported by " + "the connector, e.g. ``INSERT .. ON CONFLICT .. DO UPDATE SET ..``.\n" diff --git a/src/main/java/io/aiven/connect/jdbc/sink/PreparedStatementBinder.java b/src/main/java/io/aiven/connect/jdbc/sink/PreparedStatementBinder.java index 59a4fa4c..f49e30a4 100644 --- a/src/main/java/io/aiven/connect/jdbc/sink/PreparedStatementBinder.java +++ b/src/main/java/io/aiven/connect/jdbc/sink/PreparedStatementBinder.java @@ -31,6 +31,8 @@ import io.aiven.connect.jdbc.sink.metadata.FieldsMetadata; import io.aiven.connect.jdbc.sink.metadata.SchemaPair; +import static io.aiven.connect.jdbc.sink.JdbcSinkConfig.InsertMode.MULTI; + public class PreparedStatementBinder implements StatementBinder { private final JdbcSinkConfig.PrimaryKeyMode pkMode; @@ -41,12 +43,12 @@ public class PreparedStatementBinder implements StatementBinder { private final DatabaseDialect dialect; public PreparedStatementBinder( - final DatabaseDialect dialect, - final PreparedStatement statement, - final JdbcSinkConfig.PrimaryKeyMode pkMode, - final SchemaPair schemaPair, - final FieldsMetadata fieldsMetadata, - final JdbcSinkConfig.InsertMode insertMode + final DatabaseDialect dialect, + final PreparedStatement statement, + final JdbcSinkConfig.PrimaryKeyMode pkMode, + final SchemaPair schemaPair, + final FieldsMetadata fieldsMetadata, + final JdbcSinkConfig.InsertMode insertMode ) { this.dialect = dialect; this.pkMode = pkMode; @@ -58,6 +60,12 @@ public PreparedStatementBinder( @Override public void bindRecord(final SinkRecord record) throws SQLException { + // backwards compatibility + bindRecord(1, record); + } + + + public int bindRecord(int index, final SinkRecord record) throws SQLException { final Struct valueStruct = (Struct) record.value(); // Assumption: the relevant SQL has placeholders for keyFieldNames first followed by @@ -65,23 +73,28 @@ public void bindRecord(final SinkRecord record) throws SQLException { // the relevant SQL has placeholders for nonKeyFieldNames first followed by // keyFieldNames, in iteration order for all UPDATE queries - int index = 1; + final int nextIndex; switch (insertMode) { case INSERT: + case MULTI: case UPSERT: index = bindKeyFields(record, index); - bindNonKeyFields(record, valueStruct, index); + nextIndex = bindNonKeyFields(record, valueStruct, index); break; case UPDATE: index = bindNonKeyFields(record, valueStruct, index); - bindKeyFields(record, index); + nextIndex = bindKeyFields(record, index); break; default: throw new AssertionError(); } - statement.addBatch(); + // in a multi-row insert, all records are a single item in the batch + if (insertMode != MULTI) { + statement.addBatch(); + } + return nextIndex; } protected int bindKeyFields(final SinkRecord record, int index) throws SQLException { @@ -128,9 +141,9 @@ protected int bindKeyFields(final SinkRecord record, int index) throws SQLExcept } protected int bindNonKeyFields( - final SinkRecord record, - final Struct valueStruct, - int index + final SinkRecord record, + final Struct valueStruct, + int index ) throws SQLException { for (final String fieldName : fieldsMetadata.nonKeyFieldNames) { final Field field = record.valueSchema().field(fieldName); diff --git a/src/main/java/io/aiven/connect/jdbc/util/CollectionUtils.java b/src/main/java/io/aiven/connect/jdbc/util/CollectionUtils.java new file mode 100644 index 00000000..d78caa39 --- /dev/null +++ b/src/main/java/io/aiven/connect/jdbc/util/CollectionUtils.java @@ -0,0 +1,28 @@ +/* + * Copyright 2022 Aiven Oy and jdbc-connector-for-apache-kafka project contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.aiven.connect.jdbc.util; + +import java.util.Collection; + +public final class CollectionUtils { + private CollectionUtils() { + } + + public static boolean isEmpty(final Collection collection) { + return collection == null || collection.isEmpty(); + } +} diff --git a/src/test/java/io/aiven/connect/jdbc/sink/BufferedRecordsTest.java b/src/test/java/io/aiven/connect/jdbc/sink/BufferedRecordsTest.java index f1e2dab2..dd878c1f 100644 --- a/src/test/java/io/aiven/connect/jdbc/sink/BufferedRecordsTest.java +++ b/src/test/java/io/aiven/connect/jdbc/sink/BufferedRecordsTest.java @@ -21,10 +21,10 @@ import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; -import java.sql.Statement; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.Map; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaBuilder; @@ -39,16 +39,22 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.mockito.Matchers; -import org.mockito.Mockito; +import org.mockito.ArgumentCaptor; +import static java.sql.Statement.SUCCESS_NO_INFO; import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class BufferedRecordsTest { private final SqliteHelper sqliteHelper = new SqliteHelper(getClass().getSimpleName()); + private final String dbUrl = sqliteHelper.sqliteUri(); @Before public void setUp() throws IOException, SQLException { @@ -63,34 +69,33 @@ public void tearDown() throws IOException, SQLException { @Test public void correctBatching() throws SQLException { final HashMap props = new HashMap<>(); - props.put("connection.url", sqliteHelper.sqliteUri()); + props.put("connection.url", dbUrl); props.put("auto.create", true); props.put("auto.evolve", true); props.put("batch.size", 1000); // sufficiently high to not cause flushes due to buffer being full final JdbcSinkConfig config = new JdbcSinkConfig(props); - final String url = sqliteHelper.sqliteUri(); - final DatabaseDialect dbDialect = DatabaseDialects.findBestFor(url, config); + final DatabaseDialect dbDialect = DatabaseDialects.findBestFor(dbUrl, config); final DbStructure dbStructure = new DbStructure(dbDialect); final TableId tableId = new TableId(null, null, "dummy"); final BufferedRecords buffer = new BufferedRecords( - config, tableId, dbDialect, dbStructure, sqliteHelper.connection); + config, tableId, dbDialect, dbStructure, sqliteHelper.connection); final Schema schemaA = SchemaBuilder.struct() - .field("name", Schema.STRING_SCHEMA) - .build(); + .field("name", Schema.STRING_SCHEMA) + .build(); final Struct valueA = new Struct(schemaA) - .put("name", "cuba"); - final SinkRecord recordA = new SinkRecord("dummy", 0, null, null, schemaA, valueA, 0); + .put("name", "cuba"); + final SinkRecord recordA = wrapInSinkRecord(valueA); final Schema schemaB = SchemaBuilder.struct() - .field("name", Schema.STRING_SCHEMA) - .field("age", Schema.OPTIONAL_INT32_SCHEMA) - .build(); + .field("name", Schema.STRING_SCHEMA) + .field("age", Schema.OPTIONAL_INT32_SCHEMA) + .build(); final Struct valueB = new Struct(schemaB) - .put("name", "cuba") - .put("age", 4); + .put("name", "cuba") + .put("age", 4); final SinkRecord recordB = new SinkRecord("dummy", 1, null, null, schemaB, valueB, 1); // test records are batched correctly based on schema equality as records are added @@ -116,44 +121,40 @@ public void testFlushSuccessNoInfo() throws SQLException { props.put("batch.size", 1000); final JdbcSinkConfig config = new JdbcSinkConfig(props); - final String url = sqliteHelper.sqliteUri(); - final DatabaseDialect dbDialect = DatabaseDialects.findBestFor(url, config); + final DatabaseDialect dbDialect = DatabaseDialects.findBestFor(dbUrl, config); - final int[] batchResponse = new int[2]; - batchResponse[0] = Statement.SUCCESS_NO_INFO; - batchResponse[1] = Statement.SUCCESS_NO_INFO; + final int[] batchResponse = new int[] {SUCCESS_NO_INFO, SUCCESS_NO_INFO}; final DbStructure dbStructureMock = mock(DbStructure.class); - when(dbStructureMock.createOrAmendIfNecessary(Matchers.any(JdbcSinkConfig.class), - Matchers.any(Connection.class), - Matchers.any(TableId.class), - Matchers.any(FieldsMetadata.class))) - .thenReturn(true); + when(dbStructureMock.createOrAmendIfNecessary(any(JdbcSinkConfig.class), + any(Connection.class), + any(TableId.class), + any(FieldsMetadata.class))) + .thenReturn(true); final PreparedStatement preparedStatementMock = mock(PreparedStatement.class); when(preparedStatementMock.executeBatch()).thenReturn(batchResponse); final Connection connectionMock = mock(Connection.class); - when(connectionMock.prepareStatement(Matchers.anyString())).thenReturn(preparedStatementMock); + when(connectionMock.prepareStatement(anyString())).thenReturn(preparedStatementMock); final TableId tableId = new TableId(null, null, "dummy"); final BufferedRecords buffer = new BufferedRecords(config, tableId, dbDialect, - dbStructureMock, connectionMock); + dbStructureMock, connectionMock); final Schema schemaA = SchemaBuilder.struct().field("name", Schema.STRING_SCHEMA).build(); final Struct valueA = new Struct(schemaA).put("name", "cuba"); - final SinkRecord recordA = new SinkRecord("dummy", 0, null, null, schemaA, valueA, 0); + final SinkRecord recordA = wrapInSinkRecord(valueA); buffer.add(recordA); final Schema schemaB = SchemaBuilder.struct().field("name", Schema.STRING_SCHEMA).build(); final Struct valueB = new Struct(schemaA).put("name", "cubb"); - final SinkRecord recordB = new SinkRecord("dummy", 0, null, null, schemaB, valueB, 0); + final SinkRecord recordB = wrapInSinkRecord(valueB); buffer.add(recordB); buffer.flush(); } - @Test public void testInsertModeUpdate() throws SQLException { final HashMap props = new HashMap<>(); @@ -164,27 +165,131 @@ public void testInsertModeUpdate() throws SQLException { props.put("insert.mode", "update"); final JdbcSinkConfig config = new JdbcSinkConfig(props); - final String url = sqliteHelper.sqliteUri(); - final DatabaseDialect dbDialect = DatabaseDialects.findBestFor(url, config); + final DatabaseDialect dbDialect = DatabaseDialects.findBestFor(dbUrl, config); final DbStructure dbStructureMock = mock(DbStructure.class); - when(dbStructureMock.createOrAmendIfNecessary(Matchers.any(JdbcSinkConfig.class), - Matchers.any(Connection.class), - Matchers.any(TableId.class), - Matchers.any(FieldsMetadata.class))) - .thenReturn(true); + when(dbStructureMock.createOrAmendIfNecessary(any(JdbcSinkConfig.class), + any(Connection.class), + any(TableId.class), + any(FieldsMetadata.class))) + .thenReturn(true); final Connection connectionMock = mock(Connection.class); + final PreparedStatement preparedStatement = mock(PreparedStatement.class); + when(connectionMock.prepareStatement(anyString())).thenReturn(preparedStatement); + when(preparedStatement.executeBatch()).thenReturn(new int[1]); + final TableId tableId = new TableId(null, null, "dummy"); final BufferedRecords buffer = new BufferedRecords(config, tableId, dbDialect, dbStructureMock, - connectionMock); + connectionMock); final Schema schemaA = SchemaBuilder.struct().field("name", Schema.STRING_SCHEMA).build(); final Struct valueA = new Struct(schemaA).put("name", "cuba"); - final SinkRecord recordA = new SinkRecord("dummy", 0, null, null, schemaA, valueA, 0); + final SinkRecord recordA = wrapInSinkRecord(valueA); buffer.add(recordA); + buffer.flush(); + + verify(connectionMock).prepareStatement(eq("UPDATE \"dummy\" SET \"name\" = ?")); + + } + + @Test + public void testInsertModeMultiAutomaticFlush() throws SQLException { + final JdbcSinkConfig config = multiModeConfig(2); + + final DatabaseDialect dbDialect = DatabaseDialects.findBestFor(dbUrl, config); + final DbStructure dbStructureMock = mock(DbStructure.class); + when(dbStructureMock.createOrAmendIfNecessary(any(JdbcSinkConfig.class), + any(Connection.class), + any(TableId.class), + any(FieldsMetadata.class))) + .thenReturn(true); + + final Connection connection = mock(Connection.class); + final PreparedStatement preparedStatement = mock(PreparedStatement.class); + when(connection.prepareStatement(anyString())).thenReturn(preparedStatement); + when(preparedStatement.executeBatch()).thenReturn(new int[]{2}); + + final TableId tableId = new TableId(null, null, "planets"); + final BufferedRecords buffer = new BufferedRecords(config, tableId, dbDialect, dbStructureMock, + connection); + + final Schema schema = newPlanetSchema(); + for (int i = 1; i <= 5; i++) { + buffer.add(wrapInSinkRecord(newPlanet(schema, 1, "planet name " + i))); + } + + final ArgumentCaptor sqlCaptor = ArgumentCaptor.forClass(String.class); + // Given the 5 records, and batch size of 2, we expect 2 inserts. + // One record is still waiting in the buffer, and that is expected. + verify(connection, times(2)).prepareStatement(sqlCaptor.capture()); + assertEquals( + sqlCaptor.getAllValues().get(0), + "INSERT INTO \"planets\"(\"name\",\"planetid\") VALUES (?,?),(?,?)" + ); + assertEquals( + sqlCaptor.getAllValues().get(1), + "INSERT INTO \"planets\"(\"name\",\"planetid\") VALUES (?,?),(?,?)" + ); + } + + @Test + public void testInsertModeMultiExplicitFlush() throws SQLException { + final JdbcSinkConfig config = multiModeConfig(100); - Mockito.verify(connectionMock, Mockito.times(1)) - .prepareStatement(Matchers.eq("UPDATE \"dummy\" SET \"name\" = ?")); + final DatabaseDialect dbDialect = DatabaseDialects.findBestFor(dbUrl, config); + final DbStructure dbStructureMock = mock(DbStructure.class); + when(dbStructureMock.createOrAmendIfNecessary(any(JdbcSinkConfig.class), + any(Connection.class), + any(TableId.class), + any(FieldsMetadata.class))) + .thenReturn(true); + + final Connection connection = mock(Connection.class); + final PreparedStatement preparedStatement = mock(PreparedStatement.class); + when(connection.prepareStatement(anyString())).thenReturn(preparedStatement); + when(preparedStatement.executeBatch()).thenReturn(new int[]{2}); + + final TableId tableId = new TableId(null, null, "planets"); + final BufferedRecords buffer = new BufferedRecords(config, tableId, dbDialect, dbStructureMock, + connection); + + final Schema schema = newPlanetSchema(); + final Struct valueA = newPlanet(schema, 1, "mercury"); + final Struct valueB = newPlanet(schema, 2, "venus"); + buffer.add(wrapInSinkRecord(valueA)); + buffer.add(wrapInSinkRecord(valueB)); + buffer.flush(); + + verify(connection).prepareStatement( + "INSERT INTO \"planets\"(\"name\",\"planetid\") VALUES (?,?),(?,?)" + ); + + } + + private Struct newPlanet(final Schema schema, final int id, final String name) { + return new Struct(schema) + .put("planetid", id) + .put("name", name); + } + + private Schema newPlanetSchema() { + return SchemaBuilder.struct() + .field("name", Schema.STRING_SCHEMA) + .field("planetid", Schema.INT32_SCHEMA) + .build(); + } + + private JdbcSinkConfig multiModeConfig(final int batchSize) { + return new JdbcSinkConfig(Map.of( + "connection.url", "", + "auto.create", true, + "auto.evolve", true, + "batch.size", batchSize, + "insert.mode", "multi" + )); + } + private SinkRecord wrapInSinkRecord(final Struct value) { + return new SinkRecord("dummy-topic", 0, null, null, value.schema(), value, 0); } }