diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8ba2595..f51e877 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -25,6 +25,9 @@ jobs: uses: actions/setup-java@v1 with: java-version: 1.8 + - name: scalafmt + if: github.event.pull_request.merged == true + run: ./gradlew spotlessCheck - name: Test with Gradle if: github.event.pull_request.merged == true run: ./gradlew test diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4c77bab..661e1a7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,8 @@ jobs: uses: actions/setup-java@v1 with: java-version: 1.8 + - name: scalafmt + run: ./gradlew spotlessCheck - name: Test with Gradle run: ./gradlew test diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 0000000..7a9fe74 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,5 @@ +# https://scalameta.org/scalafmt/#Configuration + +version = "2.3.2" +newlines.alwaysBeforeElseAfterCurlyIf = true +newlines.alwaysBeforeTopLevelStatements = true diff --git a/CHANGELOG.md b/CHANGELOG.md index d1018d4..9c14acd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +0.2.0 (2020-03-10) +================== + +* [Enhancement] [#23](https://github.com/civitaspo/embulk-output-s3_parquet/pull/23) Limit the usage of swapping ContextClassLoader +* [BugFix] [#24](https://github.com/civitaspo/embulk-output-s3_parquet/pull/24) Use basic credentials correctly +* [Enhancement] [#20](https://github.com/civitaspo/embulk-output-s3_parquet/pull/20) Update gradle 4.1 -> 6.1 +* [Enhancement] [#20](https://github.com/civitaspo/embulk-output-s3_parquet/pull/20) Update parquet-{column,common,encoding,hadoop,jackson,tools} 1.10.1 -> 1.11.0 with the latest parquet-format 2.4.0 -> 2.7.0 + * [parquet-format CHANGELOG](https://github.com/apache/parquet-format/blob/master/CHANGES.md) + * [parquet-mr CHANGELOG](https://github.com/apache/parquet-mr/blob/apache-parquet-1.11.0/CHANGES.md#version-1110) +* [Enhancement] [#20](https://github.com/civitaspo/embulk-output-s3_parquet/pull/20) Update aws-java-sdk 1.11.676 -> 1.11.739 +* [Enhancement] [#20](https://github.com/civitaspo/embulk-output-s3_parquet/pull/20) Update embulk 0.9.20 -> 0.9.23 with embulk-deps-{config,buffer} +* [Enhancement] [#19](https://github.com/civitaspo/embulk-output-s3_parquet/pull/19) Use scalafmt instead of the Intellij formatter. +* [Enhancement] [#19](https://github.com/civitaspo/embulk-output-s3_parquet/pull/19) Use scalafmt in CI. +* [Enhancement] [#19](https://github.com/civitaspo/embulk-output-s3_parquet/pull/19) Enable to run examples locally with some prepared scripts. + 0.1.0 (2019-11-17) ================== diff --git a/README.md b/README.md index 000ceaa..1cc751e 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,8 @@ out: ### Run example: ```shell +$ ./run_s3_local.sh +$ ./example/prepare_s3_bucket.sh $ ./gradlew classpath $ embulk run example/config.yml -Ilib ``` @@ -138,8 +140,7 @@ $ embulk run example/config.yml -Ilib ### Run test: ```shell -## Run fake S3 with localstack -$ docker run -it --rm -p 4572:4572 -e SERVICES=s3 localstack/localstack +$ ./run_s3_local.sh $ ./gradlew test ``` diff --git a/build.gradle b/build.gradle index 8f33f8c..7ecb279 100644 --- a/build.gradle +++ b/build.gradle @@ -3,6 +3,7 @@ plugins { id "com.jfrog.bintray" version "1.1" id "com.github.jruby-gradle.base" version "1.5.0" id "com.adarshr.test-logger" version "1.6.0" // For Pretty test logging + id "com.diffplug.gradle.spotless" version "3.27.1" } import com.github.jrubygradle.JRubyExec repositories { @@ -13,29 +14,32 @@ configurations { provided } -version = "0.1.0" +version = "0.2.0" sourceCompatibility = 1.8 targetCompatibility = 1.8 dependencies { - compile "org.embulk:embulk-core:0.9.20" - provided "org.embulk:embulk-core:0.9.20" + compile "org.embulk:embulk-core:0.9.23" + provided "org.embulk:embulk-core:0.9.23" compile 'org.scala-lang:scala-library:2.13.1' ['glue', 's3', 'sts'].each { v -> - compile "com.amazonaws:aws-java-sdk-${v}:1.11.676" + compile "com.amazonaws:aws-java-sdk-${v}:1.11.739" } - ['column', 'common', 'encoding', 'format', 'hadoop', 'jackson'].each { v -> - compile "org.apache.parquet:parquet-${v}:1.10.1" + ['column', 'common', 'encoding', 'hadoop', 'jackson'].each { v -> + compile "org.apache.parquet:parquet-${v}:1.11.0" } + // ref. https://github.com/apache/parquet-mr/blob/apache-parquet-1.11.0/pom.xml#L85 + compile 'org.apache.parquet:parquet-format:2.7.0' compile 'org.apache.hadoop:hadoop-common:2.9.2' compile 'org.xerial.snappy:snappy-java:1.1.7.3' + ['test', 'standards', 'deps-buffer', 'deps-config'].each { v -> + testCompile "org.embulk:embulk-${v}:0.9.23" + } testCompile 'org.scalatest:scalatest_2.13:3.0.8' - testCompile 'org.embulk:embulk-test:0.9.20' - testCompile 'org.embulk:embulk-standards:0.9.20' - testCompile 'org.apache.parquet:parquet-tools:1.10.1' + testCompile 'org.apache.parquet:parquet-tools:1.11.0' testCompile 'org.apache.hadoop:hadoop-client:2.9.2' } @@ -43,6 +47,12 @@ testlogger { theme "mocha" } +spotless { + scala { + scalafmt('2.3.2').configFile('.scalafmt.conf') + } +} + task classpath(type: Copy, dependsOn: ["jar"]) { doFirst { file("classpath").deleteDir() } from (configurations.runtime - configurations.provided + files(jar.archivePath)) diff --git a/example/config.yml b/example/config.yml index 8241c5f..dbeff17 100644 --- a/example/config.yml +++ b/example/config.yml @@ -17,7 +17,9 @@ in: out: type: s3_parquet - bucket: my-bucket + bucket: example + region: us-east-1 + endpoint: http://127.0.0.1:4572 path_prefix: path/to/my-obj. file_ext: snappy.parquet compression_codec: snappy diff --git a/example/prepare_s3_bucket.sh b/example/prepare_s3_bucket.sh new file mode 100755 index 0000000..ab5c17c --- /dev/null +++ b/example/prepare_s3_bucket.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +aws s3 mb s3://example \ + --endpoint-url http://localhost:4572 \ + --region us-east-1 + diff --git a/example/with_catalog.yml b/example/with_catalog.yml index 6431f52..39b4ca0 100644 --- a/example/with_catalog.yml +++ b/example/with_catalog.yml @@ -17,7 +17,9 @@ in: out: type: s3_parquet - bucket: dev-baikal-workspace + bucket: example + region: us-east-1 + endpoint: http://127.0.0.1:4572 path_prefix: path/to/my-obj-2. file_ext: snappy.parquet compression_codec: snappy diff --git a/example/with_logicaltypes.yml b/example/with_logicaltypes.yml index 37f8b46..5970cc6 100644 --- a/example/with_logicaltypes.yml +++ b/example/with_logicaltypes.yml @@ -17,7 +17,9 @@ in: out: type: s3_parquet - bucket: my-bucket + bucket: example + region: us-east-1 + endpoint: http://127.0.0.1:4572 path_prefix: path/to/my-obj-2. file_ext: snappy.parquet compression_codec: snappy diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 7a3265e..f3d88b1 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index f16d266..ba94df8 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.1-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-4.1-bin.zip diff --git a/gradlew b/gradlew index cccdd3d..2fe81a7 100755 --- a/gradlew +++ b/gradlew @@ -1,5 +1,21 @@ #!/usr/bin/env sh +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + ############################################################################## ## ## Gradle start up script for UN*X @@ -28,7 +44,7 @@ APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="" +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" @@ -109,8 +125,8 @@ if $darwin; then GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" fi -# For Cygwin, switch paths to Windows format before running java -if $cygwin ; then +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` JAVACMD=`cygpath --unix "$JAVACMD"` @@ -138,19 +154,19 @@ if $cygwin ; then else eval `echo args$i`="\"$arg\"" fi - i=$((i+1)) + i=`expr $i + 1` done case $i in - (0) set -- ;; - (1) set -- "$args0" ;; - (2) set -- "$args0" "$args1" ;; - (3) set -- "$args0" "$args1" "$args2" ;; - (4) set -- "$args0" "$args1" "$args2" "$args3" ;; - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; esac fi @@ -159,14 +175,9 @@ save () { for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done echo " " } -APP_ARGS=$(save "$@") +APP_ARGS=`save "$@"` # Collect all arguments for the java command, following the shell quoting and substitution rules eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" -# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong -if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then - cd "$(dirname "$0")" -fi - exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index e95643d..24467a1 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,3 +1,19 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + @if "%DEBUG%" == "" @echo off @rem ########################################################################## @rem @@ -14,7 +30,7 @@ set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome diff --git a/run_s3_local.sh b/run_s3_local.sh new file mode 100755 index 0000000..436e0d4 --- /dev/null +++ b/run_s3_local.sh @@ -0,0 +1,7 @@ +#!/bin/sh + +docker run -it -d --rm \ + -p 4572:4572 \ + -e SERVICES=s3 \ + localstack/localstack + diff --git a/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala b/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala index 27ff4de..e81a7d5 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala @@ -1,202 +1,250 @@ package org.embulk.output.s3_parquet - import java.util.{Optional, Map => JMap} -import com.amazonaws.services.glue.model.{Column, CreateTableRequest, DeleteTableRequest, GetTableRequest, SerDeInfo, StorageDescriptor, TableInput} +import com.amazonaws.services.glue.model.{ + Column, + CreateTableRequest, + DeleteTableRequest, + GetTableRequest, + SerDeInfo, + StorageDescriptor, + TableInput +} import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.embulk.config.{Config, ConfigDefault, ConfigException} import org.embulk.output.s3_parquet.aws.Aws import org.embulk.output.s3_parquet.CatalogRegistrator.ColumnOptions import org.embulk.spi.Schema -import org.embulk.spi.`type`.{BooleanType, DoubleType, JsonType, LongType, StringType, TimestampType, Type} +import org.embulk.spi.`type`.{ + BooleanType, + DoubleType, + JsonType, + LongType, + StringType, + TimestampType, + Type +} import org.slf4j.{Logger, LoggerFactory} import scala.jdk.CollectionConverters._ import scala.util.Try - -object CatalogRegistrator -{ - trait Task - extends org.embulk.config.Task - { - @Config("catalog_id") - @ConfigDefault("null") - def getCatalogId: Optional[String] - - @Config("database") - def getDatabase: String - - @Config("table") - def getTable: String - - @Config("column_options") - @ConfigDefault("{}") - def getColumnOptions: JMap[String, ColumnOptions] - - @Config("operation_if_exists") - @ConfigDefault("\"delete\"") - def getOperationIfExists: String - } - - trait ColumnOptions - { - @Config("type") - def getType: String - } - - def apply(aws: Aws, - task: Task, - schema: Schema, - location: String, - compressionCodec: CompressionCodecName, - loggerOption: Option[Logger] = None, - parquetColumnLogicalTypes: Map[String, String] = Map.empty): CatalogRegistrator = - { - new CatalogRegistrator(aws, task, schema, location, compressionCodec, loggerOption, parquetColumnLogicalTypes) - } +object CatalogRegistrator { + + trait Task extends org.embulk.config.Task { + + @Config("catalog_id") + @ConfigDefault("null") + def getCatalogId: Optional[String] + + @Config("database") + def getDatabase: String + + @Config("table") + def getTable: String + + @Config("column_options") + @ConfigDefault("{}") + def getColumnOptions: JMap[String, ColumnOptions] + + @Config("operation_if_exists") + @ConfigDefault("\"delete\"") + def getOperationIfExists: String + } + + trait ColumnOptions { + + @Config("type") + def getType: String + } + + def apply( + aws: Aws, + task: Task, + schema: Schema, + location: String, + compressionCodec: CompressionCodecName, + loggerOption: Option[Logger] = None, + parquetColumnLogicalTypes: Map[String, String] = Map.empty + ): CatalogRegistrator = { + new CatalogRegistrator( + aws, + task, + schema, + location, + compressionCodec, + loggerOption, + parquetColumnLogicalTypes + ) + } } -class CatalogRegistrator(aws: Aws, - task: CatalogRegistrator.Task, - schema: Schema, - location: String, - compressionCodec: CompressionCodecName, - loggerOption: Option[Logger] = None, - parquetColumnLogicalTypes: Map[String, String] = Map.empty) -{ - val logger: Logger = loggerOption.getOrElse(LoggerFactory.getLogger(classOf[CatalogRegistrator])) - - def run(): Unit = - { - if (doesTableExists()) { - task.getOperationIfExists match { - case "skip" => - logger.info(s"Skip to register the table: ${task.getDatabase}.${task.getTable}") - return - - case "delete" => - logger.info(s"Delete the table: ${task.getDatabase}.${task.getTable}") - deleteTable() - - case unknown => - throw new ConfigException(s"Unsupported operation: $unknown") - } - } - registerNewParquetTable() - showNewTableInfo() +class CatalogRegistrator( + aws: Aws, + task: CatalogRegistrator.Task, + schema: Schema, + location: String, + compressionCodec: CompressionCodecName, + loggerOption: Option[Logger] = None, + parquetColumnLogicalTypes: Map[String, String] = Map.empty +) { + + val logger: Logger = + loggerOption.getOrElse(LoggerFactory.getLogger(classOf[CatalogRegistrator])) + + def run(): Unit = { + if (doesTableExists()) { + task.getOperationIfExists match { + case "skip" => + logger.info( + s"Skip to register the table: ${task.getDatabase}.${task.getTable}" + ) + return + + case "delete" => + logger.info(s"Delete the table: ${task.getDatabase}.${task.getTable}") + deleteTable() + + case unknown => + throw new ConfigException(s"Unsupported operation: $unknown") + } } - - def showNewTableInfo(): Unit = - { - val req = new GetTableRequest() - task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) - req.setDatabaseName(task.getDatabase) - req.setName(task.getTable) - - val t = aws.withGlue(_.getTable(req)).getTable - logger.info(s"Created a table: ${t.toString}") - } - - def doesTableExists(): Boolean = - { - val req = new GetTableRequest() - task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) - req.setDatabaseName(task.getDatabase) - req.setName(task.getTable) - - Try(aws.withGlue(_.getTable(req))).isSuccess + registerNewParquetTable() + showNewTableInfo() + } + + def showNewTableInfo(): Unit = { + val req = new GetTableRequest() + task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) + req.setDatabaseName(task.getDatabase) + req.setName(task.getTable) + + val t = aws.withGlue(_.getTable(req)).getTable + logger.info(s"Created a table: ${t.toString}") + } + + def doesTableExists(): Boolean = { + val req = new GetTableRequest() + task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) + req.setDatabaseName(task.getDatabase) + req.setName(task.getTable) + + Try(aws.withGlue(_.getTable(req))).isSuccess + } + + def deleteTable(): Unit = { + val req = new DeleteTableRequest() + task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) + req.setDatabaseName(task.getDatabase) + req.setName(task.getTable) + aws.withGlue(_.deleteTable(req)) + } + + def registerNewParquetTable(): Unit = { + logger.info(s"Create a new table: ${task.getDatabase}.${task.getTable}") + val req = new CreateTableRequest() + task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) + req.setDatabaseName(task.getDatabase) + req.setTableInput( + new TableInput() + .withName(task.getTable) + .withDescription("Created by embulk-output-s3_parquet") + .withTableType("EXTERNAL_TABLE") + .withParameters( + Map( + "EXTERNAL" -> "TRUE", + "classification" -> "parquet", + "parquet.compression" -> compressionCodec.name() + ).asJava + ) + .withStorageDescriptor( + new StorageDescriptor() + .withColumns(getGlueSchema: _*) + .withLocation(location) + .withCompressed(isCompressed) + .withInputFormat( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat" + ) + .withOutputFormat( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat" + ) + .withSerdeInfo( + new SerDeInfo() + .withSerializationLibrary( + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + ) + .withParameters(Map("serialization.format" -> "1").asJava) + ) + ) + ) + aws.withGlue(_.createTable(req)) + } + + private def getGlueSchema: Seq[Column] = { + val columnOptions: Map[String, ColumnOptions] = + task.getColumnOptions.asScala.toMap + schema.getColumns.asScala.toSeq.map { c => + val cType: String = + if (columnOptions.contains(c.getName)) columnOptions(c.getName).getType + else if (parquetColumnLogicalTypes.contains(c.getName)) + convertParquetLogicalTypeToGlueType( + parquetColumnLogicalTypes(c.getName) + ) + else convertEmbulkTypeToGlueType(c.getType) + new Column() + .withName(c.getName) + .withType(cType) } - - def deleteTable(): Unit = - { - val req = new DeleteTableRequest() - task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) - req.setDatabaseName(task.getDatabase) - req.setName(task.getTable) - aws.withGlue(_.deleteTable(req)) - } - - def registerNewParquetTable(): Unit = - { - logger.info(s"Create a new table: ${task.getDatabase}.${task.getTable}") - val req = new CreateTableRequest() - task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) - req.setDatabaseName(task.getDatabase) - req.setTableInput(new TableInput() - .withName(task.getTable) - .withDescription("Created by embulk-output-s3_parquet") - .withTableType("EXTERNAL_TABLE") - .withParameters(Map("EXTERNAL" -> "TRUE", - "classification" -> "parquet", - "parquet.compression" -> compressionCodec.name()).asJava) - .withStorageDescriptor(new StorageDescriptor() - .withColumns(getGlueSchema: _*) - .withLocation(location) - .withCompressed(isCompressed) - .withInputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat") - .withOutputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat") - .withSerdeInfo(new SerDeInfo() - .withSerializationLibrary("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe") - .withParameters(Map("serialization.format" -> "1").asJava) - ) - ) - ) - aws.withGlue(_.createTable(req)) + } + + private def convertParquetLogicalTypeToGlueType(t: String): String = { + t match { + case "timestamp-millis" => "timestamp" + case "timestamp-micros" => + "bigint" // Glue cannot recognize timestamp-micros. + case "int8" => "tinyint" + case "int16" => "smallint" + case "int32" => "int" + case "int64" => "bigint" + case "uint8" => + "smallint" // Glue tinyint is a minimum value of -2^7 and a maximum value of 2^7-1 + case "uint16" => + "int" // Glue smallint is a minimum value of -2^15 and a maximum value of 2^15-1. + case "uint32" => + "bigint" // Glue int is a minimum value of-2^31 and a maximum value of 2^31-1. + case "uint64" => + throw new ConfigException( + "Cannot convert uint64 to Glue data types automatically" + + " because the Glue bigint supports a 64-bit signed integer." + + " Please use `catalog.column_options` to define the type." + ) + case "json" => "string" + case _ => + throw new ConfigException( + s"Unsupported a parquet logical type: $t. Please use `catalog.column_options` to define the type." + ) } - private def getGlueSchema: Seq[Column] = - { - val columnOptions: Map[String, ColumnOptions] = task.getColumnOptions.asScala.toMap - schema.getColumns.asScala.toSeq.map { c => - val cType: String = - if (columnOptions.contains(c.getName)) columnOptions(c.getName).getType - else if (parquetColumnLogicalTypes.contains(c.getName)) convertParquetLogicalTypeToGlueType(parquetColumnLogicalTypes(c.getName)) - else convertEmbulkTypeToGlueType(c.getType) - new Column() - .withName(c.getName) - .withType(cType) - } + } + + private def convertEmbulkTypeToGlueType(t: Type): String = { + t match { + case _: BooleanType => "boolean" + case _: LongType => "bigint" + case _: DoubleType => "double" + case _: StringType => "string" + case _: TimestampType => "string" + case _: JsonType => "string" + case unknown => + throw new ConfigException( + s"Unsupported embulk type: ${unknown.getName}" + ) } + } - private def convertParquetLogicalTypeToGlueType(t: String): String = - { - t match { - case "timestamp-millis" => "timestamp" - case "timestamp-micros" => "bigint" // Glue cannot recognize timestamp-micros. - case "int8" => "tinyint" - case "int16" => "smallint" - case "int32" => "int" - case "int64" => "bigint" - case "uint8" => "smallint" // Glue tinyint is a minimum value of -2^7 and a maximum value of 2^7-1 - case "uint16" => "int" // Glue smallint is a minimum value of -2^15 and a maximum value of 2^15-1. - case "uint32" => "bigint" // Glue int is a minimum value of-2^31 and a maximum value of 2^31-1. - case "uint64" => throw new ConfigException("Cannot convert uint64 to Glue data types automatically" + - " because the Glue bigint supports a 64-bit signed integer." + - " Please use `catalog.column_options` to define the type.") - case "json" => "string" - case _ => throw new ConfigException(s"Unsupported a parquet logical type: $t. Please use `catalog.column_options` to define the type.") - } - - } - - private def convertEmbulkTypeToGlueType(t: Type): String = - { - t match { - case _: BooleanType => "boolean" - case _: LongType => "bigint" - case _: DoubleType => "double" - case _: StringType => "string" - case _: TimestampType => "string" - case _: JsonType => "string" - case unknown => throw new ConfigException(s"Unsupported embulk type: ${unknown.getName}") - } - } - - private def isCompressed: Boolean = - { - !compressionCodec.equals(CompressionCodecName.UNCOMPRESSED) - } + private def isCompressed: Boolean = { + !compressionCodec.equals(CompressionCodecName.UNCOMPRESSED) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/ContextClassLoaderSwapper.scala b/src/main/scala/org/embulk/output/s3_parquet/ContextClassLoaderSwapper.scala new file mode 100644 index 0000000..cadd27c --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/ContextClassLoaderSwapper.scala @@ -0,0 +1,18 @@ +package org.embulk.output.s3_parquet + +// WARNING: This object should be used for limited purposes only. +object ContextClassLoaderSwapper { + + def using[A](klass: Class[_])(f: => A): A = { + val currentTread = Thread.currentThread() + val original = currentTread.getContextClassLoader + val target = klass.getClassLoader + currentTread.setContextClassLoader(target) + try f + finally currentTread.setContextClassLoader(original) + } + + def usingPluginClass[A](f: => A): A = { + using(classOf[S3ParquetOutputPlugin])(f) + } +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala index 53b623f..cfa109c 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala @@ -1,18 +1,44 @@ package org.embulk.output.s3_parquet - import java.nio.file.{Files, Paths} -import java.util.{IllegalFormatException, Locale, Optional, List => JList, Map => JMap} +import java.util.{ + IllegalFormatException, + Locale, + Optional, + List => JList, + Map => JMap +} import com.amazonaws.services.s3.model.CannedAccessControlList import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.ParquetWriter import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.embulk.config.{Config, ConfigDefault, ConfigDiff, ConfigException, ConfigSource, Task, TaskReport, TaskSource} -import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ColumnOptionTask, PluginTask} +import org.embulk.config.{ + Config, + ConfigDefault, + ConfigDiff, + ConfigException, + ConfigSource, + Task, + TaskReport, + TaskSource +} +import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ + ColumnOptionTask, + PluginTask +} import org.embulk.output.s3_parquet.aws.Aws -import org.embulk.output.s3_parquet.parquet.{LogicalTypeHandlerStore, ParquetFileWriter} -import org.embulk.spi.{Exec, OutputPlugin, PageReader, Schema, TransactionalPageOutput} +import org.embulk.output.s3_parquet.parquet.{ + LogicalTypeHandlerStore, + ParquetFileWriter +} +import org.embulk.spi.{ + Exec, + OutputPlugin, + PageReader, + Schema, + TransactionalPageOutput +} import org.embulk.spi.time.TimestampFormatter import org.embulk.spi.time.TimestampFormatter.TimestampColumnOption import org.embulk.spi.util.Timestamps @@ -21,239 +47,302 @@ import org.slf4j.{Logger, LoggerFactory} import scala.jdk.CollectionConverters._ import scala.util.chaining._ +object S3ParquetOutputPlugin { -object S3ParquetOutputPlugin -{ + trait PluginTask extends Task with TimestampFormatter.Task with Aws.Task { - trait PluginTask - extends Task - with TimestampFormatter.Task - with Aws.Task - { + @Config("bucket") + def getBucket: String - @Config("bucket") - def getBucket: String + @Config("path_prefix") + @ConfigDefault("\"\"") + def getPathPrefix: String - @Config("path_prefix") - @ConfigDefault("\"\"") - def getPathPrefix: String + @Config("sequence_format") + @ConfigDefault("\"%03d.%02d.\"") + def getSequenceFormat: String - @Config("sequence_format") - @ConfigDefault("\"%03d.%02d.\"") - def getSequenceFormat: String + @Config("file_ext") + @ConfigDefault("\"parquet\"") + def getFileExt: String - @Config("file_ext") - @ConfigDefault("\"parquet\"") - def getFileExt: String + @Config("compression_codec") + @ConfigDefault("\"uncompressed\"") + def getCompressionCodecString: String - @Config("compression_codec") - @ConfigDefault("\"uncompressed\"") - def getCompressionCodecString: String + def setCompressionCodec(v: CompressionCodecName): Unit - def setCompressionCodec(v: CompressionCodecName): Unit + def getCompressionCodec: CompressionCodecName - def getCompressionCodec: CompressionCodecName + @Config("column_options") + @ConfigDefault("{}") + def getColumnOptions: JMap[String, ColumnOptionTask] - @Config("column_options") - @ConfigDefault("{}") - def getColumnOptions: JMap[String, ColumnOptionTask] + @Config("canned_acl") + @ConfigDefault("\"private\"") + def getCannedAclString: String - @Config("canned_acl") - @ConfigDefault("\"private\"") - def getCannedAclString: String + def setCannedAcl(v: CannedAccessControlList): Unit - def setCannedAcl(v: CannedAccessControlList): Unit + def getCannedAcl: CannedAccessControlList - def getCannedAcl: CannedAccessControlList + @Config("block_size") + @ConfigDefault("null") + def getBlockSize: Optional[Int] - @Config("block_size") - @ConfigDefault("null") - def getBlockSize: Optional[Int] + @Config("page_size") + @ConfigDefault("null") + def getPageSize: Optional[Int] - @Config("page_size") - @ConfigDefault("null") - def getPageSize: Optional[Int] + @Config("max_padding_size") + @ConfigDefault("null") + def getMaxPaddingSize: Optional[Int] - @Config("max_padding_size") - @ConfigDefault("null") - def getMaxPaddingSize: Optional[Int] + @Config("enable_dictionary_encoding") + @ConfigDefault("null") + def getEnableDictionaryEncoding: Optional[Boolean] - @Config("enable_dictionary_encoding") - @ConfigDefault("null") - def getEnableDictionaryEncoding: Optional[Boolean] + @Config("buffer_dir") + @ConfigDefault("null") + def getBufferDir: Optional[String] - @Config("buffer_dir") - @ConfigDefault("null") - def getBufferDir: Optional[String] + @Config("catalog") + @ConfigDefault("null") + def getCatalog: Optional[CatalogRegistrator.Task] - @Config("catalog") - @ConfigDefault("null") - def getCatalog: Optional[CatalogRegistrator.Task] + @Config("type_options") + @ConfigDefault("{}") + def getTypeOptions: JMap[String, TypeOptionTask] + } - @Config("type_options") - @ConfigDefault("{}") - def getTypeOptions: JMap[String, TypeOptionTask] - } + trait ColumnOptionTask + extends Task + with TimestampColumnOption + with LogicalTypeOption - trait ColumnOptionTask - extends Task with TimestampColumnOption with LogicalTypeOption + trait TypeOptionTask extends Task with LogicalTypeOption - trait TypeOptionTask - extends Task with LogicalTypeOption + trait LogicalTypeOption { - trait LogicalTypeOption - { - @Config("logical_type") - def getLogicalType: Optional[String] - } + @Config("logical_type") + def getLogicalType: Optional[String] + } } -class S3ParquetOutputPlugin - extends OutputPlugin -{ - - val logger: Logger = LoggerFactory.getLogger(classOf[S3ParquetOutputPlugin]) - - private def withPluginContextClassLoader[A](f: => A): A = - { - val original: ClassLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(classOf[S3ParquetOutputPlugin].getClassLoader) - try f - finally Thread.currentThread.setContextClassLoader(original) - } - - override def transaction(config: ConfigSource, - schema: Schema, - taskCount: Int, - control: OutputPlugin.Control): ConfigDiff = - { - val task: PluginTask = config.loadConfig(classOf[PluginTask]) - - withPluginContextClassLoader { - configure(task, schema) - control.run(task.dump) - } - task.getCatalog.ifPresent { catalog => - val location = s"s3://${task.getBucket}/${task.getPathPrefix.replaceFirst("(.*/)[^/]+$", "$1")}" - val parquetColumnLogicalTypes: Map[String, String] = Map.newBuilder[String, String].pipe {builder => - val cOptions = task.getColumnOptions.asScala - val tOptions = task.getTypeOptions.asScala - schema.getColumns.asScala.foreach {c => - cOptions.get(c.getName) - if (cOptions.contains(c.getName) && cOptions(c.getName).getLogicalType.isPresent) { - builder.addOne(c.getName -> cOptions(c.getName).getLogicalType.get()) - } - else if (tOptions.contains(c.getType.getName) && tOptions(c.getType.getName).getLogicalType.isPresent) { - builder.addOne(c.getName -> tOptions(c.getType.getName).getLogicalType.get()) - } - } - builder.result() +class S3ParquetOutputPlugin extends OutputPlugin { + + val logger: Logger = LoggerFactory.getLogger(classOf[S3ParquetOutputPlugin]) + + override def transaction( + config: ConfigSource, + schema: Schema, + taskCount: Int, + control: OutputPlugin.Control + ): ConfigDiff = { + val task: PluginTask = config.loadConfig(classOf[PluginTask]) + + configure(task, schema) + control.run(task.dump) + + task.getCatalog.ifPresent { catalog => + val location = + s"s3://${task.getBucket}/${task.getPathPrefix.replaceFirst("(.*/)[^/]+$", "$1")}" + val parquetColumnLogicalTypes: Map[String, String] = + Map.newBuilder[String, String].pipe { builder => + val cOptions = task.getColumnOptions.asScala + val tOptions = task.getTypeOptions.asScala + schema.getColumns.asScala.foreach { c => + cOptions.get(c.getName) + if (cOptions + .contains(c.getName) && cOptions(c.getName).getLogicalType.isPresent) { + builder + .addOne(c.getName -> cOptions(c.getName).getLogicalType.get()) } - val cr = CatalogRegistrator(aws = Aws(task), - task = catalog, - schema = schema, - location = location, - compressionCodec = task.getCompressionCodec, - parquetColumnLogicalTypes = parquetColumnLogicalTypes) - cr.run() - } - - Exec.newConfigDiff - } - - private def configure(task: PluginTask, - schema: Schema): Unit = - { - // sequence_format - try String.format(task.getSequenceFormat, 0: Integer, 0: Integer) - catch { - case e: IllegalFormatException => throw new ConfigException(s"Invalid sequence_format: ${task.getSequenceFormat}", e) - } - - // compression_codec - CompressionCodecName.values().find(v => v.name().toLowerCase(Locale.ENGLISH).equals(task.getCompressionCodecString)) match { - case Some(v) => task.setCompressionCodec(v) - case None => - val unsupported: String = task.getCompressionCodecString - val supported: String = CompressionCodecName.values().map(v => s"'${v.name().toLowerCase}'").mkString(", ") - throw new ConfigException(s"'$unsupported' is unsupported: `compression_codec` must be one of [$supported].") - } - - // column_options - task.getColumnOptions.forEach { (k: String, - opt: ColumnOptionTask) => - val c = schema.lookupColumn(k) - val useTimestampOption = opt.getFormat.isPresent || opt.getTimeZoneId.isPresent - if (!c.getType.getName.equals("timestamp") && useTimestampOption) { - throw new ConfigException(s"column:$k is not 'timestamp' type.") + else if (tOptions.contains(c.getType.getName) && tOptions( + c.getType.getName + ).getLogicalType.isPresent) { + builder.addOne( + c.getName -> tOptions(c.getType.getName).getLogicalType.get() + ) } + } + builder.result() } + val cr = CatalogRegistrator( + aws = Aws(task), + task = catalog, + schema = schema, + location = location, + compressionCodec = task.getCompressionCodec, + parquetColumnLogicalTypes = parquetColumnLogicalTypes + ) + cr.run() + } - // canned_acl - CannedAccessControlList.values().find(v => v.toString.equals(task.getCannedAclString)) match { - case Some(v) => task.setCannedAcl(v) - case None => - val unsupported: String = task.getCannedAclString - val supported: String = CannedAccessControlList.values().map(v => s"'${v.toString}'").mkString(", ") - throw new ConfigException(s"'$unsupported' is unsupported: `canned_acl` must be one of [$supported].") - } + Exec.newConfigDiff + } + + private def configure(task: PluginTask, schema: Schema): Unit = { + // sequence_format + try String.format(task.getSequenceFormat, 0: Integer, 0: Integer) + catch { + case e: IllegalFormatException => + throw new ConfigException( + s"Invalid sequence_format: ${task.getSequenceFormat}", + e + ) } - override def resume(taskSource: TaskSource, - schema: Schema, - taskCount: Int, - control: OutputPlugin.Control): ConfigDiff = - { - throw new UnsupportedOperationException("s3_parquet output plugin does not support resuming") + // compression_codec + CompressionCodecName + .values() + .find(v => + v.name() + .toLowerCase(Locale.ENGLISH) + .equals(task.getCompressionCodecString) + ) match { + case Some(v) => task.setCompressionCodec(v) + case None => + val unsupported: String = task.getCompressionCodecString + val supported: String = CompressionCodecName + .values() + .map(v => s"'${v.name().toLowerCase}'") + .mkString(", ") + throw new ConfigException( + s"'$unsupported' is unsupported: `compression_codec` must be one of [$supported]." + ) } - override def cleanup(taskSource: TaskSource, - schema: Schema, - taskCount: Int, - successTaskReports: JList[TaskReport]): Unit = - { - successTaskReports.forEach { tr => - logger.info( - s"Created: s3://${tr.get(classOf[String], "bucket")}/${tr.get(classOf[String], "key")}, " - + s"version_id: ${tr.get(classOf[String], "version_id", null)}, " - + s"etag: ${tr.get(classOf[String], "etag", null)}") - } + // column_options + task.getColumnOptions.forEach { (k: String, opt: ColumnOptionTask) => + val c = schema.lookupColumn(k) + val useTimestampOption = opt.getFormat.isPresent || opt.getTimeZoneId.isPresent + if (!c.getType.getName.equals("timestamp") && useTimestampOption) { + throw new ConfigException(s"column:$k is not 'timestamp' type.") + } } - override def open(taskSource: TaskSource, - schema: Schema, - taskIndex: Int): TransactionalPageOutput = - { - val task = taskSource.loadTask(classOf[PluginTask]) - val bufferDir: String = task.getBufferDir.orElse(Files.createTempDirectory("embulk-output-s3_parquet-").toString) - val bufferFile: String = Paths.get(bufferDir, s"embulk-output-s3_parquet-task-$taskIndex-0.parquet").toString - val destS3bucket: String = task.getBucket - val destS3Key: String = task.getPathPrefix + String.format(task.getSequenceFormat, taskIndex: Integer, 0: Integer) + task.getFileExt - - - val pageReader: PageReader = new PageReader(schema) - val aws: Aws = Aws(task) - val timestampFormatters: Seq[TimestampFormatter] = Timestamps.newTimestampColumnFormatters(task, schema, task.getColumnOptions).toSeq - val logicalTypeHandlers = LogicalTypeHandlerStore.fromEmbulkOptions(task.getTypeOptions, task.getColumnOptions) - val parquetWriter: ParquetWriter[PageReader] = ParquetFileWriter.builder() - .withPath(bufferFile) - .withSchema(schema) - .withLogicalTypeHandlers(logicalTypeHandlers) - .withTimestampFormatters(timestampFormatters) - .withCompressionCodec(task.getCompressionCodec) - .withDictionaryEncoding(task.getEnableDictionaryEncoding.orElse(ParquetProperties.DEFAULT_IS_DICTIONARY_ENABLED)) - .withDictionaryPageSize(task.getPageSize.orElse(ParquetProperties.DEFAULT_DICTIONARY_PAGE_SIZE)) - .withMaxPaddingSize(task.getMaxPaddingSize.orElse(ParquetWriter.MAX_PADDING_SIZE_DEFAULT)) - .withPageSize(task.getPageSize.orElse(ParquetProperties.DEFAULT_PAGE_SIZE)) - .withRowGroupSize(task.getBlockSize.orElse(ParquetWriter.DEFAULT_BLOCK_SIZE)) - .withValidation(ParquetWriter.DEFAULT_IS_VALIDATING_ENABLED) - .withWriteMode(org.apache.parquet.hadoop.ParquetFileWriter.Mode.CREATE) - .withWriterVersion(ParquetProperties.DEFAULT_WRITER_VERSION) - .build() - - logger.info(s"Local Buffer File: $bufferFile, Destination: s3://$destS3bucket/$destS3Key") - - S3ParquetPageOutput(bufferFile, pageReader, parquetWriter, aws, destS3bucket, destS3Key) + // canned_acl + CannedAccessControlList + .values() + .find(v => v.toString.equals(task.getCannedAclString)) match { + case Some(v) => task.setCannedAcl(v) + case None => + val unsupported: String = task.getCannedAclString + val supported: String = CannedAccessControlList + .values() + .map(v => s"'${v.toString}'") + .mkString(", ") + throw new ConfigException( + s"'$unsupported' is unsupported: `canned_acl` must be one of [$supported]." + ) + } + } + + override def resume( + taskSource: TaskSource, + schema: Schema, + taskCount: Int, + control: OutputPlugin.Control + ): ConfigDiff = { + throw new UnsupportedOperationException( + "s3_parquet output plugin does not support resuming" + ) + } + + override def cleanup( + taskSource: TaskSource, + schema: Schema, + taskCount: Int, + successTaskReports: JList[TaskReport] + ): Unit = { + successTaskReports.forEach { tr => + logger.info( + s"Created: s3://${tr.get(classOf[String], "bucket")}/${tr.get(classOf[String], "key")}, " + + s"version_id: ${tr.get(classOf[String], "version_id", null)}, " + + s"etag: ${tr.get(classOf[String], "etag", null)}" + ) } + } + + override def open( + taskSource: TaskSource, + schema: Schema, + taskIndex: Int + ): TransactionalPageOutput = { + val task = taskSource.loadTask(classOf[PluginTask]) + val bufferDir: String = task.getBufferDir.orElse( + Files.createTempDirectory("embulk-output-s3_parquet-").toString + ) + val bufferFile: String = Paths + .get(bufferDir, s"embulk-output-s3_parquet-task-$taskIndex-0.parquet") + .toString + val destS3bucket: String = task.getBucket + val destS3Key: String = task.getPathPrefix + String.format( + task.getSequenceFormat, + taskIndex: Integer, + 0: Integer + ) + task.getFileExt + + val pageReader: PageReader = new PageReader(schema) + val aws: Aws = Aws(task) + val timestampFormatters: Seq[TimestampFormatter] = Timestamps + .newTimestampColumnFormatters(task, schema, task.getColumnOptions) + .toSeq + val logicalTypeHandlers = LogicalTypeHandlerStore.fromEmbulkOptions( + task.getTypeOptions, + task.getColumnOptions + ) + val parquetWriter: ParquetWriter[PageReader] = + ContextClassLoaderSwapper.usingPluginClass { + ParquetFileWriter + .builder() + .withPath(bufferFile) + .withSchema(schema) + .withLogicalTypeHandlers(logicalTypeHandlers) + .withTimestampFormatters(timestampFormatters) + .withCompressionCodec(task.getCompressionCodec) + .withDictionaryEncoding( + task.getEnableDictionaryEncoding.orElse( + ParquetProperties.DEFAULT_IS_DICTIONARY_ENABLED + ) + ) + .withDictionaryPageSize( + task.getPageSize.orElse( + ParquetProperties.DEFAULT_DICTIONARY_PAGE_SIZE + ) + ) + .withMaxPaddingSize( + task.getMaxPaddingSize.orElse( + ParquetWriter.MAX_PADDING_SIZE_DEFAULT + ) + ) + .withPageSize( + task.getPageSize.orElse(ParquetProperties.DEFAULT_PAGE_SIZE) + ) + .withRowGroupSize( + task.getBlockSize.orElse(ParquetWriter.DEFAULT_BLOCK_SIZE) + ) + .withValidation(ParquetWriter.DEFAULT_IS_VALIDATING_ENABLED) + .withWriteMode( + org.apache.parquet.hadoop.ParquetFileWriter.Mode.CREATE + ) + .withWriterVersion(ParquetProperties.DEFAULT_WRITER_VERSION) + .build() + } + + logger.info( + s"Local Buffer File: $bufferFile, Destination: s3://$destS3bucket/$destS3Key" + ) + + S3ParquetPageOutput( + bufferFile, + pageReader, + parquetWriter, + aws, + destS3bucket, + destS3Key + ) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala index e3e0776..eb0cc22 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala @@ -1,6 +1,5 @@ package org.embulk.output.s3_parquet - import java.io.File import java.nio.file.{Files, Paths} @@ -11,63 +10,61 @@ import org.embulk.config.TaskReport import org.embulk.output.s3_parquet.aws.Aws import org.embulk.spi.{Exec, Page, PageReader, TransactionalPageOutput} +case class S3ParquetPageOutput( + outputLocalFile: String, + reader: PageReader, + writer: ParquetWriter[PageReader], + aws: Aws, + destBucket: String, + destKey: String +) extends TransactionalPageOutput { -case class S3ParquetPageOutput(outputLocalFile: String, - reader: PageReader, - writer: ParquetWriter[PageReader], - aws: Aws, - destBucket: String, - destKey: String) - extends TransactionalPageOutput -{ - - private var isClosed: Boolean = false + private var isClosed: Boolean = false - override def add(page: Page): Unit = - { - reader.setPage(page) - while (reader.nextRecord()) { - writer.write(reader) - } + override def add(page: Page): Unit = { + reader.setPage(page) + while (reader.nextRecord()) { + writer.write(reader) } + } - override def finish(): Unit = - { - } + override def finish(): Unit = {} - override def close(): Unit = - { - synchronized { - if (!isClosed) { - writer.close() - isClosed = true - } + override def close(): Unit = { + synchronized { + if (!isClosed) { + ContextClassLoaderSwapper.usingPluginClass { + writer.close() } + isClosed = true + } } + } - override def abort(): Unit = - { - close() - cleanup() - } + override def abort(): Unit = { + close() + cleanup() + } - override def commit(): TaskReport = - { - close() - val result: UploadResult = aws.withTransferManager { xfer: TransferManager => - val upload: Upload = xfer.upload(destBucket, destKey, new File(outputLocalFile)) - upload.waitForUploadResult() - } - cleanup() - Exec.newTaskReport() - .set("bucket", result.getBucketName) - .set("key", result.getKey) - .set("etag", result.getETag) - .set("version_id", result.getVersionId) + override def commit(): TaskReport = { + close() + val result: UploadResult = ContextClassLoaderSwapper.usingPluginClass { + aws.withTransferManager { xfer: TransferManager => + val upload: Upload = + xfer.upload(destBucket, destKey, new File(outputLocalFile)) + upload.waitForUploadResult() + } } + cleanup() + Exec + .newTaskReport() + .set("bucket", result.getBucketName) + .set("key", result.getKey) + .set("etag", result.getETag) + .set("version_id", result.getVersionId) + } - private def cleanup(): Unit = - { - Files.delete(Paths.get(outputLocalFile)) - } + private def cleanup(): Unit = { + Files.delete(Paths.get(outputLocalFile)) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/Aws.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/Aws.scala index a388aaa..de08024 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/Aws.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/Aws.scala @@ -1,63 +1,59 @@ package org.embulk.output.s3_parquet.aws - import com.amazonaws.client.builder.AwsClientBuilder import com.amazonaws.services.glue.{AWSGlue, AWSGlueClientBuilder} import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder} -import com.amazonaws.services.s3.transfer.{TransferManager, TransferManagerBuilder} - - -object Aws -{ - - trait Task - extends AwsCredentials.Task - with AwsEndpointConfiguration.Task - with AwsClientConfiguration.Task - with AwsS3Configuration.Task - - def apply(task: Task): Aws = - { - new Aws(task) - } - +import com.amazonaws.services.s3.transfer.{ + TransferManager, + TransferManagerBuilder } -class Aws(task: Aws.Task) -{ - - def withS3[A](f: AmazonS3 => A): A = - { - val builder: AmazonS3ClientBuilder = AmazonS3ClientBuilder.standard() - AwsS3Configuration(task).configureAmazonS3ClientBuilder(builder) - val svc = createService(builder) - try f(svc) - finally svc.shutdown() - } +object Aws { - def withTransferManager[A](f: TransferManager => A): A = - { - withS3 { s3 => - val svc = TransferManagerBuilder.standard().withS3Client(s3).build() - try f(svc) - finally svc.shutdownNow(false) - } - } + trait Task + extends AwsCredentials.Task + with AwsEndpointConfiguration.Task + with AwsClientConfiguration.Task + with AwsS3Configuration.Task - def withGlue[A](f: AWSGlue => A): A = - { - val builder: AWSGlueClientBuilder = AWSGlueClientBuilder.standard() - val svc = createService(builder) - try f(svc) - finally svc.shutdown() - } + def apply(task: Task): Aws = { + new Aws(task) + } - def createService[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): T = - { - AwsEndpointConfiguration(task).configureAwsClientBuilder(builder) - AwsClientConfiguration(task).configureAwsClientBuilder(builder) - builder.setCredentials(AwsCredentials(task).createAwsCredentialsProvider) +} - builder.build() +class Aws(task: Aws.Task) { + + def withS3[A](f: AmazonS3 => A): A = { + val builder: AmazonS3ClientBuilder = AmazonS3ClientBuilder.standard() + AwsS3Configuration(task).configureAmazonS3ClientBuilder(builder) + val svc = createService(builder) + try f(svc) + finally svc.shutdown() + } + + def withTransferManager[A](f: TransferManager => A): A = { + withS3 { s3 => + val svc = TransferManagerBuilder.standard().withS3Client(s3).build() + try f(svc) + finally svc.shutdownNow(false) } + } + + def withGlue[A](f: AWSGlue => A): A = { + val builder: AWSGlueClientBuilder = AWSGlueClientBuilder.standard() + val svc = createService(builder) + try f(svc) + finally svc.shutdown() + } + + def createService[S <: AwsClientBuilder[S, T], T]( + builder: AwsClientBuilder[S, T] + ): T = { + AwsEndpointConfiguration(task).configureAwsClientBuilder(builder) + AwsClientConfiguration(task).configureAwsClientBuilder(builder) + builder.setCredentials(AwsCredentials(task).createAwsCredentialsProvider) + + builder.build() + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsClientConfiguration.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsClientConfiguration.scala index 6f0e975..063bb7b 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsClientConfiguration.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsClientConfiguration.scala @@ -1,6 +1,5 @@ package org.embulk.output.s3_parquet.aws - import java.util.Optional import com.amazonaws.ClientConfiguration @@ -8,35 +7,31 @@ import com.amazonaws.client.builder.AwsClientBuilder import org.embulk.config.{Config, ConfigDefault} import org.embulk.output.s3_parquet.aws.AwsClientConfiguration.Task +object AwsClientConfiguration { -object AwsClientConfiguration -{ - - trait Task - { + trait Task { - @Config("http_proxy") - @ConfigDefault("null") - def getHttpProxy: Optional[HttpProxy.Task] + @Config("http_proxy") + @ConfigDefault("null") + def getHttpProxy: Optional[HttpProxy.Task] - } + } - def apply(task: Task): AwsClientConfiguration = - { - new AwsClientConfiguration(task) - } + def apply(task: Task): AwsClientConfiguration = { + new AwsClientConfiguration(task) + } } -class AwsClientConfiguration(task: Task) -{ +class AwsClientConfiguration(task: Task) { - def configureAwsClientBuilder[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): Unit = - { - task.getHttpProxy.ifPresent { v => - val cc = new ClientConfiguration - HttpProxy(v).configureClientConfiguration(cc) - builder.setClientConfiguration(cc) - } + def configureAwsClientBuilder[S <: AwsClientBuilder[S, T], T]( + builder: AwsClientBuilder[S, T] + ): Unit = { + task.getHttpProxy.ifPresent { v => + val cc = new ClientConfiguration + HttpProxy(v).configureClientConfiguration(cc) + builder.setClientConfiguration(cc) } + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsCredentials.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsCredentials.scala index d20177a..02bf136 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsCredentials.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsCredentials.scala @@ -1,147 +1,174 @@ package org.embulk.output.s3_parquet.aws - import java.util.Optional -import com.amazonaws.auth.{AnonymousAWSCredentials, AWSCredentialsProvider, AWSStaticCredentialsProvider, BasicAWSCredentials, BasicSessionCredentials, DefaultAWSCredentialsProviderChain, EC2ContainerCredentialsProviderWrapper, EnvironmentVariableCredentialsProvider, STSAssumeRoleSessionCredentialsProvider, SystemPropertiesCredentialsProvider, WebIdentityTokenCredentialsProvider} -import com.amazonaws.auth.profile.{ProfileCredentialsProvider, ProfilesConfigFile} +import com.amazonaws.auth.{ + AnonymousAWSCredentials, + AWSCredentialsProvider, + AWSStaticCredentialsProvider, + BasicAWSCredentials, + BasicSessionCredentials, + DefaultAWSCredentialsProviderChain, + EC2ContainerCredentialsProviderWrapper, + EnvironmentVariableCredentialsProvider, + STSAssumeRoleSessionCredentialsProvider, + SystemPropertiesCredentialsProvider, + WebIdentityTokenCredentialsProvider +} +import com.amazonaws.auth.profile.{ + ProfileCredentialsProvider, + ProfilesConfigFile +} import org.embulk.config.{Config, ConfigDefault, ConfigException} import org.embulk.output.s3_parquet.aws.AwsCredentials.Task import org.embulk.spi.unit.LocalFile +object AwsCredentials { -object AwsCredentials -{ + trait Task { - trait Task - { + @Config("auth_method") + @ConfigDefault("\"default\"") + def getAuthMethod: String - @Config("auth_method") - @ConfigDefault("\"default\"") - def getAuthMethod: String + @Config("access_key_id") + @ConfigDefault("null") + def getAccessKeyId: Optional[String] - @Config("access_key_id") - @ConfigDefault("null") - def getAccessKeyId: Optional[String] + @Config("secret_access_key") + @ConfigDefault("null") + def getSecretAccessKey: Optional[String] - @Config("secret_access_key") - @ConfigDefault("null") - def getSecretAccessKey: Optional[String] + @Config("session_token") + @ConfigDefault("null") + def getSessionToken: Optional[String] - @Config("session_token") - @ConfigDefault("null") - def getSessionToken: Optional[String] + @Config("profile_file") + @ConfigDefault("null") + def getProfileFile: Optional[LocalFile] - @Config("profile_file") - @ConfigDefault("null") - def getProfileFile: Optional[LocalFile] + @Config("profile_name") + @ConfigDefault("\"default\"") + def getProfileName: String - @Config("profile_name") - @ConfigDefault("\"default\"") - def getProfileName: String + @Config("role_arn") + @ConfigDefault("null") + def getRoleArn: Optional[String] - @Config("role_arn") - @ConfigDefault("null") - def getRoleArn: Optional[String] + @Config("role_session_name") + @ConfigDefault("null") + def getRoleSessionName: Optional[String] - @Config("role_session_name") - @ConfigDefault("null") - def getRoleSessionName: Optional[String] + @Config("role_external_id") + @ConfigDefault("null") + def getRoleExternalId: Optional[String] - @Config("role_external_id") - @ConfigDefault("null") - def getRoleExternalId: Optional[String] + @Config("role_session_duration_seconds") + @ConfigDefault("null") + def getRoleSessionDurationSeconds: Optional[Int] - @Config("role_session_duration_seconds") - @ConfigDefault("null") - def getRoleSessionDurationSeconds: Optional[Int] + @Config("scope_down_policy") + @ConfigDefault("null") + def getScopeDownPolicy: Optional[String] - @Config("scope_down_policy") - @ConfigDefault("null") - def getScopeDownPolicy: Optional[String] - - @Config("web_identity_token_file") - @ConfigDefault("null") - def getWebIdentityTokenFile: Optional[String] - } + @Config("web_identity_token_file") + @ConfigDefault("null") + def getWebIdentityTokenFile: Optional[String] + } - def apply(task: Task): AwsCredentials = - { - new AwsCredentials(task) - } + def apply(task: Task): AwsCredentials = { + new AwsCredentials(task) + } } -class AwsCredentials(task: Task) -{ - - def createAwsCredentialsProvider: AWSCredentialsProvider = - { - task.getAuthMethod match { - case "basic" => - new AWSStaticCredentialsProvider(new BasicAWSCredentials( - getRequiredOption(task.getAccessKeyId, "access_key_id"), - getRequiredOption(task.getAccessKeyId, "secret_access_key") - )) - - case "env" => - new EnvironmentVariableCredentialsProvider - - case "instance" => - // NOTE: combination of InstanceProfileCredentialsProvider and ContainerCredentialsProvider - new EC2ContainerCredentialsProviderWrapper - - case "profile" => - if (task.getProfileFile.isPresent) { - val pf: ProfilesConfigFile = new ProfilesConfigFile(task.getProfileFile.get().getFile) - new ProfileCredentialsProvider(pf, task.getProfileName) - } - else new ProfileCredentialsProvider(task.getProfileName) - - case "properties" => - new SystemPropertiesCredentialsProvider - - case "anonymous" => - new AWSStaticCredentialsProvider(new AnonymousAWSCredentials) - - case "session" => - new AWSStaticCredentialsProvider(new BasicSessionCredentials( - getRequiredOption(task.getAccessKeyId, "access_key_id"), - getRequiredOption(task.getSecretAccessKey, "secret_access_key"), - getRequiredOption(task.getSessionToken, "session_token") - )) - - case "assume_role" => - // NOTE: Are http_proxy, endpoint, region required when assuming role? - val builder = new STSAssumeRoleSessionCredentialsProvider.Builder( - getRequiredOption(task.getRoleArn, "role_arn"), - getRequiredOption(task.getRoleSessionName, "role_session_name") - ) - task.getRoleExternalId.ifPresent(v => builder.withExternalId(v)) - task.getRoleSessionDurationSeconds.ifPresent(v => builder.withRoleSessionDurationSeconds(v)) - task.getScopeDownPolicy.ifPresent(v => builder.withScopeDownPolicy(v)) - - builder.build() - - case "web_identity_token" => - WebIdentityTokenCredentialsProvider.builder() - .roleArn(getRequiredOption(task.getRoleArn, "role_arn")) - .roleSessionName(getRequiredOption(task.getRoleSessionName, "role_session_name")) - .webIdentityTokenFile(getRequiredOption(task.getWebIdentityTokenFile, "web_identity_token_file")) - .build() - - case "default" => - new DefaultAWSCredentialsProviderChain - - case am => - throw new ConfigException(s"'$am' is unsupported: `auth_method` must be one of ['basic', 'env', 'instance', 'profile', 'properties', 'anonymous', 'session', 'assume_role', 'default'].") +class AwsCredentials(task: Task) { + + def createAwsCredentialsProvider: AWSCredentialsProvider = { + task.getAuthMethod match { + case "basic" => + new AWSStaticCredentialsProvider( + new BasicAWSCredentials( + getRequiredOption(task.getAccessKeyId, "access_key_id"), + getRequiredOption(task.getSecretAccessKey, "secret_access_key") + ) + ) + + case "env" => + new EnvironmentVariableCredentialsProvider + + case "instance" => + // NOTE: combination of InstanceProfileCredentialsProvider and ContainerCredentialsProvider + new EC2ContainerCredentialsProviderWrapper + + case "profile" => + if (task.getProfileFile.isPresent) { + val pf: ProfilesConfigFile = new ProfilesConfigFile( + task.getProfileFile.get().getFile + ) + new ProfileCredentialsProvider(pf, task.getProfileName) } + else new ProfileCredentialsProvider(task.getProfileName) + + case "properties" => + new SystemPropertiesCredentialsProvider + + case "anonymous" => + new AWSStaticCredentialsProvider(new AnonymousAWSCredentials) + + case "session" => + new AWSStaticCredentialsProvider( + new BasicSessionCredentials( + getRequiredOption(task.getAccessKeyId, "access_key_id"), + getRequiredOption(task.getSecretAccessKey, "secret_access_key"), + getRequiredOption(task.getSessionToken, "session_token") + ) + ) + + case "assume_role" => + // NOTE: Are http_proxy, endpoint, region required when assuming role? + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder( + getRequiredOption(task.getRoleArn, "role_arn"), + getRequiredOption(task.getRoleSessionName, "role_session_name") + ) + task.getRoleExternalId.ifPresent(v => builder.withExternalId(v)) + task.getRoleSessionDurationSeconds.ifPresent(v => + builder.withRoleSessionDurationSeconds(v) + ) + task.getScopeDownPolicy.ifPresent(v => builder.withScopeDownPolicy(v)) + + builder.build() + + case "web_identity_token" => + WebIdentityTokenCredentialsProvider + .builder() + .roleArn(getRequiredOption(task.getRoleArn, "role_arn")) + .roleSessionName( + getRequiredOption(task.getRoleSessionName, "role_session_name") + ) + .webIdentityTokenFile( + getRequiredOption( + task.getWebIdentityTokenFile, + "web_identity_token_file" + ) + ) + .build() + + case "default" => + new DefaultAWSCredentialsProviderChain + + case am => + throw new ConfigException( + s"'$am' is unsupported: `auth_method` must be one of ['basic', 'env', 'instance', 'profile', 'properties', 'anonymous', 'session', 'assume_role', 'default']." + ) } - - private def getRequiredOption[A](o: Optional[A], - name: String): A = - { - o.orElseThrow(() => new ConfigException(s"`$name` must be set when `auth_method` is ${task.getAuthMethod}.")) - } - + } + + private def getRequiredOption[A](o: Optional[A], name: String): A = { + o.orElseThrow(() => + new ConfigException( + s"`$name` must be set when `auth_method` is ${task.getAuthMethod}." + ) + ) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsEndpointConfiguration.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsEndpointConfiguration.scala index e0303aa..47a20f6 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsEndpointConfiguration.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsEndpointConfiguration.scala @@ -1,6 +1,5 @@ package org.embulk.output.s3_parquet.aws - import java.util.Optional import com.amazonaws.client.builder.AwsClientBuilder @@ -11,47 +10,45 @@ import org.embulk.output.s3_parquet.aws.AwsEndpointConfiguration.Task import scala.util.Try +object AwsEndpointConfiguration { -object AwsEndpointConfiguration -{ - - trait Task - { + trait Task { - @Config("endpoint") - @ConfigDefault("null") - def getEndpoint: Optional[String] + @Config("endpoint") + @ConfigDefault("null") + def getEndpoint: Optional[String] - @Config("region") - @ConfigDefault("null") - def getRegion: Optional[String] + @Config("region") + @ConfigDefault("null") + def getRegion: Optional[String] - } + } - def apply(task: Task): AwsEndpointConfiguration = - { - new AwsEndpointConfiguration(task) - } + def apply(task: Task): AwsEndpointConfiguration = { + new AwsEndpointConfiguration(task) + } } -class AwsEndpointConfiguration(task: Task) -{ - - def configureAwsClientBuilder[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): Unit = - { - if (task.getRegion.isPresent && task.getEndpoint.isPresent) { - val ec = new EndpointConfiguration(task.getEndpoint.get, task.getRegion.get) - builder.setEndpointConfiguration(ec) - } - else if (task.getRegion.isPresent && !task.getEndpoint.isPresent) { - builder.setRegion(task.getRegion.get) - } - else if (!task.getRegion.isPresent && task.getEndpoint.isPresent) { - val r: String = Try(new DefaultAwsRegionProviderChain().getRegion).getOrElse(Regions.DEFAULT_REGION.getName) - val e: String = task.getEndpoint.get - val ec = new EndpointConfiguration(e, r) - builder.setEndpointConfiguration(ec) - } +class AwsEndpointConfiguration(task: Task) { + + def configureAwsClientBuilder[S <: AwsClientBuilder[S, T], T]( + builder: AwsClientBuilder[S, T] + ): Unit = { + if (task.getRegion.isPresent && task.getEndpoint.isPresent) { + val ec = + new EndpointConfiguration(task.getEndpoint.get, task.getRegion.get) + builder.setEndpointConfiguration(ec) + } + else if (task.getRegion.isPresent && !task.getEndpoint.isPresent) { + builder.setRegion(task.getRegion.get) + } + else if (!task.getRegion.isPresent && task.getEndpoint.isPresent) { + val r: String = Try(new DefaultAwsRegionProviderChain().getRegion) + .getOrElse(Regions.DEFAULT_REGION.getName) + val e: String = task.getEndpoint.get + val ec = new EndpointConfiguration(e, r) + builder.setEndpointConfiguration(ec) } + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsS3Configuration.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsS3Configuration.scala index 2e306f3..4c9538f 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsS3Configuration.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsS3Configuration.scala @@ -1,64 +1,68 @@ package org.embulk.output.s3_parquet.aws - import java.util.Optional import com.amazonaws.services.s3.AmazonS3ClientBuilder import org.embulk.config.{Config, ConfigDefault} import org.embulk.output.s3_parquet.aws.AwsS3Configuration.Task - /* * These are advanced settings, so write no documentation. */ -object AwsS3Configuration -{ - trait Task - { +object AwsS3Configuration { - @Config("accelerate_mode_enabled") - @ConfigDefault("null") - def getAccelerateModeEnabled: Optional[Boolean] + trait Task { - @Config("chunked_encoding_disabled") - @ConfigDefault("null") - def getChunkedEncodingDisabled: Optional[Boolean] + @Config("accelerate_mode_enabled") + @ConfigDefault("null") + def getAccelerateModeEnabled: Optional[Boolean] - @Config("dualstack_enabled") - @ConfigDefault("null") - def getDualstackEnabled: Optional[Boolean] + @Config("chunked_encoding_disabled") + @ConfigDefault("null") + def getChunkedEncodingDisabled: Optional[Boolean] - @Config("force_global_bucket_access_enabled") - @ConfigDefault("null") - def getForceGlobalBucketAccessEnabled: Optional[Boolean] + @Config("dualstack_enabled") + @ConfigDefault("null") + def getDualstackEnabled: Optional[Boolean] - @Config("path_style_access_enabled") - @ConfigDefault("null") - def getPathStyleAccessEnabled: Optional[Boolean] + @Config("force_global_bucket_access_enabled") + @ConfigDefault("null") + def getForceGlobalBucketAccessEnabled: Optional[Boolean] - @Config("payload_signing_enabled") - @ConfigDefault("null") - def getPayloadSigningEnabled: Optional[Boolean] + @Config("path_style_access_enabled") + @ConfigDefault("null") + def getPathStyleAccessEnabled: Optional[Boolean] - } + @Config("payload_signing_enabled") + @ConfigDefault("null") + def getPayloadSigningEnabled: Optional[Boolean] - def apply(task: Task): AwsS3Configuration = - { - new AwsS3Configuration(task) - } -} + } -class AwsS3Configuration(task: Task) -{ + def apply(task: Task): AwsS3Configuration = { + new AwsS3Configuration(task) + } +} - def configureAmazonS3ClientBuilder(builder: AmazonS3ClientBuilder): Unit = - { - task.getAccelerateModeEnabled.ifPresent(v => builder.setAccelerateModeEnabled(v)) - task.getChunkedEncodingDisabled.ifPresent(v => builder.setChunkedEncodingDisabled(v)) - task.getDualstackEnabled.ifPresent(v => builder.setDualstackEnabled(v)) - task.getForceGlobalBucketAccessEnabled.ifPresent(v => builder.setForceGlobalBucketAccessEnabled(v)) - task.getPathStyleAccessEnabled.ifPresent(v => builder.setPathStyleAccessEnabled(v)) - task.getPayloadSigningEnabled.ifPresent(v => builder.setPayloadSigningEnabled(v)) - } +class AwsS3Configuration(task: Task) { + + def configureAmazonS3ClientBuilder(builder: AmazonS3ClientBuilder): Unit = { + task.getAccelerateModeEnabled.ifPresent(v => + builder.setAccelerateModeEnabled(v) + ) + task.getChunkedEncodingDisabled.ifPresent(v => + builder.setChunkedEncodingDisabled(v) + ) + task.getDualstackEnabled.ifPresent(v => builder.setDualstackEnabled(v)) + task.getForceGlobalBucketAccessEnabled.ifPresent(v => + builder.setForceGlobalBucketAccessEnabled(v) + ) + task.getPathStyleAccessEnabled.ifPresent(v => + builder.setPathStyleAccessEnabled(v) + ) + task.getPayloadSigningEnabled.ifPresent(v => + builder.setPayloadSigningEnabled(v) + ) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/HttpProxy.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/HttpProxy.scala index 4318538..68e2aa3 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/HttpProxy.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/HttpProxy.scala @@ -1,64 +1,61 @@ package org.embulk.output.s3_parquet.aws - import java.util.Optional import com.amazonaws.{ClientConfiguration, Protocol} import org.embulk.config.{Config, ConfigDefault, ConfigException} import org.embulk.output.s3_parquet.aws.HttpProxy.Task +object HttpProxy { -object HttpProxy -{ - - trait Task - { + trait Task { - @Config("host") - @ConfigDefault("null") - def getHost: Optional[String] + @Config("host") + @ConfigDefault("null") + def getHost: Optional[String] - @Config("port") - @ConfigDefault("null") - def getPort: Optional[Int] + @Config("port") + @ConfigDefault("null") + def getPort: Optional[Int] - @Config("protocol") - @ConfigDefault("\"https\"") - def getProtocol: String + @Config("protocol") + @ConfigDefault("\"https\"") + def getProtocol: String - @Config("user") - @ConfigDefault("null") - def getUser: Optional[String] + @Config("user") + @ConfigDefault("null") + def getUser: Optional[String] - @Config("password") - @ConfigDefault("null") - def getPassword: Optional[String] + @Config("password") + @ConfigDefault("null") + def getPassword: Optional[String] - } + } - def apply(task: Task): HttpProxy = - { - new HttpProxy(task) - } + def apply(task: Task): HttpProxy = { + new HttpProxy(task) + } } -class HttpProxy(task: Task) -{ - - def configureClientConfiguration(cc: ClientConfiguration): Unit = - { - task.getHost.ifPresent(v => cc.setProxyHost(v)) - task.getPort.ifPresent(v => cc.setProxyPort(v)) - - Protocol.values.find(p => p.name().equals(task.getProtocol)) match { - case Some(v) => - cc.setProtocol(v) - case None => - throw new ConfigException(s"'${task.getProtocol}' is unsupported: `protocol` must be one of [${Protocol.values.map(v => s"'$v'").mkString(", ")}].") - } - - task.getUser.ifPresent(v => cc.setProxyUsername(v)) - task.getPassword.ifPresent(v => cc.setProxyPassword(v)) +class HttpProxy(task: Task) { + + def configureClientConfiguration(cc: ClientConfiguration): Unit = { + task.getHost.ifPresent(v => cc.setProxyHost(v)) + task.getPort.ifPresent(v => cc.setProxyPort(v)) + + Protocol.values.find(p => p.name().equals(task.getProtocol)) match { + case Some(v) => + cc.setProtocol(v) + case None => + throw new ConfigException( + s"'${task.getProtocol}' is unsupported: `protocol` must be one of [${Protocol.values + .map(v => s"'$v'") + .mkString(", ")}]." + ) } + + task.getUser.ifPresent(v => cc.setProxyUsername(v)) + task.getPassword.ifPresent(v => cc.setProxyPassword(v)) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala index a4c9892..61cc8cc 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala @@ -1,107 +1,153 @@ package org.embulk.output.s3_parquet.parquet - import com.google.common.collect.ImmutableList -import org.apache.parquet.schema.{MessageType, OriginalType, PrimitiveType, Type} +import org.apache.parquet.schema.{ + MessageType, + OriginalType, + PrimitiveType, + Type +} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.embulk.spi.{Column, ColumnVisitor, Schema} +object EmbulkMessageType { + + def builder(): Builder = { + Builder() + } + + case class Builder( + name: String = "embulk", + schema: Schema = Schema.builder().build(), + logicalTypeHandlers: LogicalTypeHandlerStore = + LogicalTypeHandlerStore.empty + ) { + + def withName(name: String): Builder = { + Builder( + name = name, + schema = schema, + logicalTypeHandlers = logicalTypeHandlers + ) + } + + def withSchema(schema: Schema): Builder = { + Builder( + name = name, + schema = schema, + logicalTypeHandlers = logicalTypeHandlers + ) + } + + def withLogicalTypeHandlers( + logicalTypeHandlers: LogicalTypeHandlerStore + ): Builder = { + Builder( + name = name, + schema = schema, + logicalTypeHandlers = logicalTypeHandlers + ) + } + + def build(): MessageType = { + val builder: ImmutableList.Builder[Type] = ImmutableList.builder[Type]() + schema.visitColumns( + EmbulkMessageTypeColumnVisitor(builder, logicalTypeHandlers) + ) + new MessageType("embulk", builder.build()) + } + + } + + private case class EmbulkMessageTypeColumnVisitor( + builder: ImmutableList.Builder[Type], + logicalTypeHandlers: LogicalTypeHandlerStore = + LogicalTypeHandlerStore.empty + ) extends ColumnVisitor { + + override def booleanColumn(column: Column): Unit = { + builder.add( + new PrimitiveType( + Type.Repetition.OPTIONAL, + PrimitiveTypeName.BOOLEAN, + column.getName + ) + ) + } -object EmbulkMessageType -{ + override def longColumn(column: Column): Unit = { + val name = column.getName + val et = column.getType + + val t = logicalTypeHandlers.get(name, et) match { + case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name) + case _ => + new PrimitiveType( + Type.Repetition.OPTIONAL, + PrimitiveTypeName.INT64, + column.getName + ) + } + + builder.add(t) + } - def builder(): Builder = - { - Builder() + override def doubleColumn(column: Column): Unit = { + builder.add( + new PrimitiveType( + Type.Repetition.OPTIONAL, + PrimitiveTypeName.DOUBLE, + column.getName + ) + ) } - case class Builder(name: String = "embulk", - schema: Schema = Schema.builder().build(), - logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty) - { - - def withName(name: String): Builder = - { - Builder(name = name, schema = schema, logicalTypeHandlers = logicalTypeHandlers) - } - - def withSchema(schema: Schema): Builder = - { - Builder(name = name, schema = schema, logicalTypeHandlers = logicalTypeHandlers) - } - - def withLogicalTypeHandlers(logicalTypeHandlers: LogicalTypeHandlerStore): Builder = - { - Builder(name = name, schema = schema, logicalTypeHandlers = logicalTypeHandlers) - } - - def build(): MessageType = - { - val builder: ImmutableList.Builder[Type] = ImmutableList.builder[Type]() - schema.visitColumns(EmbulkMessageTypeColumnVisitor(builder, logicalTypeHandlers)) - new MessageType("embulk", builder.build()) - } + override def stringColumn(column: Column): Unit = { + builder.add( + new PrimitiveType( + Type.Repetition.OPTIONAL, + PrimitiveTypeName.BINARY, + column.getName, + OriginalType.UTF8 + ) + ) + } + override def timestampColumn(column: Column): Unit = { + val name = column.getName + val et = column.getType + + val t = logicalTypeHandlers.get(name, et) match { + case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name) + case _ => + new PrimitiveType( + Type.Repetition.OPTIONAL, + PrimitiveTypeName.BINARY, + name, + OriginalType.UTF8 + ) + } + + builder.add(t) } - private case class EmbulkMessageTypeColumnVisitor(builder: ImmutableList.Builder[Type], - logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty) - extends ColumnVisitor - { - - override def booleanColumn(column: Column): Unit = - { - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BOOLEAN, column.getName)) - } - - override def longColumn(column: Column): Unit = - { - val name = column.getName - val et = column.getType - - val t = logicalTypeHandlers.get(name, et) match { - case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name) - case _ => new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.INT64, column.getName) - } - - builder.add(t) - } - - override def doubleColumn(column: Column): Unit = - { - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.DOUBLE, column.getName)) - } - - override def stringColumn(column: Column): Unit = - { - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, column.getName, OriginalType.UTF8)) - } - - override def timestampColumn(column: Column): Unit = - { - val name = column.getName - val et = column.getType - - val t = logicalTypeHandlers.get(name, et) match { - case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name) - case _ => new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, name, OriginalType.UTF8) - } - - builder.add(t) - } - - override def jsonColumn(column: Column): Unit = - { - val name = column.getName - val et = column.getType - - val t = logicalTypeHandlers.get(name, et) match { - case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name) - case _ => new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, name, OriginalType.UTF8) - } - - builder.add(t) - } + override def jsonColumn(column: Column): Unit = { + val name = column.getName + val et = column.getType + + val t = logicalTypeHandlers.get(name, et) match { + case Some(h) if h.isConvertible(et) => h.newSchemaFieldType(name) + case _ => + new PrimitiveType( + Type.Repetition.OPTIONAL, + PrimitiveTypeName.BINARY, + name, + OriginalType.UTF8 + ) + } + + builder.add(t) } + } -} \ No newline at end of file +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandler.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandler.scala index 108bcae..d02784e 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandler.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandler.scala @@ -1,6 +1,5 @@ package org.embulk.output.s3_parquet.parquet - import org.apache.parquet.io.api.{Binary, RecordConsumer} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.{Type => PType} @@ -11,135 +10,151 @@ import org.embulk.spi.`type`.Types import org.embulk.spi.time.Timestamp import org.msgpack.value.Value - /** - * Handle Apache Parquet 'Logical Types' on schema/value conversion. - * ref. https://github.com/apache/parquet-format/blob/master/LogicalTypes.md - * - * It focuses on only older representation because newer supported since 1.11 is not used actually yet. - * TODO Support both of older and newer representation after 1.11+ is published and other middleware supports it. - * - */ -sealed trait LogicalTypeHandler -{ - def isConvertible(t: EType): Boolean - - def newSchemaFieldType(name: String): PrimitiveType - - def consume(orig: Any, - recordConsumer: RecordConsumer): Unit + * Handle Apache Parquet 'Logical Types' on schema/value conversion. + * ref. https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + * + * It focuses on only older representation because newer supported since 1.11 is not used actually yet. + * TODO Support both of older and newer representation after 1.11+ is published and other middleware supports it. + * + */ +sealed trait LogicalTypeHandler { + def isConvertible(t: EType): Boolean + + def newSchemaFieldType(name: String): PrimitiveType + + def consume(orig: Any, recordConsumer: RecordConsumer): Unit } abstract class IntLogicalTypeHandler(ot: OriginalType) - extends LogicalTypeHandler -{ - override def isConvertible(t: EType): Boolean = - { - t == Types.LONG - } - - override def newSchemaFieldType(name: String): PrimitiveType = - { - new PrimitiveType(PType.Repetition.OPTIONAL, PrimitiveTypeName.INT64, name, ot) - } - - override def consume(orig: Any, - recordConsumer: RecordConsumer): Unit = - { - orig match { - case v: Long => recordConsumer.addLong(v) - case _ => throw new DataException("given mismatched type value; expected type is long") - } + extends LogicalTypeHandler { + + override def isConvertible(t: EType): Boolean = { + t == Types.LONG + } + + override def newSchemaFieldType(name: String): PrimitiveType = { + new PrimitiveType( + PType.Repetition.OPTIONAL, + PrimitiveTypeName.INT64, + name, + ot + ) + } + + override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { + orig match { + case v: Long => recordConsumer.addLong(v) + case _ => + throw new DataException( + "given mismatched type value; expected type is long" + ) } + } } -object TimestampMillisLogicalTypeHandler - extends LogicalTypeHandler -{ - override def isConvertible(t: EType): Boolean = - { - t == Types.TIMESTAMP - } - - override def newSchemaFieldType(name: String): PrimitiveType = - { - new PrimitiveType(PType.Repetition.OPTIONAL, PrimitiveTypeName.INT64, name, OriginalType.TIMESTAMP_MILLIS) - } - - override def consume(orig: Any, - recordConsumer: RecordConsumer): Unit = - { - orig match { - case ts: Timestamp => recordConsumer.addLong(ts.toEpochMilli) - case _ => throw new DataException("given mismatched type value; expected type is timestamp") - } +object TimestampMillisLogicalTypeHandler extends LogicalTypeHandler { + + override def isConvertible(t: EType): Boolean = { + t == Types.TIMESTAMP + } + + override def newSchemaFieldType(name: String): PrimitiveType = { + new PrimitiveType( + PType.Repetition.OPTIONAL, + PrimitiveTypeName.INT64, + name, + OriginalType.TIMESTAMP_MILLIS + ) + } + + override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { + orig match { + case ts: Timestamp => recordConsumer.addLong(ts.toEpochMilli) + case _ => + throw new DataException( + "given mismatched type value; expected type is timestamp" + ) } + } } -object TimestampMicrosLogicalTypeHandler - extends LogicalTypeHandler -{ - override def isConvertible(t: EType): Boolean = - { - t == Types.TIMESTAMP - } - - override def newSchemaFieldType(name: String): PrimitiveType = - { - new PrimitiveType(PType.Repetition.OPTIONAL, PrimitiveTypeName.INT64, name, OriginalType.TIMESTAMP_MICROS) - } - - override def consume(orig: Any, - recordConsumer: RecordConsumer): Unit = - { - orig match { - case ts: Timestamp => - val v = (ts.getEpochSecond * 1_000_000L) + (ts.getNano.asInstanceOf[Long] / 1_000L) - recordConsumer.addLong(v) - case _ => throw new DataException("given mismatched type value; expected type is timestamp") - } +object TimestampMicrosLogicalTypeHandler extends LogicalTypeHandler { + + override def isConvertible(t: EType): Boolean = { + t == Types.TIMESTAMP + } + + override def newSchemaFieldType(name: String): PrimitiveType = { + new PrimitiveType( + PType.Repetition.OPTIONAL, + PrimitiveTypeName.INT64, + name, + OriginalType.TIMESTAMP_MICROS + ) + } + + override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { + orig match { + case ts: Timestamp => + val v = (ts.getEpochSecond * 1_000_000L) + (ts.getNano + .asInstanceOf[Long] / 1_000L) + recordConsumer.addLong(v) + case _ => + throw new DataException( + "given mismatched type value; expected type is timestamp" + ) } + } } -object Int8LogicalTypeHandler - extends IntLogicalTypeHandler(OriginalType.INT_8) +object Int8LogicalTypeHandler extends IntLogicalTypeHandler(OriginalType.INT_8) + object Int16LogicalTypeHandler extends IntLogicalTypeHandler(OriginalType.INT_16) + object Int32LogicalTypeHandler extends IntLogicalTypeHandler(OriginalType.INT_32) + object Int64LogicalTypeHandler extends IntLogicalTypeHandler(OriginalType.INT_64) object Uint8LogicalTypeHandler extends IntLogicalTypeHandler(OriginalType.UINT_8) + object Uint16LogicalTypeHandler extends IntLogicalTypeHandler(OriginalType.UINT_16) + object Uint32LogicalTypeHandler extends IntLogicalTypeHandler(OriginalType.UINT_32) + object Uint64LogicalTypeHandler extends IntLogicalTypeHandler(OriginalType.UINT_64) -object JsonLogicalTypeHandler - extends LogicalTypeHandler -{ - override def isConvertible(t: EType): Boolean = - { - t == Types.JSON - } - - override def newSchemaFieldType(name: String): PrimitiveType = - { - new PrimitiveType(PType.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, name, OriginalType.JSON) - } - - override def consume(orig: Any, - recordConsumer: RecordConsumer): Unit = - { - orig match { - case msgPack: Value => - val bin = Binary.fromString(msgPack.toJson) - recordConsumer.addBinary(bin) - case _ => throw new DataException("given mismatched type value; expected type is json") - } +object JsonLogicalTypeHandler extends LogicalTypeHandler { + + override def isConvertible(t: EType): Boolean = { + t == Types.JSON + } + + override def newSchemaFieldType(name: String): PrimitiveType = { + new PrimitiveType( + PType.Repetition.OPTIONAL, + PrimitiveTypeName.BINARY, + name, + OriginalType.JSON + ) + } + + override def consume(orig: Any, recordConsumer: RecordConsumer): Unit = { + orig match { + case msgPack: Value => + val bin = Binary.fromString(msgPack.toJson) + recordConsumer.addBinary(bin) + case _ => + throw new DataException( + "given mismatched type value; expected type is json" + ) } + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala index d2c2d91..65e1a6d 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/LogicalTypeHandlerStore.scala @@ -1,107 +1,114 @@ package org.embulk.output.s3_parquet.parquet - import org.embulk.spi.`type`.{Type, Types} import java.util.{Map => JMap} import org.embulk.config.ConfigException -import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ColumnOptionTask, TypeOptionTask} +import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ + ColumnOptionTask, + TypeOptionTask +} import scala.jdk.CollectionConverters._ - /** - * A storage has mapping from logical type query (column name, type) to handler. - * - * @param fromEmbulkType - * @param fromColumnName - */ -case class LogicalTypeHandlerStore private(fromEmbulkType: Map[Type, LogicalTypeHandler], - fromColumnName: Map[String, LogicalTypeHandler]) -{ + * A storage has mapping from logical type query (column name, type) to handler. + * + * @param fromEmbulkType + * @param fromColumnName + */ +case class LogicalTypeHandlerStore private ( + fromEmbulkType: Map[Type, LogicalTypeHandler], + fromColumnName: Map[String, LogicalTypeHandler] +) { - // Try column name lookup, then column type - def get(n: String, - t: Type): Option[LogicalTypeHandler] = - { - get(n) match { - case Some(h) => Some(h) - case _ => - get(t) match { - case Some(h) => Some(h) - case _ => None - } + // Try column name lookup, then column type + def get(n: String, t: Type): Option[LogicalTypeHandler] = { + get(n) match { + case Some(h) => Some(h) + case _ => + get(t) match { + case Some(h) => Some(h) + case _ => None } } + } - def get(t: Type): Option[LogicalTypeHandler] = - { - fromEmbulkType.get(t) - } + def get(t: Type): Option[LogicalTypeHandler] = { + fromEmbulkType.get(t) + } - def get(n: String): Option[LogicalTypeHandler] = - { - fromColumnName.get(n) - } + def get(n: String): Option[LogicalTypeHandler] = { + fromColumnName.get(n) + } } -object LogicalTypeHandlerStore -{ - private val STRING_TO_EMBULK_TYPE = Map[String, Type]( - "boolean" -> Types.BOOLEAN, - "long" -> Types.LONG, - "double" -> Types.DOUBLE, - "string" -> Types.STRING, - "timestamp" -> Types.TIMESTAMP, - "json" -> Types.JSON - ) +object LogicalTypeHandlerStore { - // Listed only older logical types that we can convert from embulk type - private val STRING_TO_LOGICAL_TYPE = Map[String, LogicalTypeHandler]( - "timestamp-millis" -> TimestampMillisLogicalTypeHandler, - "timestamp-micros" -> TimestampMicrosLogicalTypeHandler, - "int8" -> Int8LogicalTypeHandler, - "int16" -> Int16LogicalTypeHandler, - "int32" -> Int32LogicalTypeHandler, - "int64" -> Int64LogicalTypeHandler, - "uint8" -> Uint8LogicalTypeHandler, - "uint16" -> Uint16LogicalTypeHandler, - "uint32" -> Uint32LogicalTypeHandler, - "uint64" -> Uint64LogicalTypeHandler, - "json" -> JsonLogicalTypeHandler - ) + private val STRING_TO_EMBULK_TYPE = Map[String, Type]( + "boolean" -> Types.BOOLEAN, + "long" -> Types.LONG, + "double" -> Types.DOUBLE, + "string" -> Types.STRING, + "timestamp" -> Types.TIMESTAMP, + "json" -> Types.JSON + ) - def empty: LogicalTypeHandlerStore = - { - LogicalTypeHandlerStore(Map.empty[Type, LogicalTypeHandler], Map.empty[String, LogicalTypeHandler]) - } + // Listed only older logical types that we can convert from embulk type + private val STRING_TO_LOGICAL_TYPE = Map[String, LogicalTypeHandler]( + "timestamp-millis" -> TimestampMillisLogicalTypeHandler, + "timestamp-micros" -> TimestampMicrosLogicalTypeHandler, + "int8" -> Int8LogicalTypeHandler, + "int16" -> Int16LogicalTypeHandler, + "int32" -> Int32LogicalTypeHandler, + "int64" -> Int64LogicalTypeHandler, + "uint8" -> Uint8LogicalTypeHandler, + "uint16" -> Uint16LogicalTypeHandler, + "uint32" -> Uint32LogicalTypeHandler, + "uint64" -> Uint64LogicalTypeHandler, + "json" -> JsonLogicalTypeHandler + ) - def fromEmbulkOptions(typeOpts: JMap[String, TypeOptionTask], - columnOpts: JMap[String, ColumnOptionTask]): LogicalTypeHandlerStore = - { - val fromEmbulkType = typeOpts.asScala - .filter(_._2.getLogicalType.isPresent) - .map[Type, LogicalTypeHandler] { case (k, v) => - val t = STRING_TO_EMBULK_TYPE.get(k) - val h = STRING_TO_LOGICAL_TYPE.get(v.getLogicalType.get) - (t, h) match { - case (Some(tt), Some(hh)) => (tt, hh) - case _ => throw new ConfigException("invalid logical types in type_options") - } - } - .toMap + def empty: LogicalTypeHandlerStore = { + LogicalTypeHandlerStore( + Map.empty[Type, LogicalTypeHandler], + Map.empty[String, LogicalTypeHandler] + ) + } - val fromColumnName = columnOpts.asScala - .filter(_._2.getLogicalType.isPresent) - .map[String, LogicalTypeHandler] { case (k, v) => - val h = STRING_TO_LOGICAL_TYPE.get(v.getLogicalType.get) - h match { - case Some(hh) => (k, hh) - case _ => throw new ConfigException("invalid logical types in column_options") - } - } - .toMap + def fromEmbulkOptions( + typeOpts: JMap[String, TypeOptionTask], + columnOpts: JMap[String, ColumnOptionTask] + ): LogicalTypeHandlerStore = { + val fromEmbulkType = typeOpts.asScala + .filter(_._2.getLogicalType.isPresent) + .map[Type, LogicalTypeHandler] { + case (k, v) => + val t = STRING_TO_EMBULK_TYPE.get(k) + val h = STRING_TO_LOGICAL_TYPE.get(v.getLogicalType.get) + (t, h) match { + case (Some(tt), Some(hh)) => (tt, hh) + case _ => + throw new ConfigException("invalid logical types in type_options") + } + } + .toMap - LogicalTypeHandlerStore(fromEmbulkType, fromColumnName) - } + val fromColumnName = columnOpts.asScala + .filter(_._2.getLogicalType.isPresent) + .map[String, LogicalTypeHandler] { + case (k, v) => + val h = STRING_TO_LOGICAL_TYPE.get(v.getLogicalType.get) + h match { + case Some(hh) => (k, hh) + case _ => + throw new ConfigException( + "invalid logical types in column_options" + ) + } + } + .toMap + + LogicalTypeHandlerStore(fromEmbulkType, fromColumnName) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala index b0deccf..76ff8dc 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala @@ -1,6 +1,5 @@ package org.embulk.output.s3_parquet.parquet - import org.apache.hadoop.conf.Configuration import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext @@ -11,32 +10,34 @@ import org.embulk.spi.time.TimestampFormatter import scala.jdk.CollectionConverters._ - -private[parquet] case class ParquetFileWriteSupport(schema: Schema, - timestampFormatters: Seq[TimestampFormatter], - logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty) - extends WriteSupport[PageReader] -{ - - private var currentParquetFileWriter: ParquetFileWriter = _ - - override def init(configuration: Configuration): WriteContext = - { - val messageType: MessageType = EmbulkMessageType.builder() - .withSchema(schema) - .withLogicalTypeHandlers(logicalTypeHandlers) - .build() - val metadata: Map[String, String] = Map.empty // NOTE: When is this used? - new WriteContext(messageType, metadata.asJava) - } - - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = - { - currentParquetFileWriter = ParquetFileWriter(recordConsumer, schema, timestampFormatters, logicalTypeHandlers) - } - - override def write(record: PageReader): Unit = - { - currentParquetFileWriter.write(record) - } +private[parquet] case class ParquetFileWriteSupport( + schema: Schema, + timestampFormatters: Seq[TimestampFormatter], + logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty +) extends WriteSupport[PageReader] { + + private var currentParquetFileWriter: ParquetFileWriter = _ + + override def init(configuration: Configuration): WriteContext = { + val messageType: MessageType = EmbulkMessageType + .builder() + .withSchema(schema) + .withLogicalTypeHandlers(logicalTypeHandlers) + .build() + val metadata: Map[String, String] = Map.empty // NOTE: When is this used? + new WriteContext(messageType, metadata.asJava) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + currentParquetFileWriter = ParquetFileWriter( + recordConsumer, + schema, + timestampFormatters, + logicalTypeHandlers + ) + } + + override def write(record: PageReader): Unit = { + currentParquetFileWriter.write(record) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala index a9f88c9..5eb5701 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala @@ -1,6 +1,5 @@ package org.embulk.output.s3_parquet.parquet - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetWriter @@ -9,168 +8,160 @@ import org.apache.parquet.io.api.{Binary, RecordConsumer} import org.embulk.spi.{Column, ColumnVisitor, PageReader, Schema} import org.embulk.spi.time.TimestampFormatter +object ParquetFileWriter { + + case class Builder( + path: Path = null, + schema: Schema = null, + timestampFormatters: Seq[TimestampFormatter] = null, + logicalTypeHandlers: LogicalTypeHandlerStore = + LogicalTypeHandlerStore.empty + ) extends ParquetWriter.Builder[PageReader, Builder](path) { -object ParquetFileWriter -{ - - case class Builder(path: Path = null, - schema: Schema = null, - timestampFormatters: Seq[TimestampFormatter] = null, - logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty) - extends ParquetWriter.Builder[PageReader, Builder](path) - { - - def withPath(path: Path): Builder = - { - copy(path = path) - } - - def withPath(pathString: String): Builder = - { - copy(path = new Path(pathString)) - } - - def withSchema(schema: Schema): Builder = - { - copy(schema = schema) - } - - def withTimestampFormatters(timestampFormatters: Seq[TimestampFormatter]): Builder = - { - copy(timestampFormatters = timestampFormatters) - } - - def withLogicalTypeHandlers(logicalTypeHandlers: LogicalTypeHandlerStore): Builder = - { - copy(logicalTypeHandlers = logicalTypeHandlers) - } - - override def self(): Builder = - { - this - } - - override def getWriteSupport(conf: Configuration): WriteSupport[PageReader] = - { - ParquetFileWriteSupport(schema, timestampFormatters, logicalTypeHandlers) - } + def withPath(path: Path): Builder = { + copy(path = path) } - def builder(): Builder = - { - Builder() + def withPath(pathString: String): Builder = { + copy(path = new Path(pathString)) } -} + def withSchema(schema: Schema): Builder = { + copy(schema = schema) + } + + def withTimestampFormatters( + timestampFormatters: Seq[TimestampFormatter] + ): Builder = { + copy(timestampFormatters = timestampFormatters) + } + def withLogicalTypeHandlers( + logicalTypeHandlers: LogicalTypeHandlerStore + ): Builder = { + copy(logicalTypeHandlers = logicalTypeHandlers) + } -private[parquet] case class ParquetFileWriter(recordConsumer: RecordConsumer, - schema: Schema, - timestampFormatters: Seq[TimestampFormatter], - logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty) -{ + override def self(): Builder = { + this + } - def write(record: PageReader): Unit = - { - recordConsumer.startMessage() - writeRecord(record) - recordConsumer.endMessage() + override def getWriteSupport( + conf: Configuration + ): WriteSupport[PageReader] = { + ParquetFileWriteSupport(schema, timestampFormatters, logicalTypeHandlers) } + } + + def builder(): Builder = { + Builder() + } + +} + +private[parquet] case class ParquetFileWriter( + recordConsumer: RecordConsumer, + schema: Schema, + timestampFormatters: Seq[TimestampFormatter], + logicalTypeHandlers: LogicalTypeHandlerStore = LogicalTypeHandlerStore.empty +) { + + def write(record: PageReader): Unit = { + recordConsumer.startMessage() + writeRecord(record) + recordConsumer.endMessage() + } - private def writeRecord(record: PageReader): Unit = - { - - schema.visitColumns(new ColumnVisitor() - { - - override def booleanColumn(column: Column): Unit = - { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addBoolean(record.getBoolean(column)) - }) - }) - } - - override def longColumn(column: Column): Unit = - { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addLong(record.getLong(column)) - }) - }) - } - - override def doubleColumn(column: Column): Unit = - { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addDouble(record.getDouble(column)) - }) - }) - } - - override def stringColumn(column: Column): Unit = - { - nullOr(column, { - withWriteFieldContext(column, { - val bin = Binary.fromString(record.getString(column)) - recordConsumer.addBinary(bin) - }) - }) - } - - override def timestampColumn(column: Column): Unit = - { - nullOr(column, { - withWriteFieldContext(column, { - val t = record.getTimestamp(column) - - logicalTypeHandlers.get(column.getName, column.getType) match { - case Some(h) => - h.consume(t, recordConsumer) - case _ => - val ft = timestampFormatters(column.getIndex).format(t) - val bin = Binary.fromString(ft) - recordConsumer.addBinary(bin) - } - }) - }) - } - - override def jsonColumn(column: Column): Unit = - { - nullOr(column, { - withWriteFieldContext(column, { - val msgPack = record.getJson(column) - - logicalTypeHandlers.get(column.getName, column.getType) match { - case Some(h) => - h.consume(msgPack, recordConsumer) - case _ => - val bin = Binary.fromString(msgPack.toJson) - recordConsumer.addBinary(bin) - } - }) - }) - } - - private def nullOr(column: Column, - f: => Unit): Unit = - { - if (!record.isNull(column)) f - } - - private def withWriteFieldContext(column: Column, - f: => Unit): Unit = - { - recordConsumer.startField(column.getName, column.getIndex) - f - recordConsumer.endField(column.getName, column.getIndex) - } + private def writeRecord(record: PageReader): Unit = { + schema.visitColumns(new ColumnVisitor() { + + override def booleanColumn(column: Column): Unit = { + nullOr(column, { + withWriteFieldContext(column, { + recordConsumer.addBoolean(record.getBoolean(column)) + }) }) + } - } + override def longColumn(column: Column): Unit = { + nullOr(column, { + withWriteFieldContext(column, { + recordConsumer.addLong(record.getLong(column)) + }) + }) + } -} \ No newline at end of file + override def doubleColumn(column: Column): Unit = { + nullOr(column, { + withWriteFieldContext(column, { + recordConsumer.addDouble(record.getDouble(column)) + }) + }) + } + + override def stringColumn(column: Column): Unit = { + nullOr(column, { + withWriteFieldContext(column, { + val bin = Binary.fromString(record.getString(column)) + recordConsumer.addBinary(bin) + }) + }) + } + + override def timestampColumn(column: Column): Unit = { + nullOr( + column, { + withWriteFieldContext( + column, { + val t = record.getTimestamp(column) + + logicalTypeHandlers.get(column.getName, column.getType) match { + case Some(h) => + h.consume(t, recordConsumer) + case _ => + val ft = timestampFormatters(column.getIndex).format(t) + val bin = Binary.fromString(ft) + recordConsumer.addBinary(bin) + } + } + ) + } + ) + } + + override def jsonColumn(column: Column): Unit = { + nullOr( + column, { + withWriteFieldContext( + column, { + val msgPack = record.getJson(column) + + logicalTypeHandlers.get(column.getName, column.getType) match { + case Some(h) => + h.consume(msgPack, recordConsumer) + case _ => + val bin = Binary.fromString(msgPack.toJson) + recordConsumer.addBinary(bin) + } + } + ) + } + ) + } + + private def nullOr(column: Column, f: => Unit): Unit = { + if (!record.isNull(column)) f + } + + private def withWriteFieldContext(column: Column, f: => Unit): Unit = { + recordConsumer.startField(column.getName, column.getIndex) + f + recordConsumer.endField(column.getName, column.getIndex) + } + + }) + + } + +} diff --git a/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala b/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala index 3d29532..046a929 100644 --- a/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala +++ b/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala @@ -1,6 +1,5 @@ package org.embulk.output.s3_parquet - import java.io.File import java.nio.file.FileSystems @@ -17,138 +16,166 @@ import org.embulk.spi.OutputPlugin import org.embulk.test.{EmbulkTests, TestingEmbulk} import org.junit.Rule import org.junit.runner.RunWith -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, DiagrammedAssertions, FunSuite} +import org.scalatest.{ + BeforeAndAfter, + BeforeAndAfterAll, + DiagrammedAssertions, + FunSuite +} import org.scalatestplus.junit.JUnitRunner import scala.annotation.meta.getter import scala.jdk.CollectionConverters._ - @RunWith(classOf[JUnitRunner]) class TestS3ParquetOutputPlugin extends FunSuite - with BeforeAndAfter - with BeforeAndAfterAll - with DiagrammedAssertions -{ - - val RESOURCE_NAME_PREFIX: String = "org/embulk/output/s3_parquet/" - val TEST_S3_ENDPOINT: String = "http://localhost:4572" - val TEST_S3_REGION: String = "us-east-1" - val TEST_S3_ACCESS_KEY_ID: String = "test" - val TEST_S3_SECRET_ACCESS_KEY: String = "test" - val TEST_BUCKET_NAME: String = "my-bucket" - - @(Rule@getter) - val embulk: TestingEmbulk = TestingEmbulk.builder() - .registerPlugin(classOf[OutputPlugin], "s3_parquet", classOf[S3ParquetOutputPlugin]) - .build() - - before { - withLocalStackS3Client(_.createBucket(TEST_BUCKET_NAME)) - } - - after { - withLocalStackS3Client(_.deleteBucket(TEST_BUCKET_NAME)) - } - - def defaultOutConfig(): ConfigSource = - { - embulk.newConfig() - .set("type", "s3_parquet") - .set("endpoint", "http://localhost:4572") // See https://github.com/localstack/localstack#overview - .set("bucket", TEST_BUCKET_NAME) - .set("path_prefix", "path/to/p") - .set("auth_method", "basic") - .set("access_key_id", TEST_S3_ACCESS_KEY_ID) - .set("secret_access_key", TEST_S3_SECRET_ACCESS_KEY) - .set("path_style_access_enabled", true) - .set("default_timezone", "Asia/Tokyo") - } - - - test("first test") { - val inPath = toPath("in1.csv") - val outConfig = defaultOutConfig() - - val result: TestingEmbulk.RunResult = embulk.runOutput(outConfig, inPath) - - - val outRecords: Seq[Map[String, String]] = result.getOutputTaskReports.asScala.map { tr => - val b = tr.get(classOf[String], "bucket") - val k = tr.get(classOf[String], "key") - readParquetFile(b, k) - }.foldLeft(Seq[Map[String, String]]()) { (merged, - records) => - merged ++ records - } - - val inRecords: Seq[Seq[String]] = EmbulkTests.readResource(RESOURCE_NAME_PREFIX + "out1.tsv") - .stripLineEnd - .split("\n") - .map(record => record.split("\t").toSeq) - .toSeq - - inRecords.zipWithIndex.foreach { - case (record, recordIndex) => - 0.to(5).foreach { columnIndex => - val columnName = s"c$columnIndex" - val inData: String = inRecords(recordIndex)(columnIndex) - val outData: String = outRecords(recordIndex).getOrElse(columnName, "") - - assert(outData === inData, s"record: $recordIndex, column: $columnName") - } - } - } - - def readParquetFile(bucket: String, - key: String): Seq[Map[String, String]] = - { - val createdParquetFile = embulk.createTempFile("in") - withLocalStackS3Client {s3 => - val xfer = TransferManagerBuilder.standard() - .withS3Client(s3) - .build() - try xfer.download(bucket, key, createdParquetFile.toFile).waitForCompletion() - finally xfer.shutdownNow() + with BeforeAndAfter + with BeforeAndAfterAll + with DiagrammedAssertions { + + val RESOURCE_NAME_PREFIX: String = "org/embulk/output/s3_parquet/" + val TEST_S3_ENDPOINT: String = "http://localhost:4572" + val TEST_S3_REGION: String = "us-east-1" + val TEST_S3_ACCESS_KEY_ID: String = "test" + val TEST_S3_SECRET_ACCESS_KEY: String = "test" + val TEST_BUCKET_NAME: String = "my-bucket" + + @(Rule @getter) + val embulk: TestingEmbulk = TestingEmbulk + .builder() + .registerPlugin( + classOf[OutputPlugin], + "s3_parquet", + classOf[S3ParquetOutputPlugin] + ) + .build() + + before { + withLocalStackS3Client(_.createBucket(TEST_BUCKET_NAME)) + } + + after { + withLocalStackS3Client(_.deleteBucket(TEST_BUCKET_NAME)) + } + + def defaultOutConfig(): ConfigSource = { + embulk + .newConfig() + .set("type", "s3_parquet") + .set("endpoint", "http://localhost:4572") // See https://github.com/localstack/localstack#overview + .set("bucket", TEST_BUCKET_NAME) + .set("path_prefix", "path/to/p") + .set("auth_method", "basic") + .set("access_key_id", TEST_S3_ACCESS_KEY_ID) + .set("secret_access_key", TEST_S3_SECRET_ACCESS_KEY) + .set("path_style_access_enabled", true) + .set("default_timezone", "Asia/Tokyo") + } + + test("first test") { + val inPath = toPath("in1.csv") + val outConfig = defaultOutConfig() + + val result: TestingEmbulk.RunResult = embulk.runOutput(outConfig, inPath) + + val outRecords: Seq[Map[String, String]] = + result.getOutputTaskReports.asScala + .map { tr => + val b = tr.get(classOf[String], "bucket") + val k = tr.get(classOf[String], "key") + readParquetFile(b, k) } - - val reader: ParquetReader[SimpleRecord] = ParquetReader - .builder(new SimpleReadSupport(), new HadoopPath(createdParquetFile.toString)) - .build() - - def read(reader: ParquetReader[SimpleRecord], - records: Seq[Map[String, String]] = Seq()): Seq[Map[String, String]] = - { - val simpleRecord: SimpleRecord = reader.read() - if (simpleRecord != null) { - val r: Map[String, String] = simpleRecord.getValues.asScala.map(v => v.getName -> v.getValue.toString).toMap - return read(reader, records :+ r) - } - records + .foldLeft(Seq[Map[String, String]]()) { (merged, records) => + merged ++ records } - try read(reader) - finally { - reader.close() - + val inRecords: Seq[Seq[String]] = EmbulkTests + .readResource(RESOURCE_NAME_PREFIX + "out1.tsv") + .stripLineEnd + .split("\n") + .map(record => record.split("\t").toSeq) + .toSeq + + inRecords.zipWithIndex.foreach { + case (record, recordIndex) => + 0.to(5).foreach { columnIndex => + val columnName = s"c$columnIndex" + val inData: String = inRecords(recordIndex)(columnIndex) + val outData: String = + outRecords(recordIndex).getOrElse(columnName, "") + + assert( + outData === inData, + s"record: $recordIndex, column: $columnName" + ) } } + } + + def readParquetFile(bucket: String, key: String): Seq[Map[String, String]] = { + val createdParquetFile = embulk.createTempFile("in") + withLocalStackS3Client { s3 => + val xfer = TransferManagerBuilder + .standard() + .withS3Client(s3) + .build() + try xfer + .download(bucket, key, createdParquetFile.toFile) + .waitForCompletion() + finally xfer.shutdownNow() + } - private def toPath(fileName: String) = - { - val url = Resources.getResource(RESOURCE_NAME_PREFIX + fileName) - FileSystems.getDefault.getPath(new File(url.toURI).getAbsolutePath) + val reader: ParquetReader[SimpleRecord] = ParquetReader + .builder( + new SimpleReadSupport(), + new HadoopPath(createdParquetFile.toString) + ) + .build() + + def read( + reader: ParquetReader[SimpleRecord], + records: Seq[Map[String, String]] = Seq() + ): Seq[Map[String, String]] = { + val simpleRecord: SimpleRecord = reader.read() + if (simpleRecord != null) { + val r: Map[String, String] = simpleRecord.getValues.asScala + .map(v => v.getName -> v.getValue.toString) + .toMap + return read(reader, records :+ r) + } + records } - private def withLocalStackS3Client[A](f: AmazonS3 => A): A = { - val client: AmazonS3 = AmazonS3ClientBuilder.standard - .withEndpointConfiguration(new EndpointConfiguration(TEST_S3_ENDPOINT, TEST_S3_REGION)) - .withCredentials(new AWSStaticCredentialsProvider(new BasicAWSCredentials(TEST_S3_ACCESS_KEY_ID, TEST_S3_SECRET_ACCESS_KEY))) - .withPathStyleAccessEnabled(true) - .build() + try read(reader) + finally { + reader.close() - try f(client) - finally client.shutdown() } + } + + private def toPath(fileName: String) = { + val url = Resources.getResource(RESOURCE_NAME_PREFIX + fileName) + FileSystems.getDefault.getPath(new File(url.toURI).getAbsolutePath) + } + + private def withLocalStackS3Client[A](f: AmazonS3 => A): A = { + val client: AmazonS3 = AmazonS3ClientBuilder.standard + .withEndpointConfiguration( + new EndpointConfiguration(TEST_S3_ENDPOINT, TEST_S3_REGION) + ) + .withCredentials( + new AWSStaticCredentialsProvider( + new BasicAWSCredentials( + TEST_S3_ACCESS_KEY_ID, + TEST_S3_SECRET_ACCESS_KEY + ) + ) + ) + .withPathStyleAccessEnabled(true) + .build() + + try f(client) + finally client.shutdown() + } } diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandler.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandler.scala index d8a4b73..5f4f609 100644 --- a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandler.scala +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandler.scala @@ -1,6 +1,5 @@ package org.embulk.output.s3_parquet.parquet - import org.embulk.spi.DataException import org.embulk.spi.`type`.Types import org.junit.runner.RunWith @@ -9,70 +8,77 @@ import org.scalatestplus.junit.JUnitRunner import scala.util.Try - @RunWith(classOf[JUnitRunner]) -class TestLogicalTypeHandler - extends FunSuite -{ - - test("IntLogicalTypeHandler.isConvertible() returns true for long") { - val h = Int8LogicalTypeHandler - - assert(h.isConvertible(Types.LONG)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test("IntLogicalTypeHandler.consume() raises DataException if given type is not long") { - val h = Int8LogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } - - - test("TimestampMillisLogicalTypeHandler.isConvertible() returns true for timestamp") { - val h = TimestampMillisLogicalTypeHandler - - assert(h.isConvertible(Types.TIMESTAMP)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test("TimestampMillisLogicalTypeHandler.consume() raises DataException if given type is not timestamp") { - val h = TimestampMillisLogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } - - - test("TimestampMicrosLogicalTypeHandler.isConvertible() returns true for timestamp") { - val h = TimestampMicrosLogicalTypeHandler - - assert(h.isConvertible(Types.TIMESTAMP)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test("TimestampMicrosLogicalTypeHandler.consume() raises DataException if given type is not timestamp") { - val h = TimestampMicrosLogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } - - test("JsonLogicalTypeHandler.isConvertible() returns true for json") { - val h = JsonLogicalTypeHandler - - assert(h.isConvertible(Types.JSON)) - assert(!h.isConvertible(Types.BOOLEAN)) - } - - test("JsonLogicalTypeHandler.consume() raises DataException if given type is not json") { - val h = JsonLogicalTypeHandler - val actual = Try(h.consume("invalid", null)) - assert(actual.isFailure) - assert(actual.failed.get.isInstanceOf[DataException]) - } +class TestLogicalTypeHandler extends FunSuite { + + test("IntLogicalTypeHandler.isConvertible() returns true for long") { + val h = Int8LogicalTypeHandler + + assert(h.isConvertible(Types.LONG)) + assert(!h.isConvertible(Types.BOOLEAN)) + } + + test( + "IntLogicalTypeHandler.consume() raises DataException if given type is not long" + ) { + val h = Int8LogicalTypeHandler + val actual = Try(h.consume("invalid", null)) + + assert(actual.isFailure) + assert(actual.failed.get.isInstanceOf[DataException]) + } + + test( + "TimestampMillisLogicalTypeHandler.isConvertible() returns true for timestamp" + ) { + val h = TimestampMillisLogicalTypeHandler + + assert(h.isConvertible(Types.TIMESTAMP)) + assert(!h.isConvertible(Types.BOOLEAN)) + } + + test( + "TimestampMillisLogicalTypeHandler.consume() raises DataException if given type is not timestamp" + ) { + val h = TimestampMillisLogicalTypeHandler + val actual = Try(h.consume("invalid", null)) + + assert(actual.isFailure) + assert(actual.failed.get.isInstanceOf[DataException]) + } + + test( + "TimestampMicrosLogicalTypeHandler.isConvertible() returns true for timestamp" + ) { + val h = TimestampMicrosLogicalTypeHandler + + assert(h.isConvertible(Types.TIMESTAMP)) + assert(!h.isConvertible(Types.BOOLEAN)) + } + + test( + "TimestampMicrosLogicalTypeHandler.consume() raises DataException if given type is not timestamp" + ) { + val h = TimestampMicrosLogicalTypeHandler + val actual = Try(h.consume("invalid", null)) + + assert(actual.isFailure) + assert(actual.failed.get.isInstanceOf[DataException]) + } + + test("JsonLogicalTypeHandler.isConvertible() returns true for json") { + val h = JsonLogicalTypeHandler + + assert(h.isConvertible(Types.JSON)) + assert(!h.isConvertible(Types.BOOLEAN)) + } + + test( + "JsonLogicalTypeHandler.consume() raises DataException if given type is not json" + ) { + val h = JsonLogicalTypeHandler + val actual = Try(h.consume("invalid", null)) + assert(actual.isFailure) + assert(actual.failed.get.isInstanceOf[DataException]) + } } diff --git a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandlerStore.scala b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandlerStore.scala index db0aa0d..7600492 100644 --- a/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandlerStore.scala +++ b/src/test/scala/org/embulk/output/s3_parquet/parquet/TestLogicalTypeHandlerStore.scala @@ -1,11 +1,13 @@ package org.embulk.output.s3_parquet.parquet - import java.util.Optional import com.google.common.base.{Optional => GOptional} import org.embulk.config.{ConfigException, TaskSource} -import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ColumnOptionTask, TypeOptionTask} +import org.embulk.output.s3_parquet.S3ParquetOutputPlugin.{ + ColumnOptionTask, + TypeOptionTask +} import org.embulk.spi.`type`.{Types, Type => EType} import org.junit.runner.RunWith import org.scalatest.FunSuite @@ -14,149 +16,164 @@ import org.scalatestplus.junit.JUnitRunner import scala.jdk.CollectionConverters._ import scala.util.Try - @RunWith(classOf[JUnitRunner]) -class TestLogicalTypeHandlerStore - extends FunSuite -{ - test("empty() returns empty maps") { - val rv = LogicalTypeHandlerStore.empty - - assert(rv.fromColumnName.isEmpty) - assert(rv.fromEmbulkType.isEmpty) +class TestLogicalTypeHandlerStore extends FunSuite { + test("empty() returns empty maps") { + val rv = LogicalTypeHandlerStore.empty + + assert(rv.fromColumnName.isEmpty) + assert(rv.fromEmbulkType.isEmpty) + } + + test("fromEmbulkOptions() returns handlers for valid option tasks") { + val typeOpts = Map[String, TypeOptionTask]( + "timestamp" -> DummyTypeOptionTask( + Optional.of[String]("timestamp-millis") + ) + ).asJava + val columnOpts = Map[String, ColumnOptionTask]( + "col1" -> DummyColumnOptionTask(Optional.of[String]("timestamp-micros")) + ).asJava + + val expected1 = Map[EType, LogicalTypeHandler]( + Types.TIMESTAMP -> TimestampMillisLogicalTypeHandler + ) + val expected2 = Map[String, LogicalTypeHandler]( + "col1" -> TimestampMicrosLogicalTypeHandler + ) + + val rv = LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) + + assert(rv.fromEmbulkType == expected1) + assert(rv.fromColumnName == expected2) + } + + test( + "fromEmbulkOptions() raises ConfigException if invalid option tasks given" + ) { + val emptyTypeOpts = Map.empty[String, TypeOptionTask].asJava + val emptyColumnOpts = Map.empty[String, ColumnOptionTask].asJava + + val invalidTypeOpts = Map[String, TypeOptionTask]( + "unknown-embulk-type-name" -> DummyTypeOptionTask( + Optional.of[String]("timestamp-millis") + ), + "timestamp" -> DummyTypeOptionTask( + Optional.of[String]("unknown-parquet-logical-type-name") + ) + ).asJava + val invalidColumnOpts = Map[String, ColumnOptionTask]( + "col1" -> DummyColumnOptionTask( + Optional.of[String]("unknown-parquet-logical-type-name") + ) + ).asJava + + val try1 = Try( + LogicalTypeHandlerStore + .fromEmbulkOptions(invalidTypeOpts, emptyColumnOpts) + ) + assert(try1.isFailure) + assert(try1.failed.get.isInstanceOf[ConfigException]) + + val try2 = Try( + LogicalTypeHandlerStore + .fromEmbulkOptions(emptyTypeOpts, invalidColumnOpts) + ) + assert(try2.isFailure) + assert(try2.failed.get.isInstanceOf[ConfigException]) + + val try3 = Try( + LogicalTypeHandlerStore + .fromEmbulkOptions(invalidTypeOpts, invalidColumnOpts) + ) + assert(try3.isFailure) + assert(try3.failed.get.isInstanceOf[ConfigException]) + } + + test("get() returns a handler matched with primary column name condition") { + val typeOpts = Map[String, TypeOptionTask]( + "timestamp" -> DummyTypeOptionTask( + Optional.of[String]("timestamp-millis") + ) + ).asJava + val columnOpts = Map[String, ColumnOptionTask]( + "col1" -> DummyColumnOptionTask(Optional.of[String]("timestamp-micros")) + ).asJava + + val handlers = + LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) + + // It matches both of column name and embulk type, and column name should be primary + val expected = Some(TimestampMicrosLogicalTypeHandler) + val actual = handlers.get("col1", Types.TIMESTAMP) + + assert(actual == expected) + } + + test("get() returns a handler matched with type name condition") { + val typeOpts = Map[String, TypeOptionTask]( + "timestamp" -> DummyTypeOptionTask( + Optional.of[String]("timestamp-millis") + ) + ).asJava + val columnOpts = Map.empty[String, ColumnOptionTask].asJava + + val handlers = + LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) + + // It matches column name + val expected = Some(TimestampMillisLogicalTypeHandler) + val actual = handlers.get("col1", Types.TIMESTAMP) + + assert(actual == expected) + } + + test("get() returns None if not matched") { + val typeOpts = Map.empty[String, TypeOptionTask].asJava + val columnOpts = Map.empty[String, ColumnOptionTask].asJava + + val handlers = + LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) + + // It matches embulk type + val actual = handlers.get("col1", Types.TIMESTAMP) + + assert(actual.isEmpty) + } + + private case class DummyTypeOptionTask(lt: Optional[String]) + extends TypeOptionTask { + + override def getLogicalType: Optional[String] = { + lt } - test("fromEmbulkOptions() returns handlers for valid option tasks") { - val typeOpts = Map[String, TypeOptionTask]( - "timestamp" -> DummyTypeOptionTask(Optional.of[String]("timestamp-millis")), - ).asJava - val columnOpts = Map[String, ColumnOptionTask]( - "col1" -> DummyColumnOptionTask(Optional.of[String]("timestamp-micros")), - ).asJava - - val expected1 = Map[EType, LogicalTypeHandler]( - Types.TIMESTAMP -> TimestampMillisLogicalTypeHandler, - ) - val expected2 = Map[String, LogicalTypeHandler]( - "col1" -> TimestampMicrosLogicalTypeHandler, - ) - - val rv = LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) - - assert(rv.fromEmbulkType == expected1) - assert(rv.fromColumnName == expected2) - } + override def validate(): Unit = {} - test("fromEmbulkOptions() raises ConfigException if invalid option tasks given") { - val emptyTypeOpts = Map.empty[String, TypeOptionTask].asJava - val emptyColumnOpts = Map.empty[String, ColumnOptionTask].asJava - - val invalidTypeOpts = Map[String, TypeOptionTask]( - "unknown-embulk-type-name" -> DummyTypeOptionTask(Optional.of[String]("timestamp-millis")), - "timestamp" -> DummyTypeOptionTask(Optional.of[String]("unknown-parquet-logical-type-name")), - ).asJava - val invalidColumnOpts = Map[String, ColumnOptionTask]( - "col1" -> DummyColumnOptionTask(Optional.of[String]("unknown-parquet-logical-type-name")), - ).asJava - - val try1 = Try(LogicalTypeHandlerStore.fromEmbulkOptions(invalidTypeOpts, emptyColumnOpts)) - assert(try1.isFailure) - assert(try1.failed.get.isInstanceOf[ConfigException]) - - val try2 = Try(LogicalTypeHandlerStore.fromEmbulkOptions(emptyTypeOpts, invalidColumnOpts)) - assert(try2.isFailure) - assert(try2.failed.get.isInstanceOf[ConfigException]) - - val try3 = Try(LogicalTypeHandlerStore.fromEmbulkOptions(invalidTypeOpts, invalidColumnOpts)) - assert(try3.isFailure) - assert(try3.failed.get.isInstanceOf[ConfigException]) + override def dump(): TaskSource = { + null } + } - test("get() returns a handler matched with primary column name condition") { - val typeOpts = Map[String, TypeOptionTask]( - "timestamp" -> DummyTypeOptionTask(Optional.of[String]("timestamp-millis")), - ).asJava - val columnOpts = Map[String, ColumnOptionTask]( - "col1" -> DummyColumnOptionTask(Optional.of[String]("timestamp-micros")), - ).asJava + private case class DummyColumnOptionTask(lt: Optional[String]) + extends ColumnOptionTask { - val handlers = LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) - - // It matches both of column name and embulk type, and column name should be primary - val expected = Some(TimestampMicrosLogicalTypeHandler) - val actual = handlers.get("col1", Types.TIMESTAMP) - - assert(actual == expected) + override def getTimeZoneId: GOptional[String] = { + GOptional.absent[String] } - test("get() returns a handler matched with type name condition") { - val typeOpts = Map[String, TypeOptionTask]( - "timestamp" -> DummyTypeOptionTask(Optional.of[String]("timestamp-millis")), - ).asJava - val columnOpts = Map.empty[String, ColumnOptionTask].asJava - - val handlers = LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) - - // It matches column name - val expected = Some(TimestampMillisLogicalTypeHandler) - val actual = handlers.get("col1", Types.TIMESTAMP) - - assert(actual == expected) + override def getFormat: GOptional[String] = { + GOptional.absent[String] } - test("get() returns None if not matched") { - val typeOpts = Map.empty[String, TypeOptionTask].asJava - val columnOpts = Map.empty[String, ColumnOptionTask].asJava - - val handlers = LogicalTypeHandlerStore.fromEmbulkOptions(typeOpts, columnOpts) - - // It matches embulk type - val actual = handlers.get("col1", Types.TIMESTAMP) - - assert(actual.isEmpty) + override def getLogicalType: Optional[String] = { + lt } - private case class DummyTypeOptionTask(lt: Optional[String]) - extends TypeOptionTask - { - override def getLogicalType: Optional[String] = - { - lt - } - - override def validate(): Unit = - {} - - override def dump(): TaskSource = - { - null - } - } + override def validate(): Unit = {} - private case class DummyColumnOptionTask(lt: Optional[String]) - extends ColumnOptionTask - { - override def getTimeZoneId: GOptional[String] = - { - GOptional.absent[String] - } - - override def getFormat: GOptional[String] = - { - GOptional.absent[String] - } - - override def getLogicalType: Optional[String] = - { - lt - } - - override def validate(): Unit = - {} - - override def dump(): TaskSource = - { - null - } + override def dump(): TaskSource = { + null } + } }