Skip to content

Commit

Permalink
Support saving DataFrames with named expressions in SQL queries
Browse files Browse the repository at this point in the history
Before this patch, named expressions haven't been taken into account,
so the fields were matched only by order and not by their names in the
query.
  • Loading branch information
akudiyar committed May 28, 2022
1 parent 1beb35e commit 281609c
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 6 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ThisBuild / developers := List(
ThisBuild / scalaVersion := scala211

val commonDependencies = Seq(
"io.tarantool" % "cartridge-driver" % "0.7.0",
"io.tarantool" % "cartridge-driver" % "0.7.2",
"junit" % "junit" % "4.12" % Test,
"com.github.sbt" % "junit-interface" % "0.12" % Test,
"org.testcontainers" % "testcontainers" % "1.16.0" % Test,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ class TarantoolRDD[R] private[spark] (

private val globalConfig = TarantoolConfig(sparkContext.getConf)

@transient private lazy val tupleFactory = new DefaultTarantoolTupleFactory(
messagePackMapper
)

override def compute(split: Partition, context: TaskContext): Iterator[R] = {
val partition = split.asInstanceOf[TarantoolPartition]
val connection = TarantoolConnection()
Expand Down Expand Up @@ -96,6 +92,8 @@ class TarantoolRDD[R] private[spark] (
data.foreachPartition((partition: Iterator[Row]) =>
if (partition.nonEmpty) {
val client = connection.client(globalConfig)
val spaceMetadata = client.metadata().getSpaceByName(space).get()
val tupleFactory = new DefaultTarantoolTupleFactory(messagePackMapper, spaceMetadata)

var rowCount: Long = 0
val failedRowsExceptions: ListBuffer[Throwable] = ListBuffer()
Expand Down
32 changes: 32 additions & 0 deletions src/main/scala/io/tarantool/spark/connector/util/StringUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.tarantool.spark.connector.util

import scala.annotation.tailrec

/**
* Provides helper methods for transforming String instances
*
* @author Alexey Kuzin
*/
object StringUtils {

/**
* Converts from camelCase to snake_case
* e.g.: camelCase => camel_case
*
* Borrowed from https://gist.github.com/sidharthkuruvila/3154845?permalink_comment_id=2622928#gistcomment-2622928
*
* @param name the camelCase name to convert
* @return snake_case version of the string passed
*/
def camelToSnake(name: String): String = {
@tailrec
def go(accDone: List[Char], acc: List[Char]): List[Char] = acc match {
case Nil => accDone
case a :: b :: c :: tail if a.isUpper && b.isUpper && c.isLower =>
go(accDone ++ List(a, '_', b, c), tail)
case a :: b :: tail if a.isLower && b.isUpper => go(accDone ++ List(a, '_', b), tail)
case a :: tail => go(accDone :+ a, tail)
}
go(Nil, name.toList).mkString.toLowerCase
}
}
25 changes: 24 additions & 1 deletion src/main/scala/org/apache/spark/sql/tarantool/MapFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.apache.spark.sql.tarantool

import io.tarantool.driver.api.tuple.{TarantoolField, TarantoolTuple, TarantoolTupleFactory}
import io.tarantool.driver.mappers.MessagePackValueMapper
import io.tarantool.spark.connector.util.StringUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
Expand All @@ -26,6 +27,9 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConv
*/
object MapFunctions {

@transient private lazy val tupleNamesCache: scala.collection.mutable.Map[String, String] =
scala.collection.concurrent.TrieMap()

def tupleToRow(
tuple: TarantoolTuple,
mapper: MessagePackValueMapper,
Expand Down Expand Up @@ -89,12 +93,31 @@ object MapFunctions {
}

def rowToTuple(tupleFactory: TarantoolTupleFactory, row: Row): TarantoolTuple =
tupleFactory.create(
Option(row.schema) match {
case Some(schema) => rowWithSchemaToTuple(tupleFactory, row)
case None => rowWithoutSchemaToTuple(tupleFactory, row)
}

def rowWithSchemaToTuple(tupleFactory: TarantoolTupleFactory, row: Row): TarantoolTuple = {
val tuple = tupleFactory.create()
row.getValuesMap[Any](row.schema.fieldNames).foreach { (pair) =>
tuple.putObject(transformSchemaFieldName(pair._1), mapToJavaValue(Option(pair._2)).orNull)
}
tuple
}

def transformSchemaFieldName(fieldName: String): String =
tupleNamesCache.getOrElseUpdate(fieldName, StringUtils.camelToSnake(fieldName))

def rowWithoutSchemaToTuple(tupleFactory: TarantoolTupleFactory, row: Row): TarantoolTuple = {
val tuple = tupleFactory.create(
row.toSeq
.map(value => mapToJavaValue(Option(value)))
.map(nullableValue => nullableValue.orNull)
.asJava
)
tuple
}

def mapToJavaValue(value: Option[Any]): Option[Any] =
if (value.isDefined) {
Expand Down
26 changes: 26 additions & 0 deletions src/test/resources/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
FROM tgagor/centos:stream8 AS tarantool-base
ARG TARANTOOL_VERSION=2.8
ARG TARANTOOL_SERVER_USER="tarantool"
ARG TARANTOOL_SERVER_UID=1000
ARG TARANTOOL_SERVER_GROUP="tarantool"
ARG TARANTOOL_SERVER_GID=1000
ARG TARANTOOL_WORKDIR="/app"
ARG TARANTOOL_RUNDIR="/tmp/run"
ARG TARANTOOL_DATADIR="/tmp/data"
ARG TARANTOOL_INSTANCES_FILE="./instances.yml"
ENV TARANTOOL_WORKDIR=$TARANTOOL_WORKDIR
ENV TARANTOOL_RUNDIR=$TARANTOOL_RUNDIR
ENV TARANTOOL_DATADIR=$TARANTOOL_DATADIR
ENV TARANTOOL_INSTANCES_FILE=$TARANTOOL_INSTANCES_FILE
RUN curl -L https://tarantool.io/installer.sh | VER=$TARANTOOL_VERSION /bin/bash -s -- --repo-only && \
yum -y install cmake make gcc gcc-c++ git unzip tarantool tarantool-devel cartridge-cli && \
yum clean all
RUN groupadd -g $TARANTOOL_SERVER_GID $TARANTOOL_SERVER_GROUP && \
useradd -u $TARANTOOL_SERVER_UID -g $TARANTOOL_SERVER_GID -m -s /bin/bash $TARANTOOL_SERVER_USER \
|| true
USER $TARANTOOL_SERVER_USER:$TARANTOOL_SERVER_GROUP
RUN cartridge version

FROM tarantool-base AS cartridge-base
WORKDIR $TARANTOOL_WORKDIR
CMD cartridge build && cartridge start --run-dir=$TARANTOOL_RUNDIR --data-dir=$TARANTOOL_DATADIR --cfg=$TARANTOOL_INSTANCES_FILE
37 changes: 37 additions & 0 deletions src/test/resources/cartridge/replicasets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
app-router:
instances:
- router
roles:
- vshard-router
- crud-router
- app.roles.api_router
all_rw: false
app-router-second:
instances:
- second-router
roles:
- vshard-router
- crud-router
- app.roles.api_router
all_rw: false
s1-storage:
instances:
- s1-master
roles:
- vshard-storage
- crud-storage
- app.roles.api_storage
weight: 1
all_rw: false
vshard_group: default
s2-storage:
instances:
- s2-master
roles:
- vshard-storage
- crud-storage
- app.roles.api_storage
weight: 1
all_rw: false
vshard_group: default

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package io.tarantool.spark.connector.integration

import io.tarantool.driver.api.conditions.Conditions
import io.tarantool.driver.api.tuple.{DefaultTarantoolTupleFactory, TarantoolTuple}
import io.tarantool.driver.exceptions.TarantoolException
import io.tarantool.driver.mappers.DefaultMessagePackMapperFactory
import io.tarantool.spark.connector.connection.TarantoolConnection
import io.tarantool.spark.connector.config.TarantoolConfig
import io.tarantool.spark.connector.toSparkContextFunctions
import org.apache.spark.SparkException
import org.apache.spark.sql.{Encoders, Row, SaveMode}
Expand Down Expand Up @@ -185,6 +188,61 @@ class TarantoolSparkWriteClusterTest
actual.foreach(item => item.getString("order_type") should endWith("555"))
}

test("should write a Dataset to the space with field names mapping") {
val space = "test_space"

var ds = spark.sql(
"""
|select 1 as id, null as bucketId, 'Don Quixote' as bookName, 'Miguel de Cervantes' as author, 1605 as year union all
|select 2, null, 'The Great Gatsby', 'F. Scott Fitzgerald', 1925 union all
|select 2, null, 'War and Peace', 'Leo Tolstoy', 1869
|""".stripMargin
)

val ex = intercept[SparkException] {
ds.write
.format("org.apache.spark.sql.tarantool")
.mode(SaveMode.Append)
.option("tarantool.space", space)
.save()
}
ex.getMessage should include(
"Tuple field 3 (unique_key) type does not match one required by operation: expected string, got nil"
)

ds = spark.sql(
"""
|select 1 as id, null as bucketId, 'Miguel de Cervantes' as author, 1605 as year, 'Don Quixote' as bookName, 'lolkek' as uniqueKey union all
|select 2, null, 'F. Scott Fitzgerald', 1925, 'The Great Gatsby', 'lolkek1' union all
|select 3, null, 'Leo Tolstoy', 1869, 'War and Peace', 'lolkek2'
|""".stripMargin
)

ds.write
.format("org.apache.spark.sql.tarantool")
.mode(SaveMode.Append)
.option("tarantool.space", space)
.save()

val actual = spark.sparkContext.tarantoolSpace(space, Conditions.any()).collect()
actual.length should equal(3)

actual(0).getString("author") should equal("Miguel de Cervantes")
actual(0).getString("book_name") should equal("Don Quixote")
actual(0).getInteger("year") should equal(1605)
actual(0).getString("unique_key") should equal("lolkek")

actual(1).getString("author") should equal("F. Scott Fitzgerald")
actual(1).getString("book_name") should equal("The Great Gatsby")
actual(1).getInteger("year") should equal(1925)
actual(1).getString("unique_key") should equal("lolkek1")

actual(2).getString("author") should equal("Leo Tolstoy")
actual(2).getString("book_name") should equal("War and Peace")
actual(2).getInteger("year") should equal(1869)
actual(2).getString("unique_key") should equal("lolkek2")
}

test("should throw an exception if the space name is not specified") {
assertThrows[IllegalArgumentException] {
val orders = Range(1, 10).map(i => Order(i))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.tarantool.spark.connector.util

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class StringUtilsSpec extends AnyFlatSpec with Matchers {
"camelToSnake" should "process right camel case + all upper case + mixed" in {
StringUtils.camelToSnake("COLUMN") shouldBe "column"
StringUtils.camelToSnake("someColumnNameRespectingCamel") shouldBe "some_column_name_respecting_camel"
StringUtils.camelToSnake("columnWITHSomeALLUppercaseWORDS") shouldBe "column_with_some_all_uppercase_words"
StringUtils.camelToSnake("_column") shouldBe "_column"
StringUtils.camelToSnake("column_") shouldBe "column_"
StringUtils.camelToSnake("Column") shouldBe "column"
StringUtils.camelToSnake("Column_S") shouldBe "column_s"
StringUtils.camelToSnake("Column_SV") shouldBe "column_sv"
StringUtils.camelToSnake("Column_String") shouldBe "column_string"
StringUtils.camelToSnake("_columnS") shouldBe "_column_s"
StringUtils.camelToSnake("column1234") shouldBe "column1234"
StringUtils.camelToSnake("column1234string") shouldBe "column1234string"
}
}

0 comments on commit 281609c

Please sign in to comment.