Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the VTab interface safe #416

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
10 changes: 7 additions & 3 deletions crates/duckdb-loadable-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,20 @@ pub fn duckdb_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
/// Will be called by duckdb
#[no_mangle]
pub unsafe extern "C" fn #c_entrypoint(db: *mut c_void) {
let connection = Connection::open_from_raw(db.cast()).expect("can't open db connection");
#prefixed_original_function(connection).expect("init failed");
unsafe {
let connection = Connection::open_from_raw(db.cast()).expect("can't open db connection");
#prefixed_original_function(connection).expect("init failed");
}
}

/// # Safety
///
/// Predefined function, don't need to change unless you are sure
#[no_mangle]
pub unsafe extern "C" fn #c_entrypoint_version() -> *const c_char {
ffi::duckdb_library_version()
unsafe {
ffi::duckdb_library_version()
}
}


Expand Down
69 changes: 24 additions & 45 deletions crates/duckdb/examples/hello-ext/main.rs
Original file line number Diff line number Diff line change
@@ -1,81 +1,60 @@
#![warn(unsafe_op_in_unsafe_fn)]
#![warn(unsafe)] // extensions can be safe

extern crate duckdb;
extern crate duckdb_loadable_macros;
extern crate libduckdb_sys;

use duckdb::{
core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId},
vtab::{BindInfo, Free, FunctionInfo, InitInfo, VTab},
vtab::{BindInfo, FunctionInfo, InitInfo, VTab},
Connection, Result,
};
use duckdb_loadable_macros::duckdb_entrypoint;
use libduckdb_sys as ffi;
use std::{
error::Error,
ffi::{c_char, c_void, CString},
sync::atomic::{AtomicBool, Ordering},
};

#[repr(C)]
struct HelloBindData {
name: *mut c_char,
}

impl Free for HelloBindData {
fn free(&mut self) {
unsafe {
if self.name.is_null() {
return;
}
drop(CString::from_raw(self.name));
}
}
name: String,
}

#[repr(C)]
struct HelloInitData {
done: bool,
done: AtomicBool,
}

struct HelloVTab;

impl Free for HelloInitData {}

impl VTab for HelloVTab {
type InitData = HelloInitData;
type BindData = HelloBindData;

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn std::error::Error>> {
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let param = bind.get_parameter(0).to_string();
unsafe {
(*data).name = CString::new(param).unwrap().into_raw();
}
Ok(())
let name = bind.get_parameter(0).to_string();
Ok(HelloBindData { name })
}

unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box<dyn std::error::Error>> {
unsafe {
(*data).done = false;
}
Ok(())
fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>> {
Ok(HelloInitData {
done: AtomicBool::new(false),
})
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = func.get_init_data::<HelloInitData>();
let bind_info = func.get_bind_data::<HelloBindData>();
fn func(func: &FunctionInfo<Self>, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_data = func.get_init_data();
let bind_data = func.get_bind_data();

unsafe {
if (*init_info).done {
output.set_len(0);
} else {
(*init_info).done = true;
let vector = output.flat_vector(0);
let name = CString::from_raw((*bind_info).name);
let result = CString::new(format!("Hello {}", name.to_str()?))?;
// Can't consume the CString
(*bind_info).name = CString::into_raw(name);
vector.insert(0, result);
output.set_len(1);
}
if init_data.done.swap(true, Ordering::Relaxed) {
output.set_len(0);
} else {
let vector = output.flat_vector(0);
let result = CString::new(format!("Hello {}", bind_data.name))?;
vector.insert(0, result);
output.set_len(1);
}
Ok(())
}
Expand Down
59 changes: 40 additions & 19 deletions crates/duckdb/src/vtab/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use super::{
duckdb_table_function_set_init, duckdb_table_function_set_local_init, duckdb_table_function_set_name,
duckdb_table_function_supports_projection_pushdown, idx_t,
},
LogicalTypeHandle, Value,
LogicalTypeHandle, VTab, Value,
};
use std::{
ffi::{c_void, CString},
marker::PhantomData,
os::raw::c_char,
};

Expand Down Expand Up @@ -138,7 +139,9 @@ impl From<duckdb_init_info> for InitInfo {
impl InitInfo {
/// # Safety
pub unsafe fn set_init_data(&self, data: *mut c_void, freeer: Option<unsafe extern "C" fn(*mut c_void)>) {
duckdb_init_set_init_data(self.0, data, freeer);
unsafe {
duckdb_init_set_init_data(self.0, data, freeer);
}
}

/// Returns the column indices of the projected columns at the specified positions.
Expand Down Expand Up @@ -188,7 +191,7 @@ impl InitInfo {
/// * `error`: The error message
pub fn set_error(&self, error: &str) {
let c_str = CString::new(error).unwrap();
unsafe { duckdb_init_set_error(self.0, c_str.as_ptr() as *const c_char) }
unsafe { duckdb_init_set_error(self.0, c_str.as_ptr()) }
}
}

Expand Down Expand Up @@ -309,7 +312,9 @@ impl TableFunction {
///
/// # Safety
pub unsafe fn set_extra_info(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) {
duckdb_table_function_set_extra_info(self.ptr, extra_info, destroy);
unsafe {
duckdb_table_function_set_extra_info(self.ptr, extra_info, destroy);
}
}

/// Sets the thread-local init function of the table function
Expand All @@ -334,54 +339,70 @@ use super::ffi::{

/// An interface to store and retrieve data during the function execution stage
#[derive(Debug)]
pub struct FunctionInfo(duckdb_function_info);
pub struct FunctionInfo<V: VTab> {
ptr: duckdb_function_info,
_vtab: PhantomData<V>,
}

impl FunctionInfo {
impl<V: VTab> FunctionInfo<V> {
/// Report that an error has occurred while executing the function.
///
/// # Arguments
/// * `error`: The error message
pub fn set_error(&self, error: &str) {
let c_str = CString::new(error).unwrap();
unsafe {
duckdb_function_set_error(self.0, c_str.as_ptr() as *const c_char);
duckdb_function_set_error(self.ptr, c_str.as_ptr());
}
}

/// Gets the bind data set by [`BindInfo::set_bind_data`] during the bind.
///
/// Note that the bind data should be considered as read-only.
/// For tracking state, use the init data instead.
///
/// # Arguments
/// * `returns`: The bind data object
pub fn get_bind_data<T>(&self) -> *mut T {
unsafe { duckdb_function_get_bind_data(self.0).cast() }
pub fn get_bind_data(&self) -> &V::BindData {
unsafe {
let bind_data: *const V::BindData = duckdb_function_get_bind_data(self.ptr).cast();
bind_data.as_ref().unwrap()
}
}
/// Gets the init data set by [`InitInfo::set_init_data`] during the init.

/// Get a reference to the init data set by [`InitInfo::set_init_data`] during the init.
///
/// This returns a shared reference because the init data is shared between multiple threads.
/// It may internally be mutable.
///
/// # Arguments
/// * `returns`: The init data object
pub fn get_init_data<T>(&self) -> *mut T {
unsafe { duckdb_function_get_init_data(self.0).cast() }
pub fn get_init_data(&self) -> &V::InitData {
// Safety: A pointer to a box of the init data is stored during vtab init.
unsafe {
let init_data: *const V::InitData = duckdb_function_get_init_data(self.ptr).cast();
init_data.as_ref().unwrap()
}
}

/// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`]
///
/// # Arguments
/// * `returns`: The extra info
pub fn get_extra_info<T>(&self) -> *mut T {
unsafe { duckdb_function_get_extra_info(self.0).cast() }
unsafe { duckdb_function_get_extra_info(self.ptr).cast() }
}
/// Gets the thread-local init data set by [`InitInfo::set_init_data`] during the local_init.
///
/// # Arguments
/// * `returns`: The init data object
pub fn get_local_init_data<T>(&self) -> *mut T {
unsafe { duckdb_function_get_local_init_data(self.0).cast() }
unsafe { duckdb_function_get_local_init_data(self.ptr).cast() }
}
}

impl From<duckdb_function_info> for FunctionInfo {
impl<V: VTab> From<duckdb_function_info> for FunctionInfo<V> {
fn from(ptr: duckdb_function_info) -> Self {
Self(ptr)
Self {
ptr,
_vtab: PhantomData,
}
}
}
Loading