Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

FunctionDefSpec improvements #737

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion core/jvm/src/main/scala/zio/sql/expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ trait ExprModule extends NewtypesModule with FeaturesModule with OpsModule {
val Ascii = FunctionDef[String, Int](FunctionName("ascii"))
val CharLength = FunctionDef[String, Int](FunctionName("character_length"))
val Concat = FunctionDef[(String, String), String](FunctionName("concat")) // todo varargs
val ConcatWs2 = FunctionDef[(String, String), String](FunctionName("concat_ws"))
val ConcatWs3 = FunctionDef[(String, String, String), String](FunctionName("concat_ws"))
val ConcatWs4 = FunctionDef[(String, String, String, String), String](FunctionName("concat_ws"))
val Lower = FunctionDef[String, String](FunctionName("lower"))
Expand Down
313 changes: 313 additions & 0 deletions mysql/src/test/scala/zio/sql/mysql/CommonFunctionDefSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
package zio.sql.mysql

import zio.Cause
import zio.stream.ZStream
import zio.test.Assertion._
import zio.test._

object CommonFunctionDefSpec extends MysqlRunnableSpec with ShopSchema {
import FunctionDef.{ CharLength => _, _ }
import Customers._

private def collectAndCompare[R, E](
expected: Seq[String],
testResult: ZStream[R, E, String]
) =
assertZIO(testResult.runCollect)(hasSameElementsDistinct(expected))

override def specLayered = suite("MySQL Common FunctionDef")(
suite("Schema dependent tests")(
test("concat_ws #2 - combine columns") {

// note: you can't use customerId here as it is a UUID, hence not a string in our book
val query = select(ConcatWs3(Customers.fName, Customers.fName, Customers.lName)) from customers

val expected = Seq(
"RonaldRonaldRussell",
"TerrenceTerrenceNoel",
"MilaMilaPaterso",
"AlanaAlanaMurray",
"JoseJoseWiggins"
)

val testResult = execute(query)
collectAndCompare(expected, testResult)
},
test("concat_ws #3 - combine columns and flat values") {

val query = select(ConcatWs4(" ", "Person:", Customers.fName, Customers.lName)) from customers

val expected = Seq(
"Person: Ronald Russell",
"Person: Terrence Noel",
"Person: Mila Paterso",
"Person: Alana Murray",
"Person: Jose Wiggins"
)

val testResult = execute(query)
collectAndCompare(expected, testResult)
},
test("concat_ws #3 - combine function calls together") {

val query = select(
ConcatWs3(" and ", Concat("Name: ", Customers.fName), Concat("Surname: ", Customers.lName))
) from customers

val expected = Seq(
"Name: Ronald and Surname: Russell",
"Name: Terrence and Surname: Noel",
"Name: Mila and Surname: Paterso",
"Name: Alana and Surname: Murray",
"Name: Jose and Surname: Wiggins"
)

val testResult = execute(query)
collectAndCompare(expected, testResult)
},
test("lower") {
val query = select(Lower(fName)) from customers limit (1)

val expected = "ronald"

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("Can concat strings with concat function") {

val query = select(Concat(fName, lName) as "fullname") from customers

val expected = Seq("RonaldRussell", "TerrenceNoel", "MilaPaterso", "AlanaMurray", "JoseWiggins")

val result = execute(query)

val assertion = for {
r <- result.runCollect
} yield assert(r)(hasSameElementsDistinct(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("replace") {
val lastNameReplaced = Replace(lName, "ll", "_") as "lastNameReplaced"
val computedReplace = Replace("special ::ąę::", "ąę", "__") as "computedReplace"

val query = select(lastNameReplaced, computedReplace) from customers

val expected = ("Russe_", "special ::__::")

val testResult =
execute(query).map { case row =>
(row._1, row._2)
}

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
}
),
suite("Schema independent tests")(
test("concat_ws #1 - combine flat values") {

// note: a plain number (3) would and should not compile
val query = select(ConcatWs4("+", "1", "2", "3"))

val expected = Seq("1+2+3")

val testResult = execute(query)
collectAndCompare(expected, testResult)
},
test("ltrim") {
assertZIO(execute(select(Ltrim(" hello "))).runHead.some)(equalTo("hello "))
},
test("rtrim") {
assertZIO(execute(select(Rtrim(" hello "))).runHead.some)(equalTo(" hello"))
},
test("abs") {
assertZIO(execute(select(Abs(-3.14159))).runHead.some)(equalTo(3.14159))
},
test("log") {
assertZIO(execute(select(Log(2.0, 32.0))).runHead.some)(equalTo(5.0))
},
test("acos") {
assertZIO(execute(select(Acos(-1.0))).runHead.some)(equalTo(3.141592653589793))
},
test("asin") {
assertZIO(execute(select(Asin(0.5))).runHead.some)(equalTo(0.5235987755982989))
},
test("ln") {
assertZIO(execute(select(Ln(3.0))).runHead.some)(equalTo(1.0986122886681097))
},
test("atan") {
assertZIO(execute(select(Atan(10.0))).runHead.some)(equalTo(1.4711276743037347))
},
test("cos") {
assertZIO(execute(select(Cos(3.141592653589793))).runHead.some)(equalTo(-1.0))
},
test("exp") {
assertZIO(execute(select(Exp(1.0))).runHead.some)(equalTo(2.718281828459045))
},
test("floor") {
assertZIO(execute(select(Floor(-3.14159))).runHead.some)(equalTo(-4.0))
},
test("ceil") {
assertZIO(execute(select(Ceil(53.7), Ceil(-53.7))).runHead.some)(equalTo((54.0, -53.0)))
},
test("sin") {
assertZIO(execute(select(Sin(1.0))).runHead.some)(equalTo(0.8414709848078965))
},
test("sqrt") {
val query = select(Sqrt(121.0))

val expected = 11.0

val testResult = execute(query)

assertZIO(testResult.runHead.some)(equalTo(expected))
},
test("round") {
val query = select(Round(10.8124, 2))

val expected = 10.81

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("sign positive") {
val query = select(Sign(3.0))

val expected = 1

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("sign negative") {
val query = select(Sign(-3.0))

val expected = -1

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("sign zero") {
val query = select(Sign(0.0))

val expected = 0

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("power") {
val query = select(Power(7.0, 3.0))

val expected = 343.000000000000000

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("mod") {
val query = select(Mod(-15.0, -4.0))

val expected = -3.0

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("octet_length") {
val query = select(OctetLength("josé"))

val expected = 5

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("ascii") {
val query = select(Ascii("""x"""))

val expected = 120

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("upper") {
val query = (select(Upper("ronald"))).limit(1)

val expected = "RONALD"

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("tan") {
val query = select(Tan(0.7853981634))

val expected = 1.0000000000051035

val testResult = execute(query)

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("trim") {
assertZIO(execute(select(Trim(" 1234 "))).runHead.some)(equalTo("1234"))
},
test("lower") {
assertZIO(execute(select(Lower("YES"))).runHead.some)(equalTo("yes"))
}
)
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,12 @@ import zio.test.Assertion._
import java.time.{ LocalDate, LocalTime, ZoneId }
import java.time.format.DateTimeFormatter

object FunctionDefSpec extends MysqlRunnableSpec with ShopSchema {
object CustomFunctionDefSpec extends MysqlRunnableSpec with ShopSchema {

import Customers._
import FunctionDef._
import MysqlFunctionDef._

override def specLayered = suite("MySQL FunctionDef")(
test("lower") {
val query = select(Lower(fName)) from customers limit (1)

val expected = "ronald"

val testResult = execute(query)

assertZIO(testResult.runHead.some)(equalTo(expected))
},
// FIXME: lower with string literal should not refer to a column name
// See: https://www.w3schools.com/sql/trymysql.asp?filename=trysql_func_mysql_lower
// Uncomment the following test when fixed
// test("lower with string literal") {
// val query = select(Lower("LOWER")) from customers limit(1)
//
// val expected = "lower"
//
// val testResult = execute(query.to[String, String](identity))
//
// val assertion = for {
// r <- testResult.runCollect
// } yield assert(r.head)(equalTo(expected))
//
// assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
// },
test("sin") {
val query = select(Sin(1.0))

val expected = 0.8414709848078965

val testResult = execute(query)

assertZIO(testResult.runHead.some)(equalTo(expected))
},
test("abs") {
val query = select(Abs(-32.0))

val expected = 32.0

val testResult = execute(query)

assertZIO(testResult.runHead.some)(equalTo(expected))
},
test("crc32") {
val query = select(Crc32("MySQL")) from customers

Expand Down
Loading