From a31b0caa473cd088e9d7f1e4ec9c6fced342140f Mon Sep 17 00:00:00 2001 From: Murali Basani Date: Fri, 22 Nov 2024 16:29:21 +0100 Subject: [PATCH] [KCON35] : Improvement : Read files with stream instead of loading it all (#351) Currently the transformers load the files and get a list of records. This could cause performance issues for large files. * With Stream/StreamSupport, only when next() is called from iterator, a record is transformed. --- s3-source-connector/build.gradle.kts | 1 - .../connect/s3/source/IntegrationTest.java | 55 ++++------- .../s3/source/input/AvroTransformer.java | 26 +++++- .../s3/source/input/ByteArrayTransformer.java | 49 +++++----- .../s3/source/input/JsonTransformer.java | 93 +++++++++++++------ .../s3/source/input/ParquetTransformer.java | 71 +++++++++----- .../connect/s3/source/input/Transformer.java | 7 +- .../s3/source/utils/SourceRecordIterator.java | 49 +++++----- .../input/ByteArrayTransformerTest.java | 48 ++++------ .../s3/source/input/JsonTransformerTest.java | 54 +++++++++-- .../source/input/ParquetTransformerTest.java | 49 +++++++++- .../utils/SourceRecordIteratorTest.java | 5 +- 12 files changed, 327 insertions(+), 180 deletions(-) diff --git a/s3-source-connector/build.gradle.kts b/s3-source-connector/build.gradle.kts index ad2c69d2a..943dbc75c 100644 --- a/s3-source-connector/build.gradle.kts +++ b/s3-source-connector/build.gradle.kts @@ -117,7 +117,6 @@ dependencies { exclude(group = "org.apache.commons", module = "commons-math3") exclude(group = "org.apache.httpcomponents", module = "httpclient") exclude(group = "commons-codec", module = "commons-codec") - exclude(group = "commons-io", module = "commons-io") exclude(group = "commons-net", module = "commons-net") exclude(group = "org.eclipse.jetty") exclude(group = "org.eclipse.jetty.websocket") diff --git a/s3-source-connector/src/integration-test/java/io/aiven/kafka/connect/s3/source/IntegrationTest.java b/s3-source-connector/src/integration-test/java/io/aiven/kafka/connect/s3/source/IntegrationTest.java index bab6d1587..eb0e86003 100644 --- a/s3-source-connector/src/integration-test/java/io/aiven/kafka/connect/s3/source/IntegrationTest.java +++ b/s3-source-connector/src/integration-test/java/io/aiven/kafka/connect/s3/source/IntegrationTest.java @@ -22,9 +22,7 @@ import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.AWS_S3_ENDPOINT_CONFIG; import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.AWS_S3_PREFIX_CONFIG; import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.AWS_SECRET_ACCESS_KEY_CONFIG; -import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.EXPECTED_MAX_MESSAGE_BYTES; import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.INPUT_FORMAT_KEY; -import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.MAX_POLL_RECORDS; import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.SCHEMA_REGISTRY_URL; import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.TARGET_TOPICS; import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.TARGET_TOPIC_PARTITIONS; @@ -41,7 +39,6 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; -import java.time.Duration; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.util.HashMap; @@ -68,7 +65,6 @@ import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; import org.apache.avro.io.DatumWriter; -import org.awaitility.Awaitility; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -185,33 +181,6 @@ void bytesTest(final TestInfo testInfo) throws IOException { verifyOffsetPositions(offsetKeys, 4); } - @Test - void bytesTestBasedOnMaxMessageBytes(final TestInfo testInfo) throws IOException, InterruptedException { - final String testData = "AABBCCDDEE"; - final var topicName = IntegrationBase.topicName(testInfo); - final Map connectorConfig = getConfig(CONNECTOR_NAME, topicName, 3); - connectorConfig.put(INPUT_FORMAT_KEY, InputFormat.BYTES.getValue()); - connectorConfig.put(EXPECTED_MAX_MESSAGE_BYTES, "2"); // For above test data of 10 bytes length, with 2 bytes - // each - // in source record, we expect 5 records. - connectorConfig.put(MAX_POLL_RECORDS, "2"); // In 3 polls all the 5 records should be processed - - connectRunner.configureConnector(CONNECTOR_NAME, connectorConfig); - final String offsetKey = writeToS3(topicName, testData.getBytes(StandardCharsets.UTF_8), "00000"); - - // Poll messages from the Kafka topic and verify the consumed data - final List records = IntegrationBase.consumeMessages(topicName, 5, connectRunner.getBootstrapServers()); - - // Verify that the correct data is read from the S3 bucket and pushed to Kafka - assertThat(records).containsExactly("AA", "BB", "CC", "DD", "EE"); - - Awaitility.await().atMost(Duration.ofMinutes(2)).untilAsserted(() -> { - final Map offsetRecs = IntegrationBase.consumeOffsetStorageMessages( - "connect-offset-topic-" + CONNECTOR_NAME, 1, connectRunner.getBootstrapServers()); - assertThat(offsetRecs).containsExactly(entry(offsetKey, 5)); - }); - } - @Test void avroTest(final TestInfo testInfo) throws IOException, InterruptedException { final var topicName = IntegrationBase.topicName(testInfo); @@ -227,16 +196,19 @@ void avroTest(final TestInfo testInfo) throws IOException, InterruptedException final Schema schema = parser.parse(schemaJson); final byte[] outputStream1 = getAvroRecord(schema, 1, 100); - final byte[] outputStream2 = getAvroRecord(schema, 2, 100); + final byte[] outputStream2 = getAvroRecord(schema, 101, 100); + final byte[] outputStream3 = getAvroRecord(schema, 201, 100); + final byte[] outputStream4 = getAvroRecord(schema, 301, 100); + final byte[] outputStream5 = getAvroRecord(schema, 401, 100); final Set offsetKeys = new HashSet<>(); offsetKeys.add(writeToS3(topicName, outputStream1, "00001")); offsetKeys.add(writeToS3(topicName, outputStream2, "00001")); - offsetKeys.add(writeToS3(topicName, outputStream1, "00002")); - offsetKeys.add(writeToS3(topicName, outputStream2, "00002")); - offsetKeys.add(writeToS3(topicName, outputStream2, "00002")); + offsetKeys.add(writeToS3(topicName, outputStream3, "00002")); + offsetKeys.add(writeToS3(topicName, outputStream4, "00002")); + offsetKeys.add(writeToS3(topicName, outputStream5, "00002")); assertThat(testBucketAccessor.listObjects()).hasSize(5); @@ -249,7 +221,12 @@ void avroTest(final TestInfo testInfo) throws IOException, InterruptedException assertThat(records).hasSize(500) .map(record -> entry(record.get("id"), String.valueOf(record.get("message")))) .contains(entry(1, "Hello, Kafka Connect S3 Source! object 1"), - entry(2, "Hello, Kafka Connect S3 Source! object 2")); + entry(2, "Hello, Kafka Connect S3 Source! object 2"), + entry(100, "Hello, Kafka Connect S3 Source! object 100"), + entry(200, "Hello, Kafka Connect S3 Source! object 200"), + entry(300, "Hello, Kafka Connect S3 Source! object 300"), + entry(400, "Hello, Kafka Connect S3 Source! object 400"), + entry(500, "Hello, Kafka Connect S3 Source! object 500")); Thread.sleep(10_000); @@ -327,17 +304,17 @@ void jsonTest(final TestInfo testInfo) throws IOException { verifyOffsetPositions(offsetKeys, 1); } - private static byte[] getAvroRecord(final Schema schema, final int messageId, final int noOfAvroRecs) - throws IOException { + private static byte[] getAvroRecord(final Schema schema, int messageId, final int noOfAvroRecs) throws IOException { final DatumWriter datumWriter = new GenericDatumWriter<>(schema); try (DataFileWriter dataFileWriter = new DataFileWriter<>(datumWriter); ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { dataFileWriter.create(schema, outputStream); for (int i = 0; i < noOfAvroRecs; i++) { final GenericRecord avroRecord = new GenericData.Record(schema); // NOPMD - avroRecord.put("message", "Hello, Kafka Connect S3 Source! object " + i); + avroRecord.put("message", "Hello, Kafka Connect S3 Source! object " + messageId); avroRecord.put("id", messageId); dataFileWriter.append(avroRecord); + messageId++; // NOPMD } dataFileWriter.flush(); diff --git a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/AvroTransformer.java b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/AvroTransformer.java index a781f6bd1..dd2516692 100644 --- a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/AvroTransformer.java +++ b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/AvroTransformer.java @@ -24,15 +24,21 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; import com.amazonaws.util.IOUtils; import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.DataFileStream; import org.apache.avro.file.SeekableByteArrayInput; import org.apache.avro.generic.GenericDatumReader; import org.apache.avro.generic.GenericRecord; import org.apache.avro.io.DatumReader; +import org.apache.commons.io.function.IOSupplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,10 +52,10 @@ public void configureValueConverter(final Map config, final S3So } @Override - public List getRecords(final InputStream inputStream, final String topic, final int topicPartition, - final S3SourceConfig s3SourceConfig) { + public Stream getRecords(final IOSupplier inputStreamIOSupplier, final String topic, + final int topicPartition, final S3SourceConfig s3SourceConfig) { final DatumReader datumReader = new GenericDatumReader<>(); - return readAvroRecords(inputStream, datumReader); + return readAvroRecordsAsStream(inputStreamIOSupplier, datumReader); } @Override @@ -58,6 +64,20 @@ public byte[] getValueBytes(final Object record, final String topic, final S3Sou s3SourceConfig); } + private Stream readAvroRecordsAsStream(final IOSupplier inputStreamIOSupplier, + final DatumReader datumReader) { + try (DataFileStream dataFileStream = new DataFileStream<>(inputStreamIOSupplier.get(), + datumReader)) { + // Wrap DataFileStream in a Stream using a Spliterator for lazy processing + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize(dataFileStream, Spliterator.ORDERED | Spliterator.NONNULL), + false); + } catch (IOException e) { + LOGGER.error("Error in DataFileStream: {}", e.getMessage(), e); + return Stream.empty(); // Return an empty stream if initialization fails + } + } + List readAvroRecords(final InputStream content, final DatumReader datumReader) { final List records = new ArrayList<>(); try (SeekableByteArrayInput sin = new SeekableByteArrayInput(IOUtils.toByteArray(content))) { diff --git a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/ByteArrayTransformer.java b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/ByteArrayTransformer.java index bc53e6330..8e36cab8c 100644 --- a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/ByteArrayTransformer.java +++ b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/ByteArrayTransformer.java @@ -16,16 +16,17 @@ package io.aiven.kafka.connect.s3.source.input; -import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.EXPECTED_MAX_MESSAGE_BYTES; - import java.io.IOException; import java.io.InputStream; -import java.util.ArrayList; -import java.util.List; import java.util.Map; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; +import org.apache.commons.io.function.IOSupplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,29 +38,31 @@ public void configureValueConverter(final Map config, final S3So // For byte array transformations, ByteArrayConverter is the converter which is the default config. } - @SuppressWarnings("PMD.AvoidInstantiatingObjectsInLoops") @Override - public List getRecords(final InputStream inputStream, final String topic, final int topicPartition, - final S3SourceConfig s3SourceConfig) { + public Stream getRecords(final IOSupplier inputStreamIOSupplier, final String topic, + final int topicPartition, final S3SourceConfig s3SourceConfig) { - final int maxMessageBytesSize = s3SourceConfig.getInt(EXPECTED_MAX_MESSAGE_BYTES); - final byte[] buffer = new byte[maxMessageBytesSize]; - int bytesRead; + // Create a Stream that processes each chunk lazily + return StreamSupport.stream(new Spliterators.AbstractSpliterator<>(Long.MAX_VALUE, Spliterator.ORDERED) { + final byte[] buffer = new byte[4096]; - final List chunks = new ArrayList<>(); - try { - bytesRead = inputStream.read(buffer); - while (bytesRead != -1) { - final byte[] chunk = new byte[bytesRead]; - System.arraycopy(buffer, 0, chunk, 0, bytesRead); - chunks.add(chunk); - bytesRead = inputStream.read(buffer); + @Override + public boolean tryAdvance(final java.util.function.Consumer action) { + try (InputStream inputStream = inputStreamIOSupplier.get()) { + final int bytesRead = inputStream.read(buffer); + if (bytesRead == -1) { + return false; + } + final byte[] chunk = new byte[bytesRead]; + System.arraycopy(buffer, 0, chunk, 0, bytesRead); + action.accept(chunk); + return true; + } catch (IOException e) { + LOGGER.error("Error trying to advance byte stream: {}", e.getMessage(), e); + return false; + } } - } catch (IOException e) { - LOGGER.error("Error reading from input stream: {}", e.getMessage(), e); - } - - return chunks; + }, false); } @Override diff --git a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/JsonTransformer.java b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/JsonTransformer.java index 5cda04f1a..80827fd8a 100644 --- a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/JsonTransformer.java +++ b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/JsonTransformer.java @@ -23,15 +23,18 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; import java.util.Map; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.io.function.IOSupplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,30 +50,9 @@ public void configureValueConverter(final Map config, final S3So } @Override - public List getRecords(final InputStream inputStream, final String topic, final int topicPartition, - final S3SourceConfig s3SourceConfig) { - final List jsonNodeList = new ArrayList<>(); - JsonNode jsonNode; - try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { - String line = reader.readLine(); - while (line != null) { - line = line.trim(); - if (!line.isEmpty()) { - try { - // Parse each line as a separate JSON object - jsonNode = objectMapper.readTree(line.trim()); // Parse the current line into a JsonNode - jsonNodeList.add(jsonNode); // Add parsed JSON object to the list - } catch (IOException e) { - LOGGER.error("Error parsing JSON record from S3 input stream: {}", e.getMessage(), e); - } - } - - line = reader.readLine(); - } - } catch (IOException e) { - LOGGER.error("Error reading S3 object stream: {}", e.getMessage()); - } - return jsonNodeList; + public Stream getRecords(final IOSupplier inputStreamIOSupplier, final String topic, + final int topicPartition, final S3SourceConfig s3SourceConfig) { + return readJsonRecordsAsStream(inputStreamIOSupplier); } @Override @@ -82,4 +64,63 @@ public byte[] getValueBytes(final Object record, final String topic, final S3Sou return new byte[0]; } } + + private Stream readJsonRecordsAsStream(final IOSupplier inputStreamIOSupplier) { + // Use a Stream that lazily processes each line as a JSON object + CustomSpliterator customSpliteratorParam; + try { + customSpliteratorParam = new CustomSpliterator(inputStreamIOSupplier); + } catch (IOException e) { + LOGGER.error("Error creating Json transformer CustomSpliterator: {}", e.getMessage(), e); + return Stream.empty(); + } + return StreamSupport.stream(customSpliteratorParam, false).onClose(() -> { + try { + customSpliteratorParam.reader.close(); // Ensure the reader is closed after streaming + } catch (IOException e) { + LOGGER.error("Error closing BufferedReader: {}", e.getMessage(), e); + } + }); + } + + /* + * This CustomSpliterator class is created so that BufferedReader instantiation is not closed before the all the + * records from stream is closed. With this now, we have a onclose method declared in parent declaration. + */ + final class CustomSpliterator extends Spliterators.AbstractSpliterator { + BufferedReader reader; + String line; + CustomSpliterator(final IOSupplier inputStreamIOSupplier) throws IOException { + super(Long.MAX_VALUE, Spliterator.ORDERED | Spliterator.NONNULL); + reader = new BufferedReader(new InputStreamReader(inputStreamIOSupplier.get(), StandardCharsets.UTF_8)); + } + + @Override + public boolean tryAdvance(final java.util.function.Consumer action) { + try { + if (line == null) { + line = reader.readLine(); + } + while (line != null) { + line = line.trim(); + if (!line.isEmpty()) { + try { + final JsonNode jsonNode = objectMapper.readTree(line); // Parse the JSON + // line + action.accept(jsonNode); // Provide the parsed JSON node to the stream + } catch (IOException e) { + LOGGER.error("Error parsing JSON record: {}", e.getMessage(), e); + } + line = null; // NOPMD + return true; + } + line = reader.readLine(); + } + return false; // End of file + } catch (IOException e) { + LOGGER.error("Error reading S3 object stream: {}", e.getMessage(), e); + return false; + } + } + } } diff --git a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/ParquetTransformer.java b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/ParquetTransformer.java index 39fec83de..48b0abd33 100644 --- a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/ParquetTransformer.java +++ b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/ParquetTransformer.java @@ -25,15 +25,18 @@ import java.nio.file.Files; import java.nio.file.Path; import java.time.Instant; -import java.util.ArrayList; import java.util.Collections; -import java.util.List; import java.util.Map; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; import org.apache.avro.generic.GenericRecord; import org.apache.commons.compress.utils.IOUtils; +import org.apache.commons.io.function.IOSupplier; import org.apache.parquet.avro.AvroParquetReader; import org.apache.parquet.io.InputFile; import org.apache.parquet.io.LocalInputFile; @@ -50,9 +53,9 @@ public void configureValueConverter(final Map config, final S3So } @Override - public List getRecords(final InputStream inputStream, final String topic, final int topicPartition, - final S3SourceConfig s3SourceConfig) { - return getParquetRecords(inputStream, topic, topicPartition); + public Stream getRecords(final IOSupplier inputStreamIOSupplier, final String topic, + final int topicPartition, final S3SourceConfig s3SourceConfig) { + return getParquetStreamRecords(inputStreamIOSupplier, topic, topicPartition); } @Override @@ -61,35 +64,59 @@ public byte[] getValueBytes(final Object record, final String topic, final S3Sou s3SourceConfig); } - private List getParquetRecords(final InputStream inputStream, final String topic, - final int topicPartition) { + private Stream getParquetStreamRecords(final IOSupplier inputStreamIOSupplier, + final String topic, final int topicPartition) { final String timestamp = String.valueOf(Instant.now().toEpochMilli()); File parquetFile; - final List records = new ArrayList<>(); + try { + // Create a temporary file for the Parquet data parquetFile = File.createTempFile(topic + "_" + topicPartition + "_" + timestamp, ".parquet"); } catch (IOException e) { - LOGGER.error("Error in reading s3 object stream {}", e.getMessage(), e); - return records; + LOGGER.error("Error creating temp file for Parquet data: {}", e.getMessage(), e); + return Stream.empty(); } - try (OutputStream outputStream = Files.newOutputStream(parquetFile.toPath())) { - IOUtils.copy(inputStream, outputStream); + try (OutputStream outputStream = Files.newOutputStream(parquetFile.toPath()); + InputStream inputStream = inputStreamIOSupplier.get();) { + IOUtils.copy(inputStream, outputStream); // Copy input stream to temporary file + final InputFile inputFile = new LocalInputFile(parquetFile.toPath()); - try (var parquetReader = AvroParquetReader.builder(inputFile).build()) { - GenericRecord record; - record = parquetReader.read(); - while (record != null) { - records.add(record); - record = parquetReader.read(); + final var parquetReader = AvroParquetReader.builder(inputFile).build(); + + return StreamSupport.stream(new Spliterators.AbstractSpliterator(Long.MAX_VALUE, + Spliterator.ORDERED | Spliterator.NONNULL) { + @Override + public boolean tryAdvance(final java.util.function.Consumer action) { + try { + final GenericRecord record = parquetReader.read(); + if (record != null) { + action.accept(record); // Pass record to the stream + return true; + } else { + parquetReader.close(); // Close reader at end of file + deleteTmpFile(parquetFile.toPath()); + return false; + } + } catch (IOException | RuntimeException e) { // NOPMD + LOGGER.error("Error reading Parquet record: {}", e.getMessage(), e); + deleteTmpFile(parquetFile.toPath()); + return false; + } } - } + }, false).onClose(() -> { + try { + parquetReader.close(); // Ensure reader is closed when the stream is closed + } catch (IOException e) { + LOGGER.error("Error closing Parquet reader: {}", e.getMessage(), e); + } + deleteTmpFile(parquetFile.toPath()); + }); } catch (IOException | RuntimeException e) { // NOPMD - LOGGER.error("Error in reading s3 object stream {}", e.getMessage(), e); - } finally { + LOGGER.error("Error processing Parquet data: {}", e.getMessage(), e); deleteTmpFile(parquetFile.toPath()); + return Stream.empty(); } - return records; } static void deleteTmpFile(final Path parquetFile) { diff --git a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/Transformer.java b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/Transformer.java index 70fe28d96..616cfdb77 100644 --- a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/Transformer.java +++ b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/input/Transformer.java @@ -17,16 +17,19 @@ package io.aiven.kafka.connect.s3.source.input; import java.io.InputStream; -import java.util.List; import java.util.Map; +import java.util.stream.Stream; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; +import org.apache.commons.io.function.IOSupplier; + public interface Transformer { void configureValueConverter(Map config, S3SourceConfig s3SourceConfig); - List getRecords(InputStream inputStream, String topic, int topicPartition, S3SourceConfig s3SourceConfig); + Stream getRecords(IOSupplier inputStreamIOSupplier, String topic, int topicPartition, + S3SourceConfig s3SourceConfig); byte[] getValueBytes(Object record, String topic, S3SourceConfig s3SourceConfig); } diff --git a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/utils/SourceRecordIterator.java b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/utils/SourceRecordIterator.java index 8c1fcb77d..3a6c40812 100644 --- a/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/utils/SourceRecordIterator.java +++ b/s3-source-connector/src/main/java/io/aiven/kafka/connect/s3/source/utils/SourceRecordIterator.java @@ -19,7 +19,6 @@ import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.MAX_POLL_RECORDS; import java.io.IOException; -import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; @@ -28,6 +27,7 @@ import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Stream; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; import io.aiven.kafka.connect.s3.source.input.Transformer; @@ -35,7 +35,6 @@ import com.amazonaws.AmazonClientException; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.amazonaws.services.s3.model.S3ObjectSummary; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -95,8 +94,7 @@ private void nextS3Object() { } private Iterator createIteratorForCurrentFile() throws IOException { - try (S3Object s3Object = s3Client.getObject(bucketName, currentObjectKey); - S3ObjectInputStream inputStream = s3Object.getObjectContent()) { + try (S3Object s3Object = s3Client.getObject(bucketName, currentObjectKey);) { final Matcher fileMatcher = FILE_DEFAULT_PATTERN.matcher(currentObjectKey); String topicName; @@ -107,7 +105,6 @@ private Iterator createIteratorForCurrentFile() throws IOExcepti defaultPartitionId = Integer.parseInt(fileMatcher.group(PATTERN_PARTITION_KEY)); } else { LOGGER.error("File naming doesn't match to any topic. {}", currentObjectKey); - inputStream.abort(); s3Object.close(); return Collections.emptyIterator(); } @@ -118,13 +115,13 @@ private Iterator createIteratorForCurrentFile() throws IOExcepti final Map partitionMap = ConnectUtils.getPartitionMap(topicName, defaultPartitionId, bucketName); - return getObjectIterator(inputStream, finalTopic, defaultPartitionId, defaultStartOffsetId, transformer, + return getObjectIterator(s3Object, finalTopic, defaultPartitionId, defaultStartOffsetId, transformer, partitionMap); } } @SuppressWarnings("PMD.CognitiveComplexity") - private Iterator getObjectIterator(final InputStream valueInputStream, final String topic, + private Iterator getObjectIterator(final S3Object s3Object, final String topic, final int topicPartition, final long startOffset, final Transformer transformer, final Map partitionMap) { return new Iterator<>() { @@ -136,24 +133,34 @@ private List readNext() { int numOfProcessedRecs = 1; boolean checkOffsetMap = true; - for (final Object record : transformer.getRecords(valueInputStream, topic, topicPartition, - s3SourceConfig)) { - if (offsetManager.shouldSkipRecord(partitionMap, currentObjectKey, numOfProcessedRecs) - && checkOffsetMap) { + try (Stream recordStream = transformer.getRecords(s3Object::getObjectContent, topic, + topicPartition, s3SourceConfig)) { + final Iterator recordIterator = recordStream.iterator(); + while (recordIterator.hasNext()) { + final Object record = recordIterator.next(); + + // Check if the record should be skipped based on the offset + if (offsetManager.shouldSkipRecord(partitionMap, currentObjectKey, numOfProcessedRecs) + && checkOffsetMap) { + numOfProcessedRecs++; + continue; + } + + final byte[] valueBytes = transformer.getValueBytes(record, topic, s3SourceConfig); + checkOffsetMap = false; + + sourceRecords.add(getSourceRecord(keyBytes, valueBytes, topic, topicPartition, offsetManager, + startOffset, partitionMap)); + numOfProcessedRecs++; - continue; - } - final byte[] valueBytes = transformer.getValueBytes(record, topic, s3SourceConfig); - checkOffsetMap = false; - sourceRecords.add(getSourceRecord(keyBytes, valueBytes, topic, topicPartition, offsetManager, - startOffset, partitionMap)); - if (sourceRecords.size() >= s3SourceConfig.getInt(MAX_POLL_RECORDS)) { - break; + // Break if we have reached the max records per poll + if (sourceRecords.size() >= s3SourceConfig.getInt(MAX_POLL_RECORDS)) { + break; + } } - - numOfProcessedRecs++; } + return sourceRecords; } diff --git a/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/ByteArrayTransformerTest.java b/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/ByteArrayTransformerTest.java index db743748f..2486cfadd 100644 --- a/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/ByteArrayTransformerTest.java +++ b/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/ByteArrayTransformerTest.java @@ -16,18 +16,17 @@ package io.aiven.kafka.connect.s3.source.input; -import static io.aiven.kafka.connect.s3.source.config.S3SourceConfig.EXPECTED_MAX_MESSAGE_BYTES; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.Mockito.when; +import static org.assertj.core.api.Assertions.assertThat; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.io.InputStream; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; +import org.apache.commons.io.function.IOSupplier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -37,6 +36,7 @@ @ExtendWith(MockitoExtension.class) final class ByteArrayTransformerTest { + public static final String TEST_TOPIC = "test-topic"; private ByteArrayTransformer byteArrayTransformer; @Mock @@ -51,45 +51,33 @@ void setUp() { void testGetRecordsSingleChunk() { final byte[] data = { 1, 2, 3, 4, 5 }; final InputStream inputStream = new ByteArrayInputStream(data); + final IOSupplier inputStreamIOSupplier = () -> inputStream; - when(s3SourceConfig.getInt(EXPECTED_MAX_MESSAGE_BYTES)).thenReturn(10_000); // Larger than data size + final Stream records = byteArrayTransformer.getRecords(inputStreamIOSupplier, TEST_TOPIC, 0, + s3SourceConfig); - final List records = byteArrayTransformer.getRecords(inputStream, "test-topic", 0, s3SourceConfig); - - assertEquals(1, records.size()); - assertArrayEquals(data, (byte[]) records.get(0)); - } - - @Test - void testGetRecordsMultipleChunks() { - final byte[] data = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; - final InputStream inputStream = new ByteArrayInputStream(data); - - when(s3SourceConfig.getInt(EXPECTED_MAX_MESSAGE_BYTES)).thenReturn(5); // Smaller than data size - - final List records = byteArrayTransformer.getRecords(inputStream, "test-topic", 0, s3SourceConfig); - - assertEquals(2, records.size()); - assertArrayEquals(new byte[] { 1, 2, 3, 4, 5 }, (byte[]) records.get(0)); - assertArrayEquals(new byte[] { 6, 7, 8, 9, 10 }, (byte[]) records.get(1)); + final List recs = records.collect(Collectors.toList()); + assertThat(recs).hasSize(1); + assertThat((byte[]) recs.get(0)).isEqualTo(data); } @Test - void testGetRecordsEmptyInputStream() throws IOException { + void testGetRecordsEmptyInputStream() { final InputStream inputStream = new ByteArrayInputStream(new byte[] {}); - when(s3SourceConfig.getInt(EXPECTED_MAX_MESSAGE_BYTES)).thenReturn(5); + final IOSupplier inputStreamIOSupplier = () -> inputStream; - final List records = byteArrayTransformer.getRecords(inputStream, "test-topic", 0, s3SourceConfig); + final Stream records = byteArrayTransformer.getRecords(inputStreamIOSupplier, TEST_TOPIC, 0, + s3SourceConfig); - assertEquals(0, records.size()); + assertThat(records).hasSize(0); } @Test void testGetValueBytes() { final byte[] record = { 1, 2, 3 }; - final byte[] result = byteArrayTransformer.getValueBytes(record, "test-topic", s3SourceConfig); + final byte[] result = byteArrayTransformer.getValueBytes(record, TEST_TOPIC, s3SourceConfig); - assertArrayEquals(record, result); + assertThat(result).containsExactlyInAnyOrder(record); } } diff --git a/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/JsonTransformerTest.java b/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/JsonTransformerTest.java index e24711f36..bdf4780d1 100644 --- a/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/JsonTransformerTest.java +++ b/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/JsonTransformerTest.java @@ -20,31 +20,39 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.HashMap; -import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.io.function.IOSupplier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) final class JsonTransformerTest { + public static final String TESTTOPIC = "testtopic"; JsonTransformer jsonTransformer; S3SourceConfig s3SourceConfig; + @Mock + private IOSupplier inputStreamIOSupplierMock; + @BeforeEach void setUp() { jsonTransformer = new JsonTransformer(); @@ -63,32 +71,64 @@ void testConfigureValueConverter() { void testHandleValueDataWithValidJson() { final InputStream validJsonInputStream = new ByteArrayInputStream( "{\"key\":\"value\"}".getBytes(StandardCharsets.UTF_8)); - final List jsonNodes = jsonTransformer.getRecords(validJsonInputStream, "testtopic", 1, s3SourceConfig); + final IOSupplier inputStreamIOSupplier = () -> validJsonInputStream; + final Stream jsonNodes = jsonTransformer.getRecords(inputStreamIOSupplier, TESTTOPIC, 1, + s3SourceConfig); - assertThat(jsonNodes.size()).isEqualTo(1); + assertThat(jsonNodes.collect(Collectors.toList())).hasSize(1); } @Test void testHandleValueDataWithInvalidJson() { final InputStream invalidJsonInputStream = new ByteArrayInputStream( "invalid-json".getBytes(StandardCharsets.UTF_8)); + final IOSupplier inputStreamIOSupplier = () -> invalidJsonInputStream; - final List jsonNodes = jsonTransformer.getRecords(invalidJsonInputStream, "testtopic", 1, + final Stream jsonNodes = jsonTransformer.getRecords(inputStreamIOSupplier, TESTTOPIC, 1, s3SourceConfig); - assertThat(jsonNodes.size()).isEqualTo(0); + assertThat(jsonNodes.collect(Collectors.toList())).hasSize(0); } @Test void testSerializeJsonDataValid() throws IOException { final InputStream validJsonInputStream = new ByteArrayInputStream( "{\"key\":\"value\"}".getBytes(StandardCharsets.UTF_8)); - final List jsonNodes = jsonTransformer.getRecords(validJsonInputStream, "testtopic", 1, s3SourceConfig); - final byte[] serializedData = jsonTransformer.getValueBytes(jsonNodes.get(0), "testtopic", s3SourceConfig); + final IOSupplier inputStreamIOSupplier = () -> validJsonInputStream; + final Stream jsonNodes = jsonTransformer.getRecords(inputStreamIOSupplier, TESTTOPIC, 1, + s3SourceConfig); + final byte[] serializedData = jsonTransformer.getValueBytes(jsonNodes.findFirst().get(), TESTTOPIC, + s3SourceConfig); final ObjectMapper objectMapper = new ObjectMapper(); final JsonNode expectedData = objectMapper.readTree(serializedData); assertThat(expectedData.get("key").asText()).isEqualTo("value"); } + + @Test + void testGetRecordsWithIOException() throws IOException { + when(inputStreamIOSupplierMock.get()).thenThrow(new IOException("Test IOException")); + final Stream resultStream = jsonTransformer.getRecords(inputStreamIOSupplierMock, "topic", 0, null); + + assertThat(resultStream).isEmpty(); + } + + @Test + void testCustomSpliteratorStreamProcessing() throws IOException { + final String jsonContent = "{\"key\":\"value\"}\n{\"key2\":\"value2\"}"; + final InputStream inputStream = new ByteArrayInputStream(jsonContent.getBytes(StandardCharsets.UTF_8)); + final IOSupplier supplier = () -> inputStream; + + final JsonTransformer.CustomSpliterator spliterator = jsonTransformer.new CustomSpliterator(supplier); + assertThat(spliterator.tryAdvance(jsonNode -> assertThat(jsonNode).isNotNull())).isTrue(); + } + + @Test + void testCustomSpliteratorWithIOExceptionDuringInitialization() throws IOException { + when(inputStreamIOSupplierMock.get()).thenThrow(new IOException("Test IOException during initialization")); + final Stream resultStream = jsonTransformer.getRecords(inputStreamIOSupplierMock, "topic", 0, null); + + assertThat(resultStream).isEmpty(); + } } diff --git a/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/ParquetTransformerTest.java b/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/ParquetTransformerTest.java index 69d7ac493..08f462595 100644 --- a/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/ParquetTransformerTest.java +++ b/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/input/ParquetTransformerTest.java @@ -17,24 +17,32 @@ package io.aiven.kafka.connect.s3.source.input; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.ByteArrayInputStream; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; import io.aiven.kafka.connect.s3.source.testutils.ContentUtils; import com.amazonaws.util.IOUtils; import org.apache.avro.generic.GenericRecord; +import org.apache.commons.io.function.IOSupplier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) @@ -50,11 +58,13 @@ public void setUp() { void testHandleValueDataWithZeroBytes() { final byte[] mockParquetData = new byte[0]; final InputStream inputStream = new ByteArrayInputStream(mockParquetData); + final IOSupplier inputStreamIOSupplier = () -> inputStream; final S3SourceConfig s3SourceConfig = mock(S3SourceConfig.class); final String topic = "test-topic"; final int topicPartition = 0; - final List recs = parquetTransformer.getRecords(inputStream, topic, topicPartition, s3SourceConfig); + final Stream recs = parquetTransformer.getRecords(inputStreamIOSupplier, topic, topicPartition, + s3SourceConfig); assertThat(recs).isEmpty(); } @@ -63,12 +73,15 @@ void testHandleValueDataWithZeroBytes() { void testGetRecordsWithValidData() throws Exception { final byte[] mockParquetData = generateMockParquetData(); final InputStream inputStream = new ByteArrayInputStream(mockParquetData); + final IOSupplier inputStreamIOSupplier = () -> inputStream; final S3SourceConfig s3SourceConfig = mock(S3SourceConfig.class); final String topic = "test-topic"; final int topicPartition = 0; - final List records = parquetTransformer.getRecords(inputStream, topic, topicPartition, s3SourceConfig); + final List records = parquetTransformer + .getRecords(inputStreamIOSupplier, topic, topicPartition, s3SourceConfig) + .collect(Collectors.toList()); assertThat(records).isNotEmpty(); assertThat(records).extracting(record -> ((GenericRecord) record).get("name").toString()) @@ -80,12 +93,15 @@ void testGetRecordsWithValidData() throws Exception { void testGetRecordsWithInvalidData() { final byte[] invalidData = "invalid data".getBytes(StandardCharsets.UTF_8); final InputStream inputStream = new ByteArrayInputStream(invalidData); + final IOSupplier inputStreamIOSupplier = () -> inputStream; + final S3SourceConfig s3SourceConfig = mock(S3SourceConfig.class); final String topic = "test-topic"; final int topicPartition = 0; - final List records = parquetTransformer.getRecords(inputStream, topic, topicPartition, s3SourceConfig); + final Stream records = parquetTransformer.getRecords(inputStreamIOSupplier, topic, topicPartition, + s3SourceConfig); assertThat(records).isEmpty(); } @@ -102,4 +118,31 @@ private byte[] generateMockParquetData() throws IOException { final Path path = ContentUtils.getTmpFilePath("name"); return IOUtils.toByteArray(Files.newInputStream(path)); } + + @Test + void testIOExceptionCreatingTempFile() { + try (var mockStatic = Mockito.mockStatic(File.class)) { + mockStatic.when(() -> File.createTempFile(anyString(), anyString())) + .thenThrow(new IOException("Test IOException for temp file")); + + final IOSupplier inputStreamSupplier = mock(IOSupplier.class); + final Stream resultStream = parquetTransformer.getRecords(inputStreamSupplier, "test-topic", 1, + null); + + assertThat(resultStream).isEmpty(); + } + } + + @Test + void testIOExceptionDuringDataCopy() throws IOException { + try (InputStream inputStreamMock = mock(InputStream.class)) { + when(inputStreamMock.read(any(byte[].class))).thenThrow(new IOException("Test IOException during copy")); + + final IOSupplier inputStreamSupplier = () -> inputStreamMock; + final Stream resultStream = parquetTransformer.getRecords(inputStreamSupplier, "test-topic", 1, + null); + + assertThat(resultStream).isEmpty(); + } + } } diff --git a/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/utils/SourceRecordIteratorTest.java b/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/utils/SourceRecordIteratorTest.java index a2fb31d9f..e6ba44756 100644 --- a/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/utils/SourceRecordIteratorTest.java +++ b/s3-source-connector/src/test/java/io/aiven/kafka/connect/s3/source/utils/SourceRecordIteratorTest.java @@ -30,6 +30,7 @@ import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; +import java.util.stream.Stream; import io.aiven.kafka.connect.s3.source.config.S3SourceConfig; import io.aiven.kafka.connect.s3.source.input.Transformer; @@ -78,8 +79,7 @@ void testIteratorProcessesS3Objects() throws Exception { when(mockS3Client.getObject(anyString(), anyString())).thenReturn(mockS3Object); when(mockS3Object.getObjectContent()).thenReturn(mockInputStream); - when(mockTransformer.getRecords(any(), anyString(), anyInt(), any())) - .thenReturn(Collections.singletonList(new Object())); + when(mockTransformer.getRecords(any(), anyString(), anyInt(), any())).thenReturn(Stream.of(new Object())); final String outStr = "this is a test"; when(mockTransformer.getValueBytes(any(), anyString(), any())) @@ -102,7 +102,6 @@ void testIteratorProcessesS3Objects() throws Exception { assertTrue(iterator.hasNext()); assertNotNull(iterator.next()); } - } private ListObjectsV2Result mockListObjectsResult(final List summaries) {