Skip to content

Commit

Permalink
Add test for decimal types support
Browse files Browse the repository at this point in the history
  • Loading branch information
akudiyar committed May 30, 2022
1 parent 281609c commit 2001ddc
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 8 deletions.
20 changes: 13 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ ThisBuild / developers := List(
ThisBuild / scalaVersion := scala211

val commonDependencies = Seq(
"io.tarantool" % "cartridge-driver" % "0.7.2",
"io.tarantool" % "cartridge-driver" % "0.8.0",
"junit" % "junit" % "4.12" % Test,
"com.github.sbt" % "junit-interface" % "0.12" % Test,
"org.testcontainers" % "testcontainers" % "1.16.0" % Test,
"io.tarantool" % "testcontainers-java-tarantool" % "0.4.7" % Test,
"org.testcontainers" % "testcontainers" % "1.17.0" % Test,
"io.tarantool" % "testcontainers-java-tarantool" % "0.5.0" % Test,
"org.scalatest" %% "scalatest" % "3.2.9" % Test,
"org.scalamock" %% "scalamock" % "5.1.0" % Test,
"com.dimafeng" %% "testcontainers-scala-scalatest" % "0.39.5" % Test,
"ch.qos.logback" % "logback-classic" % "1.2.5" % Test
"ch.qos.logback" % "logback-classic" % "1.2.5" % Test,
"org.apache.derby" % "derby" % "10.11.1.1" % Test
)

lazy val root = (project in file("."))
Expand All @@ -55,12 +56,14 @@ lazy val root = (project in file("."))
case Some((2, scalaMajor)) if scalaMajor >= 12 =>
Seq(
"org.apache.spark" %% "spark-core" % "2.4.8" % "provided",
"org.apache.spark" %% "spark-sql" % "2.4.8" % "provided"
"org.apache.spark" %% "spark-sql" % "2.4.8" % "provided",
"org.apache.spark" %% "spark-hive" % "2.4.8" % "provided"
)
case _ =>
Seq(
"org.apache.spark" %% "spark-core" % "2.2.3" % "provided",
"org.apache.spark" %% "spark-sql" % "2.2.3" % "provided"
"org.apache.spark" %% "spark-sql" % "2.2.3" % "provided",
"org.apache.spark" %% "spark-hive" % "2.2.3" % "provided"
)
}
}).map(
Expand Down Expand Up @@ -100,7 +103,10 @@ lazy val root = (project in file("."))
"UTF-8"
),
// Test frameworks options
testOptions += Tests.Argument(TestFrameworks.JUnit, "-v"),
testOptions ++= Seq(
Tests.Argument(TestFrameworks.JUnit, "-v"),
Tests.Setup(() => System.setSecurityManager(null)) // SPARK-22918
),
// Publishing settings
publishTo := {
val nexus = "https://oss.sonatype.org/"
Expand Down
26 changes: 26 additions & 0 deletions src/test/resources/cartridge/app/roles/api_storage.lua
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ local function init_space()
unique = false,
if_not_exists = true,
})

local reg_numbers = box.schema.space.create(
'reg_numbers',
{
format = {
{name = 'bucket_id' , type = 'unsigned' , is_nullable = false},
{name = 'idreg' , type = 'decimal' , is_nullable = false},
{name = 'regnum' , type = 'decimal' , is_nullable = true},
},
if_not_exists = true,
}
)

reg_numbers:create_index('index_id', {
unique = true,
type = 'tree',
parts = {{2, 'decimal'}},
if_not_exists = true,
})

reg_numbers:create_index('bucket_id', {
unique = false,
type = 'tree',
parts = {{1, 'unsigned'}},
if_not_exists = true,
})
end

local function init(opts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
sessionBuilder: SparkSession.Builder,
conf: SparkConf
): SparkSession.Builder = {
sessionBuilder.config(conf)
sessionBuilder.config(conf).enableHiveSupport()
sessionBuilder
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,55 @@ class TarantoolSparkWriteClusterTest
.save()
}
}

test("should write a Dataset to the space with decimal values") {
val space = "reg_numbers"

spark.sql("create database if not exists dl_raw")
spark.sql("drop table if exists DL_RAW.reg_numbers")

spark.sql("""
|create table if not exists DL_RAW.reg_numbers (
| bucket_id integer
| ,idreg decimal(38,18)
| ,regnum decimal(38)
| ) stored as orc""".stripMargin)
spark.sql("""
|insert into dl_raw.reg_numbers values
|(null, 1085529600000.13452690000413, 404503014700028),
|(null, 1086629600000.13452690000413, 404503015800028),
|(null, 1087430400000.13452690000413, 304503016900085)
|""".stripMargin)

val ds = spark.table("dl_raw.reg_numbers")

ds.show(false)
ds.printSchema()

ds.write
.format("org.apache.spark.sql.tarantool")
.option("tarantool.space", space)
.mode(SaveMode.Overwrite)
.save()

val actual = spark.sparkContext.tarantoolSpace(space, Conditions.any()).collect()
actual.length should equal(3)

actual(0).getDecimal("idreg") should equal(
BigDecimal("1085529600000.134526900004130000").bigDecimal
)
actual(0).getDecimal("regnum") should equal(BigDecimal("404503014700028").bigDecimal)

actual(1).getDecimal("idreg") should equal(
BigDecimal("1086629600000.134526900004130000").bigDecimal
)
actual(1).getDecimal("regnum") should equal(BigDecimal("404503015800028").bigDecimal)

actual(2).getDecimal("idreg") should equal(
BigDecimal("1087430400000.134526900004130000").bigDecimal
)
actual(2).getDecimal("regnum") should equal(BigDecimal("304503016900085").bigDecimal)
}
}

case class Order(
Expand Down

0 comments on commit 2001ddc

Please sign in to comment.