Skip to content

Commit

Permalink
Add support for executing statements with custom encoders
Browse files Browse the repository at this point in the history
Introduced methods to execute and fetch all results using encoded statements across SQLite, PostgreSQL, and MySQL classes. Implemented Statement's `render` method to utilize a ValueEncoderRegistry, ensuring proper parameter encoding in SQL queries.
  • Loading branch information
smyrgeorge committed Oct 6, 2024
1 parent 4a54d63 commit 184d246
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package io.github.smyrgeorge.sqlx4k.mysql

import io.github.smyrgeorge.sqlx4k.Driver
import io.github.smyrgeorge.sqlx4k.ResultSet
import io.github.smyrgeorge.sqlx4k.RowMapper
import io.github.smyrgeorge.sqlx4k.Statement
import io.github.smyrgeorge.sqlx4k.Transaction
import io.github.smyrgeorge.sqlx4k.impl.extensions.rowsAffectedOrError
import io.github.smyrgeorge.sqlx4k.impl.extensions.sqlx
Expand Down Expand Up @@ -57,11 +59,20 @@ class MySQL(
sqlx { c -> sqlx4k_query(sql, c, Driver.fn) }.rowsAffectedOrError()
}

override suspend fun execute(statement: Statement): Result<Long> =
execute(statement.render(encoders))

override suspend fun fetchAll(sql: String): Result<ResultSet> {
val res = sqlx { c -> sqlx4k_fetch_all(sql, c, Driver.fn) }
return ResultSet(res).toKotlinResult()
}

override suspend fun fetchAll(statement: Statement): Result<ResultSet> =
fetchAll(statement.render(encoders))

override suspend fun <T> fetchAll(statement: Statement, rowMapper: RowMapper<T>): Result<List<T>> =
fetchAll(statement.render(encoders), rowMapper)

override suspend fun begin(): Result<Transaction> = runCatching {
val tx = sqlx { c -> sqlx4k_tx_begin(c, Driver.fn) }.tx()
Tx(tx.first)
Expand Down Expand Up @@ -90,6 +101,9 @@ class MySQL(
}
}

override suspend fun execute(statement: Statement): Result<Long> =
execute(statement.render(encoders))

override suspend fun fetchAll(sql: String): Result<ResultSet> {
val res = mutex.withLock {
val r = sqlx { c -> sqlx4k_tx_fetch_all(tx, sql, c, Driver.fn) }
Expand All @@ -99,5 +113,23 @@ class MySQL(
tx = res.getRaw().tx!!
return res.toKotlinResult()
}

override suspend fun fetchAll(statement: Statement): Result<ResultSet> =
fetchAll(statement.render(encoders))

override suspend fun <T> fetchAll(statement: Statement, rowMapper: RowMapper<T>): Result<List<T>> =
fetchAll(statement.render(encoders), rowMapper)
}

companion object {
/**
* The `ValueEncoderRegistry` instance used for encoding values supplied to SQL statements in the `MySQL` class.
* This registry maps data types to their corresponding encoders, which convert values into a format suitable for
* inclusion in SQL queries.
*
* This registry is utilized in methods like `execute`, `fetchAll`, and other database operation methods to ensure
* that parameters bound to SQL statements are correctly encoded before being executed.
*/
val encoders = Statement.ValueEncoderRegistry()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package io.github.smyrgeorge.sqlx4k.postgres

import io.github.smyrgeorge.sqlx4k.Driver
import io.github.smyrgeorge.sqlx4k.ResultSet
import io.github.smyrgeorge.sqlx4k.RowMapper
import io.github.smyrgeorge.sqlx4k.Statement
import io.github.smyrgeorge.sqlx4k.Transaction
import io.github.smyrgeorge.sqlx4k.impl.extensions.rowsAffectedOrError
import io.github.smyrgeorge.sqlx4k.impl.extensions.sqlx
Expand Down Expand Up @@ -74,11 +76,20 @@ class PostgreSQL(
sqlx { c -> sqlx4k_query(sql, c, Driver.fn) }.rowsAffectedOrError()
}

override suspend fun execute(statement: Statement): Result<Long> =
execute(statement.render(encoders))

override suspend fun fetchAll(sql: String): Result<ResultSet> {
val res = sqlx { c -> sqlx4k_fetch_all(sql, c, Driver.fn) }
return ResultSet(res).toKotlinResult()
}

override suspend fun fetchAll(statement: Statement): Result<ResultSet> =
fetchAll(statement.render(encoders))

override suspend fun <T> fetchAll(statement: Statement, rowMapper: RowMapper<T>): Result<List<T>> =
fetchAll(statement.render(encoders), rowMapper)

override suspend fun begin(): Result<Transaction> = runCatching {
val tx = sqlx { c -> sqlx4k_tx_begin(c, Driver.fn) }.tx()
Tx(tx.first)
Expand Down Expand Up @@ -146,6 +157,9 @@ class PostgreSQL(
}
}

override suspend fun execute(statement: Statement): Result<Long> =
execute(statement.render(encoders))

override suspend fun fetchAll(sql: String): Result<ResultSet> {
val res = mutex.withLock {
val r = sqlx { c -> sqlx4k_tx_fetch_all(tx, sql, c, Driver.fn) }
Expand All @@ -155,6 +169,12 @@ class PostgreSQL(
tx = res.getRaw().tx!!
return res.toKotlinResult()
}

override suspend fun fetchAll(statement: Statement): Result<ResultSet> =
fetchAll(statement.render(encoders))

override suspend fun <T> fetchAll(statement: Statement, rowMapper: RowMapper<T>): Result<List<T>> =
fetchAll(statement.render(encoders), rowMapper)
}

data class Notification(
Expand All @@ -163,6 +183,16 @@ class PostgreSQL(
)

companion object {
/**
* The `ValueEncoderRegistry` instance used for encoding values supplied to SQL statements in the `PostgreSQL` class.
* This registry maps data types to their corresponding encoders, which convert values into a format suitable for
* inclusion in SQL queries.
*
* This registry is utilized in methods like `execute`, `fetchAll`, and other database operation methods to ensure
* that parameters bound to SQL statements are correctly encoded before being executed.
*/
val encoders = Statement.ValueEncoderRegistry()

private val channels: MutableMap<Int, Channel<Notification>> by lazy { mutableMapOf() }
private val listenerMutex = Mutex()
private var listenerId: Int = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package io.github.smyrgeorge.sqlx4k.sqlite

import io.github.smyrgeorge.sqlx4k.Driver
import io.github.smyrgeorge.sqlx4k.ResultSet
import io.github.smyrgeorge.sqlx4k.RowMapper
import io.github.smyrgeorge.sqlx4k.Statement
import io.github.smyrgeorge.sqlx4k.Transaction
import io.github.smyrgeorge.sqlx4k.impl.extensions.rowsAffectedOrError
import io.github.smyrgeorge.sqlx4k.impl.extensions.sqlx
Expand Down Expand Up @@ -49,11 +51,20 @@ class SQLite(
sqlx { c -> sqlx4k_query(sql, c, Driver.fn) }.rowsAffectedOrError()
}

override suspend fun execute(statement: Statement): Result<Long> =
execute(statement.render(encoders))

override suspend fun fetchAll(sql: String): Result<ResultSet> {
val res = sqlx { c -> sqlx4k_fetch_all(sql, c, Driver.fn) }
return ResultSet(res).toKotlinResult()
}

override suspend fun fetchAll(statement: Statement): Result<ResultSet> =
fetchAll(statement.render(encoders))

override suspend fun <T> fetchAll(statement: Statement, rowMapper: RowMapper<T>): Result<List<T>> =
fetchAll(statement.render(encoders), rowMapper)

override suspend fun begin(): Result<Transaction> = runCatching {
val tx = sqlx { c -> sqlx4k_tx_begin(c, Driver.fn) }.tx()
Tx(tx.first)
Expand Down Expand Up @@ -82,6 +93,9 @@ class SQLite(
}
}

override suspend fun execute(statement: Statement): Result<Long> =
execute(statement.render(encoders))

override suspend fun fetchAll(sql: String): Result<ResultSet> {
val res = mutex.withLock {
val r = sqlx { c -> sqlx4k_tx_fetch_all(tx, sql, c, Driver.fn) }
Expand All @@ -91,5 +105,23 @@ class SQLite(
tx = res.getRaw().tx!!
return res.toKotlinResult()
}

override suspend fun fetchAll(statement: Statement): Result<ResultSet> =
fetchAll(statement.render(encoders))

override suspend fun <T> fetchAll(statement: Statement, rowMapper: RowMapper<T>): Result<List<T>> =
fetchAll(statement.render(encoders), rowMapper)
}

companion object {
/**
* The `ValueEncoderRegistry` instance used for encoding values supplied to SQL statements in the `SQLite` class.
* This registry maps data types to their corresponding encoders, which convert values into a format suitable for
* inclusion in SQL queries.
*
* This registry is utilized in methods like `execute`, `fetchAll`, and other database operation methods to ensure
* that parameters bound to SQL statements are correctly encoded before being executed.
*/
val encoders = Statement.ValueEncoderRegistry()
}
}
23 changes: 10 additions & 13 deletions sqlx4k/src/nativeMain/kotlin/io/github/smyrgeorge/sqlx4k/Driver.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ interface Driver {
* @param statement the SQL statement to be executed.
* @return a result containing the number of affected rows.
*/
suspend fun execute(statement: Statement): Result<Long> =
execute(statement.render())
suspend fun execute(statement: Statement): Result<Long>

/**
* Fetches all results of the given SQL query asynchronously.
Expand All @@ -41,6 +40,14 @@ interface Driver {
*/
suspend fun fetchAll(sql: String): Result<ResultSet>

/**
* Fetches all results of the given SQL statement asynchronously.
*
* @param statement The SQL statement to be executed.
* @return A result containing the retrieved result set.
*/
suspend fun fetchAll(statement: Statement): Result<ResultSet>

/**
* Fetches all results of the given SQL query and maps each row using the provided RowMapper.
*
Expand All @@ -54,15 +61,6 @@ interface Driver {
rowMapper.map(res)
}

/**
* Fetches all results of the given SQL statement asynchronously.
*
* @param statement The SQL statement to be executed.
* @return A result containing the retrieved result set.
*/
suspend fun fetchAll(statement: Statement): Result<ResultSet> =
fetchAll(statement.render())

/**
* Fetches all results of the given SQL statement and maps each row using the provided RowMapper.
*
Expand All @@ -71,8 +69,7 @@ interface Driver {
* @param rowMapper The RowMapper to use for converting rows in the result set to instances of type T.
* @return A Result containing a list of instances of type T mapped from the query result set.
*/
suspend fun <T> fetchAll(statement: Statement, rowMapper: RowMapper<T>): Result<List<T>> =
fetchAll(statement.render(), rowMapper)
suspend fun <T> fetchAll(statement: Statement, rowMapper: RowMapper<T>): Result<List<T>>

/**
* Represents a general interface for managing connection pools.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ interface Statement {
* @return A string representing the rendered SQL statement with all positional and named
* parameters substituted by their bound values.
*/
fun render(): String
fun render(encoders: ValueEncoderRegistry = ValueEncoderRegistry.EMPTY): String

/**
* Converts the value of the receiver to a string representation suitable for database operations.
Expand Down Expand Up @@ -142,15 +142,11 @@ interface Statement {

companion object {
/**
* Creates a new `Statement` instance with the given SQL string and an optional `ValueEncoderRegistry`.
* Creates a new `Statement` instance with the provided SQL string.
*
* @param sql The SQL string to be used in the statement.
* @param encoders The `ValueEncoderRegistry` to be used for encoding values, default is `ValueEncoderRegistry.EMPTY`.
* @return The newly created `Statement` instance with the specified SQL and encoders.
* @param sql The SQL string used to create the statement.
* @return A new `Statement` instance initialized with the provided SQL string.
*/
fun create(
sql: String,
encoders: ValueEncoderRegistry = ValueEncoderRegistry.EMPTY
): Statement = SimpleStatement(sql, encoders)
fun create(sql: String, ): Statement = SimpleStatement(sql)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ import io.github.smyrgeorge.sqlx4k.SQLError
import io.github.smyrgeorge.sqlx4k.Statement.ValueEncoderRegistry

/**
* An extension of the `SimpleStatement` class that adds support for positional parameter binding
* and custom SQL statement rendering.
* Represents an extended SQL statement that allows binding values to positional parameters
* specific to PostgreSQL.
*
* @property sql The SQL statement to be executed.
* @property encoders A `ValueEncoderRegistry` used for encoding values.
* @constructor Creates an `ExtendedStatement` with the given SQL string and value encoder registry.
* This class extends [SimpleStatement] by leveraging PostgreSQL's custom parameter syntax
* (e.g., $1, $2) and providing mechanisms to bind values to those parameters and render
* the final SQL with all parameters substituted.
*
* @property sql The SQL string containing the statement.
*/
@Suppress("unused")
class ExtendedStatement(
private val sql: String,
private val encoders: ValueEncoderRegistry = ValueEncoderRegistry.EMPTY
) : SimpleStatement(sql, encoders) {
class ExtendedStatement(private val sql: String) : SimpleStatement(sql) {

private val pgParameters: List<Int> by lazy {
extractPgParameters(sql)
Expand Down Expand Up @@ -43,33 +42,24 @@ class ExtendedStatement(
}

/**
* Renders the SQL statement by processing positional parameters.
*
* This method overrides the base class implementation to process
* positional parameters specific to the PgStatement class. It first
* delegates the rendering to the base class implementation, then
* further processes positional parameters.
* Renders the SQL statement, including encoding all positional parameters using the specified encoder registry.
*
* @return A string representing the rendered SQL statement with all
* positional parameters substituted by their bound values.
* @param encoders The `ValueEncoderRegistry` that provides the appropriate encoders for the parameter values.
* @return A string representing the fully rendered SQL statement with all parameters encoded.
*/
override fun render(): String =
override fun render(encoders: ValueEncoderRegistry): String =
super
.render()
.renderPositionalParameters()
.render(encoders)
.renderPositionalParameters(encoders)

/**
* Replaces positional parameters in the SQL string with their corresponding values.
*
* This function scans the SQL string for positional parameters indicated by
* placeholders such as `$1`, `$2`, etc., and replaces them with their corresponding
* bound values from the `pgParametersValues` map. If a required value is not supplied,
* an error is thrown.
* Replaces positional parameters in the SQL statement with their corresponding encoded values.
*
* @return A string where all positional parameters are replaced by their bound values.
* @throws SQLError if a positional parameter value is not supplied.
* @param encoders The `ValueEncoderRegistry` that provides the appropriate encoders for the parameter values.
* @return The SQL statement with all positional parameters replaced by their encoded values.
* @throws SQLError if a value for a positional parameter index is not supplied.
*/
private fun String.renderPositionalParameters(): String {
private fun String.renderPositionalParameters(encoders: ValueEncoderRegistry): String {
var res: String = this
pgParameters.forEach { index ->
if (!pgParametersValues.containsKey(index)) {
Expand Down
Loading

0 comments on commit 184d246

Please sign in to comment.