Skip to content

Commit

Permalink
Added support for listen/notify pg commands.
Browse files Browse the repository at this point in the history
  • Loading branch information
smyrgeorge committed Jul 21, 2024
1 parent 3b120f7 commit 4286f95
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 4 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ around well-tested libraries to provide the necessary functionality to the ecosy
### Async-io

The driver fully supports non-blocking io.
Bridges the kotlin-async (coroutines) with the rust-async (tokio) without blocking.
Bridges the kotlin-async (coroutines) with the rust-async (tokio) without blocking.

All the "magic" happens thanks to the build in kotlin function `suspendCoroutine`, take a
look [here](https://kotlinlang.org/api/latest/jvm/stdlib/kotlin.coroutines/suspend-coroutine.html).
Expand Down Expand Up @@ -94,14 +94,27 @@ pg.fetchAll("select * from sqlx4k;") {
tx1.commit().getOrThrow()
```

### Listen/Notify

```kotlin
pg.listen("chan0") { notification: Postgres.PgNotification ->
println(notification)
}

(1..10).forEach {
pg.notify("chan0", "Hello $it")
delay(1000)
}
```

## Todo

- [x] PostgresSQL
- [x] Try to "bridge" the 2 async worlds (kotlin-rust)
- [x] Use non-blocking io end to end, using the `suspendCoroutine` function
- [x] Transactions
- [x] Named parameters
- [ ] Listen/Notify Postgres commands.
- [ ] Listen/Notify Postgres (in progress).
- [ ] Transaction isolation level
- [x] Publish to maven central
- [x] Better error handling
Expand Down
19 changes: 19 additions & 0 deletions examples/src/nativeMain/kotlin/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.IO
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlin.coroutines.CoroutineContext
Expand Down Expand Up @@ -65,6 +66,17 @@ fun main() {
println(debug())
}

println("Connections: ${pg.poolSize()}, Idle: ${pg.poolIdleSize()}")
println("\n\n\n::: LISTEN/NOTIFY :::")
pg.listen("chan0") { notification: Postgres.PgNotification ->
println(notification)
}

(1..10).forEach {
pg.notify("chan0", "Hello $it")
delay(1000)
}

println("\n\n\n::: TX :::")

val tx1: Transaction = pg.begin().getOrThrow()
Expand Down Expand Up @@ -133,5 +145,12 @@ fun main() {
// 9.385897375s
// 9.351138833s
println(t2)

println("Connections: ${pg.poolSize()}, Idle: ${pg.poolIdleSize()}")
(1..10).forEach {
println("Notify: $it")
pg.notify("chan0", "Hello $it")
delay(1000)
}
}
}
2 changes: 2 additions & 0 deletions sqlx4k/rust_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ panic = "abort"
[dependencies]
# https://crates.io/crates/once_cell
once_cell = { version = "1.19.0" }
# https://crates.io/crates/futures
futures = "0.3.30"
# https://crates.io/crates/tokio
tokio = { version = "1.38.0", features = ["rt-multi-thread"] }
# https://crates.io/crates/sqlx
Expand Down
79 changes: 78 additions & 1 deletion sqlx4k/rust_lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use sqlx::postgres::{PgPool, PgPoolOptions, PgRow, PgValueFormat, PgValueRef};
use futures::StreamExt;
use sqlx::postgres::{
PgListener, PgNotification, PgPool, PgPoolOptions, PgRow, PgValueFormat, PgValueRef,
};
use sqlx::{Column, Error, Executor};
use sqlx::{Row, TypeInfo, ValueRef};
use std::ffi::c_long;
use std::ptr::null_mut;
use std::{
ffi::{c_char, c_int, c_void, CStr, CString},
Expand Down Expand Up @@ -347,6 +351,39 @@ pub extern "C" fn sqlx4k_tx_fetch_all(
});
}

#[no_mangle]
pub extern "C" fn sqlx4k_listen(
channels: *const c_char,
notify_id: c_long,
notify: unsafe extern "C" fn(c_long, *mut Sqlx4kResult),
callback: *mut c_void,
fun: unsafe extern "C" fn(Ptr, *mut Sqlx4kResult),
) {
let callback = Ptr { ptr: callback };
let channels = unsafe { c_chars_to_str(channels).to_owned() };
let runtime = RUNTIME.get().unwrap();
let sqlx4k = unsafe { SQLX4K.get_mut().unwrap() };
runtime.spawn(async move {
let mut listener = PgListener::connect_with(&sqlx4k.pool).await.unwrap();
let channels: Vec<&str> = channels.split(',').collect();
listener.listen_all(channels).await.unwrap();
let mut stream = listener.into_stream();

// Return OK as soon as the stream is ready.
let result = Sqlx4kResult::default().leak();
unsafe { fun(callback, result) }

while let Some(item) = stream.next().await {
let item: PgNotification = item.unwrap();
let result = sqlx4k_result_of_pg_notification(item).leak();
unsafe { notify(notify_id, result) }
}

// TODO: remove this.
panic!("Consume from channel stoped.");
});
}

#[no_mangle]
pub extern "C" fn sqlx4k_free_result(ptr: *mut Sqlx4kResult) {
let ptr: Sqlx4kResult = unsafe { *Box::from_raw(ptr) };
Expand Down Expand Up @@ -375,6 +412,46 @@ pub extern "C" fn sqlx4k_free_result(ptr: *mut Sqlx4kResult) {
}
}

fn sqlx4k_result_of_pg_notification(item: PgNotification) -> Sqlx4kResult {
let bytes: &[u8] = item.payload().as_bytes();
let size: usize = bytes.len();
let bytes: Vec<u8> = bytes.iter().cloned().collect();
let bytes: Box<[u8]> = bytes.into_boxed_slice();
let bytes: &mut [u8] = Box::leak(bytes);
let bytes: *mut u8 = bytes.as_mut_ptr();
let value: *mut c_void = bytes as *mut c_void;

let column = Sqlx4kColumn {
ordinal: 0,
name: CString::new(item.channel()).unwrap().into_raw(),
kind: TYPE_TEXT,
size: size as c_int,
value,
};
let mut columns = vec![column];
// Make sure we're not wasting space.
columns.shrink_to_fit();
assert!(columns.len() == columns.capacity());
let columns: Box<[Sqlx4kColumn]> = columns.into_boxed_slice();
let columns: &mut [Sqlx4kColumn] = Box::leak(columns);
let columns: *mut Sqlx4kColumn = columns.as_mut_ptr();

let row = Sqlx4kRow { size: 1, columns };
let mut rows = vec![row];
// Make sure we're not wasting space.
rows.shrink_to_fit();
assert!(rows.len() == rows.capacity());
let rows: Box<[Sqlx4kRow]> = rows.into_boxed_slice();
let rows: &mut [Sqlx4kRow] = Box::leak(rows);
let rows: *mut Sqlx4kRow = rows.as_mut_ptr();

Sqlx4kResult {
size: 1,
rows,
..Default::default()
}
}

fn sqlx4k_result_of(result: Result<Vec<PgRow>, sqlx::Error>) -> Sqlx4kResult {
match result {
Ok(rows) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,40 @@ import io.github.smyrgeorge.sqlx4k.Transaction
import kotlinx.cinterop.CPointed
import kotlinx.cinterop.CPointer
import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.get
import kotlinx.cinterop.pointed
import kotlinx.cinterop.readBytes
import kotlinx.cinterop.staticCFunction
import kotlinx.cinterop.toKString
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.IO
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import librust_lib.Sqlx4kResult
import librust_lib.Sqlx4kRow
import librust_lib.TYPE_TEXT
import librust_lib.sqlx4k_fetch_all
import librust_lib.sqlx4k_free_result
import librust_lib.sqlx4k_listen
import librust_lib.sqlx4k_of
import librust_lib.sqlx4k_query
import librust_lib.sqlx4k_tx_begin
import librust_lib.sqlx4k_tx_commit
import librust_lib.sqlx4k_tx_fetch_all
import librust_lib.sqlx4k_tx_query
import librust_lib.sqlx4k_tx_rollback
import librust_lib.sqlx4k_pool_size
import librust_lib.sqlx4k_pool_idle_size
import kotlin.experimental.ExperimentalNativeApi

@OptIn(ExperimentalForeignApi::class)
@Suppress("MemberVisibilityCanBePrivate")
@OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class)
class Postgres(
host: String,
port: Int,
Expand All @@ -38,6 +60,9 @@ class Postgres(
).throwIfError()
}

fun poolSize(): Int = sqlx4k_pool_size()
fun poolIdleSize(): Int = sqlx4k_pool_idle_size()

override suspend fun query(sql: String): Result<Unit> = runCatching {
sqlx { c -> sqlx4k_query(sql, c, fn) }.throwIfError()
}
Expand All @@ -51,6 +76,45 @@ class Postgres(
Tx(tx)
}

suspend fun <T> listen(channel: String, f: (PgNotification) -> T) {
listen(listOf(channel), f)
}

suspend fun <T> listen(channels: List<String>, f: (PgNotification) -> T) {
val channelId: Long = listenerId()
val channel = Channel<PgNotification>(capacity = Channel.UNLIMITED)

// Store the channel.
Postgres.channels[channelId] = channel

// Start the channel consumer.
@OptIn(DelicateCoroutinesApi::class)
GlobalScope.launch(Dispatchers.IO) {
channel.consumeEach { f(it) }
}

// Create the listener.
sqlx { c ->
sqlx4k_listen(
// TODO: validate channels.
channels = channels.joinToString(","),
notify_id = channelId,
notify = notify,
callback = c,
`fun` = fn
)
}.throwIfError()
}

/**
* We accept only [String] values,
* because only the text type is supported by postgres.
* https://www.postgresql.org/docs/current/sql-notify.html
*/
suspend fun notify(channel: String, value: String) {
query("select pg_notify('$channel', '$value');").getOrThrow()
}

class Tx(override var tx: CPointer<out CPointed>) : Transaction {
private val mutex = Mutex()

Expand Down Expand Up @@ -81,4 +145,46 @@ class Postgres(
}
}
}

data class PgNotification(
val channel: String,
val value: String,
)

companion object {
private val channels: MutableMap<Long, Channel<PgNotification>> by lazy { mutableMapOf() }
private val listenerMutex = Mutex()
private var listenerId: Long = 0
private suspend fun listenerId(): Long = listenerMutex.withLock {
listenerId += 1
listenerId
}

private fun CPointer<Sqlx4kResult>?.notify(): PgNotification {
return try {
val result: Sqlx4kResult =
this?.pointed ?: error("Could not extract the value from the raw pointer (null).")

assert(result.size == 1)
val row: Sqlx4kRow = result.rows!![0]
assert(row.size == 1)
val column = row.columns!![0]
assert(column.kind == TYPE_TEXT)

PgNotification(
channel = column.name!!.toKString(),
value = column.value!!.readBytes(column.size).toKString()
)
} finally {
sqlx4k_free_result(this)
}
}

private val notify = staticCFunction<Long, CPointer<Sqlx4kResult>?, Unit> { c, r ->
channels[c]?.let {
val notification: PgNotification = r.notify()
runBlocking { it.send(notification) }
}
}
}
}

0 comments on commit 4286f95

Please sign in to comment.