Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NU-1768] Auto convert SQL enricher returning types #6586

Merged
merged 10 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package pl.touk.nussknacker.sql.db.ignite

import com.typesafe.scalalogging.LazyLogging
import pl.touk.nussknacker.engine.api.typed.typing.Typed
import pl.touk.nussknacker.sql.db.schema.TableDefinition

import java.sql.{Connection, PreparedStatement, ResultSet}
Expand Down Expand Up @@ -29,9 +28,8 @@ class IgniteQueryHelper(getConnection: () => Connection) extends LazyLogging {
}.groupBy { case (tableName, _, _, _) => tableName }
.map { case (tableName, entries) =>
val columnTypings = entries.map { case (_, columnName, klassName, _) =>
columnName -> Typed.typedClass(Class.forName(klassName))
columnName -> klassName
}

tableName -> TableDefinition.applyList(columnTypings)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ trait QueryExecutor {

def execute(statement: PreparedStatement): QueryResult

protected def toTypedMap(tableDef: TableDefinition, resultSet: ResultSet): TypedMap = {
protected def toTypedMap(tableDef: TableDefinition, resultSet: ResultSet): java.util.Map[String, Any] = {
mk-software-pl marked this conversation as resolved.
Show resolved Hide resolved
val fields = tableDef.columnDefs.map { columnDef =>
// we could here use method resultSet.getObject(Int) and pass column number as argument
// but in case of ignite db it is not certain which column index corresponds to which column.
columnDef.name -> resultSet.getObject(columnDef.name)
columnDef.name -> columnDef.extractValue(resultSet)
}.toMap
TypedMap(fields)
}
Expand All @@ -35,7 +33,7 @@ class UpdateQueryExecutor extends QueryExecutor {

class SingleResultQueryExecutor(tableDef: TableDefinition) extends QueryExecutor {

override type QueryResult = TypedMap
override type QueryResult = java.util.Map[String, Any]

def execute(statement: PreparedStatement): QueryResult = {
val resultSet = statement.executeQuery()
Expand All @@ -49,11 +47,11 @@ class SingleResultQueryExecutor(tableDef: TableDefinition) extends QueryExecutor

class ResultSetQueryExecutor(tableDef: TableDefinition) extends QueryExecutor {

override type QueryResult = java.util.List[TypedMap]
override type QueryResult = java.util.List[java.util.Map[String, Any]]

override def execute(statement: PreparedStatement): QueryResult = {
val resultSet = statement.executeQuery()
val results = new util.ArrayList[TypedMap]()
val results = new util.ArrayList[java.util.Map[String, Any]]()
while (resultSet.next()) {
results add toTypedMap(tableDef, resultSet)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,75 @@ package pl.touk.nussknacker.sql.db.schema

import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypingResult}

import java.sql.ResultSetMetaData
import java.io.BufferedReader
import java.sql.{Clob, ResultSet, ResultSetMetaData}
import java.time.{Instant, LocalDate, LocalTime}
import java.util.stream.Collectors
import java.{sql, util}
import scala.util.Using

final case class ColumnDefinition(
name: String,
typing: TypingResult,
valueMapping: Any => Any
) {

def extractValue(resultSet: ResultSet): Any = {
// we could here use method resultSet.getObject(Int) and pass column number as argument
// but in case of ignite db it is not certain which column index corresponds to which column.
val value = resultSet.getObject(name)
Option(value).map(valueMapping).getOrElse(value)
}

}

object ColumnDefinition {
private val sqlArrayClassName = classOf[sql.Array].getName
private val sqlTimeClassName = classOf[sql.Time].getName
private val sqlDateClassName = classOf[sql.Date].getName
private val sqlTimestampClassName = classOf[sql.Timestamp].getName
private val sqlClobClassName = classOf[sql.Clob].getName

def apply(columnNo: Int, resultMeta: ResultSetMetaData): ColumnDefinition =
def apply(columnNo: Int, resultMeta: ResultSetMetaData): ColumnDefinition = {
val (typingResult, valueMapping) = mapValueToSupportedType(resultMeta.getColumnClassName(columnNo))
ColumnDefinition(
name = resultMeta.getColumnName(columnNo),
typing = Typed(Class.forName(resultMeta.getColumnClassName(columnNo)))
typing = typingResult,
valueMapping = valueMapping
)
}

def apply(typing: (String, TypingResult)): ColumnDefinition =
def apply(typing: (String, String)): ColumnDefinition = {
val (typingResult, valueMapping) = mapValueToSupportedType(typing._2)
ColumnDefinition(
name = typing._1,
typing = typing._2
typing = typingResult,
valueMapping = valueMapping
)
}

}
private def mapValueToSupportedType(className: String): (TypingResult, Any => Any) = className match {
case `sqlArrayClassName` => (Typed.typedClass(classOf[util.List[Any]]), v => readArray(v.asInstanceOf[sql.Array]))
case `sqlTimeClassName` => (Typed.typedClass(classOf[LocalTime]), v => v.asInstanceOf[sql.Time].toLocalTime)
case `sqlDateClassName` => (Typed.typedClass(classOf[LocalDate]), v => v.asInstanceOf[sql.Date].toLocalDate)
case `sqlTimestampClassName` => (Typed.typedClass(classOf[Instant]), v => v.asInstanceOf[sql.Timestamp].toInstant)
case `sqlClobClassName` => (Typed.typedClass(classOf[String]), v => readClob(v.asInstanceOf[sql.Clob]))
case _ => (Typed.typedClass(Class.forName(className)), identity)
}

private def readArray(v: sql.Array): util.List[AnyRef] = {
val result = new util.ArrayList[AnyRef]()
val resultSet = v.getResultSet
while (resultSet.next()) {
result.add(resultSet.getObject(1))
}
result
}

final case class ColumnDefinition(name: String, typing: TypingResult)
private def readClob(v: Clob): String = {
Using.resource(new BufferedReader(v.getCharacterStream))(br => readFromStream(br))
}

private def readFromStream(br: BufferedReader): String =
br.lines().collect(Collectors.joining(System.lineSeparator()))
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@ object TableDefinition {

def apply(resultMeta: ResultSetMetaData): TableDefinition =
TableDefinition(
columnDefs = (1 to resultMeta.getColumnCount).map(ColumnDefinition(_, resultMeta)).toList
columnDefs = (1 to resultMeta.getColumnCount)
.map(ColumnDefinition(_, resultMeta))
.toList
)

def applyList(fields: List[(String, TypingResult)]): TableDefinition = {
val columnDefinitions = fields
.map { typing =>
ColumnDefinition(typing)
}
TableDefinition(
columnDefs = columnDefinitions
)
def applyList(fields: List[(String, String)]): TableDefinition = {
val columnDefinitions = fields.map(ColumnDefinition.apply)
TableDefinition(columnDefs = columnDefinitions)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class DatabaseEnricherInvoker(
tableDef: TableDefinition,
strategy: QueryResultStrategy,
queryArgumentsExtractor: (Int, Params, Context) => QueryArguments,
val returnType: typing.TypingResult,
val getConnection: () => Connection,
val getTimeMeasurement: () => AsyncExecutionTimeMeasurement,
params: Params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class DatabaseEnricherInvokerWithCache(
strategy: QueryResultStrategy,
queryArgumentsExtractor: (Int, Params, Context) => QueryArguments,
cacheTTL: Duration,
override val returnType: typing.TypingResult,
override val getConnection: () => Connection,
override val getTimeMeasurement: () => AsyncExecutionTimeMeasurement,
params: Params
Expand All @@ -36,7 +35,6 @@ class DatabaseEnricherInvokerWithCache(
tableDef,
strategy,
queryArgumentsExtractor,
returnType,
getConnection,
getTimeMeasurement,
params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,42 +241,37 @@ class DatabaseQueryEnricher(val dbPoolConfig: DBPoolConfig, val dbMetaDataProvid
dependencies: List[NodeDependencyValue],
finalState: Option[TransformationState]
): ServiceInvoker = {
val state = finalState.get
val cacheTTLOption = cacheTTLParamDeclaration.extractValue(params)

val state = finalState.get
val cacheTTLOption = cacheTTLParamDeclaration.extractValue(params)
val query = state.query
val argsCount = state.argsCount
val tableDef = state.tableDef
val strategy = state.strategy
val outputType = state.outputType
val getConnectionCallback = () => dataSource.getConnection()
val timeMeasurementCallback = () => timeMeasurement

cacheTTLOption match {
case Some(cacheTTL) =>
new DatabaseEnricherInvokerWithCache(
query,
argsCount,
tableDef,
strategy,
queryArgumentsExtractor,
cacheTTL,
outputType,
getConnectionCallback,
timeMeasurementCallback,
params
query = query,
argsCount = argsCount,
tableDef = tableDef,
strategy = strategy,
queryArgumentsExtractor = queryArgumentsExtractor,
cacheTTL = cacheTTL,
getConnection = getConnectionCallback,
getTimeMeasurement = timeMeasurementCallback,
params = params,
)
case None =>
new DatabaseEnricherInvoker(
query,
argsCount,
tableDef,
strategy,
queryArgumentsExtractor,
outputType,
getConnectionCallback,
timeMeasurementCallback,
params
query = query,
argsCount = argsCount,
tableDef = tableDef,
strategy = strategy,
queryArgumentsExtractor = queryArgumentsExtractor,
getConnection = getConnectionCallback,
getTimeMeasurement = timeMeasurementCallback,
params = params
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import pl.touk.nussknacker.sql.db.query.ResultSetStrategy
import pl.touk.nussknacker.sql.db.schema.{JdbcMetaDataProvider, MetaDataProviderFactory, TableDefinition}
import pl.touk.nussknacker.sql.utils.BaseHsqlQueryEnricherTest

import java.time.LocalDate
import scala.concurrent.Await

class DatabaseLookupEnricherTest extends BaseHsqlQueryEnricherTest {
Expand All @@ -18,8 +19,8 @@ class DatabaseLookupEnricherTest extends BaseHsqlQueryEnricherTest {
import scala.concurrent.duration._

override val prepareHsqlDDLs: List[String] = List(
"CREATE TABLE persons (id INT, name VARCHAR(40));",
"INSERT INTO persons (id, name) VALUES (1, 'John')"
"CREATE TABLE persons (id INT, name VARCHAR(40), birth_date DATE);",
"INSERT INTO persons (id, name, birth_date) VALUES (1, 'John', '1990-08-12')"
)

private val notExistingDbUrl = s"jdbc:hsqldb:mem:dummy"
Expand Down Expand Up @@ -52,20 +53,20 @@ class DatabaseLookupEnricherTest extends BaseHsqlQueryEnricherTest {
dependencies = Nil,
finalState = Some(state)
)
returnType(service, state).display shouldBe "List[Record{ID: Integer, NAME: String}]"
returnType(service, state).display shouldBe "List[Record{BIRTH_DATE: LocalDate, ID: Integer, NAME: String}]"
val resultF =
implementation.invoke(Context.withInitialId)
val result = Await.result(resultF, 5 seconds).asInstanceOf[java.util.List[TypedMap]].asScala.toList
result shouldBe List(
TypedMap(Map("ID" -> 1, "NAME" -> "John"))
TypedMap(Map("ID" -> 1, "NAME" -> "John", "BIRTH_DATE" -> LocalDate.parse("1990-08-12")))
)

conn.prepareStatement("UPDATE persons SET name = 'Alex' WHERE id = 1").execute()
val resultF2 =
implementation.invoke(Context.withInitialId)
val result2 = Await.result(resultF2, 5 seconds).asInstanceOf[java.util.List[TypedMap]].asScala.toList
result2 shouldBe List(
TypedMap(Map("ID" -> 1, "NAME" -> "Alex"))
TypedMap(Map("ID" -> 1, "NAME" -> "Alex", "BIRTH_DATE" -> LocalDate.parse("1990-08-12")))
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package pl.touk.nussknacker.sql.service

import org.scalatest.BeforeAndAfterEach
import pl.touk.nussknacker.engine.api.typed.TypedMap
import pl.touk.nussknacker.sql.db.schema.{MetaDataProviderFactory, TableDefinition}
import pl.touk.nussknacker.sql.db.schema.MetaDataProviderFactory
import pl.touk.nussknacker.sql.utils.BaseHsqlQueryEnricherTest
import org.scalatest.BeforeAndAfterEach

class DatabaseQueryEnricherHsqlTest
extends BaseHsqlQueryEnricherTest
Expand All @@ -15,7 +15,9 @@ class DatabaseQueryEnricherHsqlTest

override val prepareHsqlDDLs: List[String] = List(
"CREATE TABLE people (id INT, name VARCHAR(40));",
"INSERT INTO people (id, name) VALUES (1, 'John')"
"INSERT INTO people (id, name) VALUES (1, 'John');",
"CREATE TABLE types_test(t_clob CLOB);",
"INSERT INTO types_test(t_clob) values ('very long text');"
)

override protected def afterEach(): Unit = {
Expand Down Expand Up @@ -65,4 +67,17 @@ class DatabaseQueryEnricherHsqlTest
queryResultSet.getObject("name") shouldBe "Don"
}

test("DatabaseQueryEnricher#type conversions") {
val result = queryWithEnricher(
"select * from types_test",
Map(),
conn,
service,
"List[Record{T_CLOB: String}]"
)
result shouldBe List(
TypedMap(Map("T_CLOB" -> "very long text"))
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package pl.touk.nussknacker.sql.service

import org.scalatest.BeforeAndAfterEach
import pl.touk.nussknacker.engine.api.typed.TypedMap
import pl.touk.nussknacker.sql.db.schema.{MetaDataProviderFactory, TableDefinition}
import pl.touk.nussknacker.sql.db.schema.MetaDataProviderFactory
import pl.touk.nussknacker.sql.utils.BasePostgresqlQueryEnricherTest

import java.time.{LocalDate, LocalDateTime, LocalTime, ZoneId, ZonedDateTime}

class DatabaseQueryEnricherPostgresqlTest
extends BasePostgresqlQueryEnricherTest
with DatabaseQueryEnricherQueryWithEnricher
Expand All @@ -15,7 +17,11 @@ class DatabaseQueryEnricherPostgresqlTest

override val preparePostgresqlDDLs: List[String] = List(
"CREATE TABLE people (id INT, name VARCHAR(40));",
"INSERT INTO people (id, name) VALUES (1, 'John')"
"INSERT INTO people (id, name) VALUES (1, 'John');",
"CREATE TABLE types_test(t_time TIME, t_timestamp TIMESTAMP, t_timestamptz TIMESTAMPTZ, t_date DATE, " +
"t_array INT[], t_boolean BOOLEAN, t_text TEXT);",
"INSERT INTO types_test(t_time, t_timestamp, t_timestamptz, t_date, t_array, t_boolean, t_text) VALUES (" +
"'08:00:00', '2024-08-12 08:00:00', '2024-08-12 09:00:00+01:00', '2024-08-12', '{1,2,3,4,5}', true, 'long text');"
)

override protected def afterEach(): Unit = {
Expand Down Expand Up @@ -67,4 +73,30 @@ class DatabaseQueryEnricherPostgresqlTest
queryResultSet.getObject("name") shouldBe "Don"
}

test("DatabaseQueryEnricherPostgresqlTest#type conversions") {
import scala.jdk.CollectionConverters._
val result = queryWithEnricher(
"select * from types_test",
Map(),
conn,
service,
"List[Record{t_array: List[Unknown], t_boolean: Boolean, t_date: LocalDate, t_text: String, t_time: LocalTime, " +
"t_timestamp: Instant, t_timestamptz: Instant}]"
)

result shouldBe List(
TypedMap(
Map(
"t_boolean" -> true,
"t_timestamp" -> LocalDateTime.parse("2024-08-12T08:00:00").atZone(ZoneId.systemDefault()).toInstant,
"t_date" -> LocalDate.parse("2024-08-12"),
"t_array" -> List(1, 2, 3, 4, 5).asJava,
"t_text" -> "long text",
"t_time" -> LocalTime.parse("08:00"),
"t_timestamptz" -> ZonedDateTime.parse("2024-08-12T09:00:00+01:00").toInstant
)
)
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ trait DatabaseQueryEnricherQueryWithEnricher extends BaseDatabaseQueryEnricherTe
val meta = st.getMetaData
val state = DatabaseQueryEnricher.TransformationState(
query = query,
argsCount = 1,
argsCount = parameters.size,
tableDef = TableDefinition(meta),
strategy = ResultSetStrategy
)
Expand Down
Loading
Loading