From 281609c8a2058070743d27f61bbbd4f043a950a2 Mon Sep 17 00:00:00 2001 From: Alexey Kuzin Date: Mon, 23 May 2022 17:30:05 -0400 Subject: [PATCH] Support saving DataFrames with named expressions in SQL queries Before this patch, named expressions haven't been taken into account, so the fields were matched only by order and not by their names in the query. --- build.sbt | 2 +- .../spark/connector/rdd/TarantoolRDD.scala | 6 +- .../spark/connector/util/StringUtils.scala | 32 ++++++++++ .../spark/sql/tarantool/MapFunctions.scala | 25 +++++++- src/test/resources/Dockerfile | 26 +++++++++ src/test/resources/cartridge/replicasets.yml | 37 ++++++++++++ .../TarantoolSparkWriteClusterTest.scala | 58 +++++++++++++++++++ .../connector/util/StringUtilsSpec.scala | 21 +++++++ 8 files changed, 201 insertions(+), 6 deletions(-) create mode 100644 src/main/scala/io/tarantool/spark/connector/util/StringUtils.scala create mode 100644 src/test/resources/Dockerfile create mode 100644 src/test/resources/cartridge/replicasets.yml create mode 100644 src/test/scala/io/tarantool/spark/connector/util/StringUtilsSpec.scala diff --git a/build.sbt b/build.sbt index 5a4569d..c0abbb6 100644 --- a/build.sbt +++ b/build.sbt @@ -33,7 +33,7 @@ ThisBuild / developers := List( ThisBuild / scalaVersion := scala211 val commonDependencies = Seq( - "io.tarantool" % "cartridge-driver" % "0.7.0", + "io.tarantool" % "cartridge-driver" % "0.7.2", "junit" % "junit" % "4.12" % Test, "com.github.sbt" % "junit-interface" % "0.12" % Test, "org.testcontainers" % "testcontainers" % "1.16.0" % Test, diff --git a/src/main/scala/io/tarantool/spark/connector/rdd/TarantoolRDD.scala b/src/main/scala/io/tarantool/spark/connector/rdd/TarantoolRDD.scala index 32e3b5f..d20c305 100644 --- a/src/main/scala/io/tarantool/spark/connector/rdd/TarantoolRDD.scala +++ b/src/main/scala/io/tarantool/spark/connector/rdd/TarantoolRDD.scala @@ -50,10 +50,6 @@ class TarantoolRDD[R] private[spark] ( private val globalConfig = TarantoolConfig(sparkContext.getConf) - @transient private lazy val tupleFactory = new DefaultTarantoolTupleFactory( - messagePackMapper - ) - override def compute(split: Partition, context: TaskContext): Iterator[R] = { val partition = split.asInstanceOf[TarantoolPartition] val connection = TarantoolConnection() @@ -96,6 +92,8 @@ class TarantoolRDD[R] private[spark] ( data.foreachPartition((partition: Iterator[Row]) => if (partition.nonEmpty) { val client = connection.client(globalConfig) + val spaceMetadata = client.metadata().getSpaceByName(space).get() + val tupleFactory = new DefaultTarantoolTupleFactory(messagePackMapper, spaceMetadata) var rowCount: Long = 0 val failedRowsExceptions: ListBuffer[Throwable] = ListBuffer() diff --git a/src/main/scala/io/tarantool/spark/connector/util/StringUtils.scala b/src/main/scala/io/tarantool/spark/connector/util/StringUtils.scala new file mode 100644 index 0000000..7846c46 --- /dev/null +++ b/src/main/scala/io/tarantool/spark/connector/util/StringUtils.scala @@ -0,0 +1,32 @@ +package io.tarantool.spark.connector.util + +import scala.annotation.tailrec + +/** + * Provides helper methods for transforming String instances + * + * @author Alexey Kuzin + */ +object StringUtils { + + /** + * Converts from camelCase to snake_case + * e.g.: camelCase => camel_case + * + * Borrowed from https://gist.github.com/sidharthkuruvila/3154845?permalink_comment_id=2622928#gistcomment-2622928 + * + * @param name the camelCase name to convert + * @return snake_case version of the string passed + */ + def camelToSnake(name: String): String = { + @tailrec + def go(accDone: List[Char], acc: List[Char]): List[Char] = acc match { + case Nil => accDone + case a :: b :: c :: tail if a.isUpper && b.isUpper && c.isLower => + go(accDone ++ List(a, '_', b, c), tail) + case a :: b :: tail if a.isLower && b.isUpper => go(accDone ++ List(a, '_', b), tail) + case a :: tail => go(accDone :+ a, tail) + } + go(Nil, name.toList).mkString.toLowerCase + } +} diff --git a/src/main/scala/org/apache/spark/sql/tarantool/MapFunctions.scala b/src/main/scala/org/apache/spark/sql/tarantool/MapFunctions.scala index 9abd80d..6e2768e 100644 --- a/src/main/scala/org/apache/spark/sql/tarantool/MapFunctions.scala +++ b/src/main/scala/org/apache/spark/sql/tarantool/MapFunctions.scala @@ -2,6 +2,7 @@ package org.apache.spark.sql.tarantool import io.tarantool.driver.api.tuple.{TarantoolField, TarantoolTuple, TarantoolTupleFactory} import io.tarantool.driver.mappers.MessagePackValueMapper +import io.tarantool.spark.connector.util.StringUtils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ @@ -26,6 +27,9 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConv */ object MapFunctions { + @transient private lazy val tupleNamesCache: scala.collection.mutable.Map[String, String] = + scala.collection.concurrent.TrieMap() + def tupleToRow( tuple: TarantoolTuple, mapper: MessagePackValueMapper, @@ -89,12 +93,31 @@ object MapFunctions { } def rowToTuple(tupleFactory: TarantoolTupleFactory, row: Row): TarantoolTuple = - tupleFactory.create( + Option(row.schema) match { + case Some(schema) => rowWithSchemaToTuple(tupleFactory, row) + case None => rowWithoutSchemaToTuple(tupleFactory, row) + } + + def rowWithSchemaToTuple(tupleFactory: TarantoolTupleFactory, row: Row): TarantoolTuple = { + val tuple = tupleFactory.create() + row.getValuesMap[Any](row.schema.fieldNames).foreach { (pair) => + tuple.putObject(transformSchemaFieldName(pair._1), mapToJavaValue(Option(pair._2)).orNull) + } + tuple + } + + def transformSchemaFieldName(fieldName: String): String = + tupleNamesCache.getOrElseUpdate(fieldName, StringUtils.camelToSnake(fieldName)) + + def rowWithoutSchemaToTuple(tupleFactory: TarantoolTupleFactory, row: Row): TarantoolTuple = { + val tuple = tupleFactory.create( row.toSeq .map(value => mapToJavaValue(Option(value))) .map(nullableValue => nullableValue.orNull) .asJava ) + tuple + } def mapToJavaValue(value: Option[Any]): Option[Any] = if (value.isDefined) { diff --git a/src/test/resources/Dockerfile b/src/test/resources/Dockerfile new file mode 100644 index 0000000..17d21dc --- /dev/null +++ b/src/test/resources/Dockerfile @@ -0,0 +1,26 @@ +FROM tgagor/centos:stream8 AS tarantool-base +ARG TARANTOOL_VERSION=2.8 +ARG TARANTOOL_SERVER_USER="tarantool" +ARG TARANTOOL_SERVER_UID=1000 +ARG TARANTOOL_SERVER_GROUP="tarantool" +ARG TARANTOOL_SERVER_GID=1000 +ARG TARANTOOL_WORKDIR="/app" +ARG TARANTOOL_RUNDIR="/tmp/run" +ARG TARANTOOL_DATADIR="/tmp/data" +ARG TARANTOOL_INSTANCES_FILE="./instances.yml" +ENV TARANTOOL_WORKDIR=$TARANTOOL_WORKDIR +ENV TARANTOOL_RUNDIR=$TARANTOOL_RUNDIR +ENV TARANTOOL_DATADIR=$TARANTOOL_DATADIR +ENV TARANTOOL_INSTANCES_FILE=$TARANTOOL_INSTANCES_FILE +RUN curl -L https://tarantool.io/installer.sh | VER=$TARANTOOL_VERSION /bin/bash -s -- --repo-only && \ + yum -y install cmake make gcc gcc-c++ git unzip tarantool tarantool-devel cartridge-cli && \ + yum clean all +RUN groupadd -g $TARANTOOL_SERVER_GID $TARANTOOL_SERVER_GROUP && \ + useradd -u $TARANTOOL_SERVER_UID -g $TARANTOOL_SERVER_GID -m -s /bin/bash $TARANTOOL_SERVER_USER \ + || true +USER $TARANTOOL_SERVER_USER:$TARANTOOL_SERVER_GROUP +RUN cartridge version + +FROM tarantool-base AS cartridge-base +WORKDIR $TARANTOOL_WORKDIR +CMD cartridge build && cartridge start --run-dir=$TARANTOOL_RUNDIR --data-dir=$TARANTOOL_DATADIR --cfg=$TARANTOOL_INSTANCES_FILE \ No newline at end of file diff --git a/src/test/resources/cartridge/replicasets.yml b/src/test/resources/cartridge/replicasets.yml new file mode 100644 index 0000000..9975ac1 --- /dev/null +++ b/src/test/resources/cartridge/replicasets.yml @@ -0,0 +1,37 @@ +app-router: + instances: + - router + roles: + - vshard-router + - crud-router + - app.roles.api_router + all_rw: false +app-router-second: + instances: + - second-router + roles: + - vshard-router + - crud-router + - app.roles.api_router + all_rw: false +s1-storage: + instances: + - s1-master + roles: + - vshard-storage + - crud-storage + - app.roles.api_storage + weight: 1 + all_rw: false + vshard_group: default +s2-storage: + instances: + - s2-master + roles: + - vshard-storage + - crud-storage + - app.roles.api_storage + weight: 1 + all_rw: false + vshard_group: default + diff --git a/src/test/scala/io/tarantool/spark/connector/integration/TarantoolSparkWriteClusterTest.scala b/src/test/scala/io/tarantool/spark/connector/integration/TarantoolSparkWriteClusterTest.scala index d9a62ae..72eb849 100644 --- a/src/test/scala/io/tarantool/spark/connector/integration/TarantoolSparkWriteClusterTest.scala +++ b/src/test/scala/io/tarantool/spark/connector/integration/TarantoolSparkWriteClusterTest.scala @@ -2,7 +2,10 @@ package io.tarantool.spark.connector.integration import io.tarantool.driver.api.conditions.Conditions import io.tarantool.driver.api.tuple.{DefaultTarantoolTupleFactory, TarantoolTuple} +import io.tarantool.driver.exceptions.TarantoolException import io.tarantool.driver.mappers.DefaultMessagePackMapperFactory +import io.tarantool.spark.connector.connection.TarantoolConnection +import io.tarantool.spark.connector.config.TarantoolConfig import io.tarantool.spark.connector.toSparkContextFunctions import org.apache.spark.SparkException import org.apache.spark.sql.{Encoders, Row, SaveMode} @@ -185,6 +188,61 @@ class TarantoolSparkWriteClusterTest actual.foreach(item => item.getString("order_type") should endWith("555")) } + test("should write a Dataset to the space with field names mapping") { + val space = "test_space" + + var ds = spark.sql( + """ + |select 1 as id, null as bucketId, 'Don Quixote' as bookName, 'Miguel de Cervantes' as author, 1605 as year union all + |select 2, null, 'The Great Gatsby', 'F. Scott Fitzgerald', 1925 union all + |select 2, null, 'War and Peace', 'Leo Tolstoy', 1869 + |""".stripMargin + ) + + val ex = intercept[SparkException] { + ds.write + .format("org.apache.spark.sql.tarantool") + .mode(SaveMode.Append) + .option("tarantool.space", space) + .save() + } + ex.getMessage should include( + "Tuple field 3 (unique_key) type does not match one required by operation: expected string, got nil" + ) + + ds = spark.sql( + """ + |select 1 as id, null as bucketId, 'Miguel de Cervantes' as author, 1605 as year, 'Don Quixote' as bookName, 'lolkek' as uniqueKey union all + |select 2, null, 'F. Scott Fitzgerald', 1925, 'The Great Gatsby', 'lolkek1' union all + |select 3, null, 'Leo Tolstoy', 1869, 'War and Peace', 'lolkek2' + |""".stripMargin + ) + + ds.write + .format("org.apache.spark.sql.tarantool") + .mode(SaveMode.Append) + .option("tarantool.space", space) + .save() + + val actual = spark.sparkContext.tarantoolSpace(space, Conditions.any()).collect() + actual.length should equal(3) + + actual(0).getString("author") should equal("Miguel de Cervantes") + actual(0).getString("book_name") should equal("Don Quixote") + actual(0).getInteger("year") should equal(1605) + actual(0).getString("unique_key") should equal("lolkek") + + actual(1).getString("author") should equal("F. Scott Fitzgerald") + actual(1).getString("book_name") should equal("The Great Gatsby") + actual(1).getInteger("year") should equal(1925) + actual(1).getString("unique_key") should equal("lolkek1") + + actual(2).getString("author") should equal("Leo Tolstoy") + actual(2).getString("book_name") should equal("War and Peace") + actual(2).getInteger("year") should equal(1869) + actual(2).getString("unique_key") should equal("lolkek2") + } + test("should throw an exception if the space name is not specified") { assertThrows[IllegalArgumentException] { val orders = Range(1, 10).map(i => Order(i)) diff --git a/src/test/scala/io/tarantool/spark/connector/util/StringUtilsSpec.scala b/src/test/scala/io/tarantool/spark/connector/util/StringUtilsSpec.scala new file mode 100644 index 0000000..2a6452f --- /dev/null +++ b/src/test/scala/io/tarantool/spark/connector/util/StringUtilsSpec.scala @@ -0,0 +1,21 @@ +package io.tarantool.spark.connector.util + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class StringUtilsSpec extends AnyFlatSpec with Matchers { + "camelToSnake" should "process right camel case + all upper case + mixed" in { + StringUtils.camelToSnake("COLUMN") shouldBe "column" + StringUtils.camelToSnake("someColumnNameRespectingCamel") shouldBe "some_column_name_respecting_camel" + StringUtils.camelToSnake("columnWITHSomeALLUppercaseWORDS") shouldBe "column_with_some_all_uppercase_words" + StringUtils.camelToSnake("_column") shouldBe "_column" + StringUtils.camelToSnake("column_") shouldBe "column_" + StringUtils.camelToSnake("Column") shouldBe "column" + StringUtils.camelToSnake("Column_S") shouldBe "column_s" + StringUtils.camelToSnake("Column_SV") shouldBe "column_sv" + StringUtils.camelToSnake("Column_String") shouldBe "column_string" + StringUtils.camelToSnake("_columnS") shouldBe "_column_s" + StringUtils.camelToSnake("column1234") shouldBe "column1234" + StringUtils.camelToSnake("column1234string") shouldBe "column1234string" + } +}