Skip to content

Commit

Permalink
Modify todo example to use postgres and diesel_async
Browse files Browse the repository at this point in the history
  • Loading branch information
ThouCheese committed Nov 21, 2023
1 parent 89a2af1 commit 6a6b53c
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 63 deletions.
8 changes: 4 additions & 4 deletions examples/todo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ publish = false

[dependencies]
rocket = { path = "../../core/lib" }
diesel = { version = "2.0.0", features = ["sqlite", "r2d2"] }
diesel = { version = "2.0.0", features = ["postgres", "r2d2"] }
diesel_migrations = "2.0.0"

[dev-dependencies]
parking_lot = "0.12"
rand = "0.8"

[dependencies.rocket_sync_db_pools]
path = "../../contrib/sync_db_pools/lib/"
features = ["diesel_sqlite_pool"]
[dependencies.rocket_db_pools]
path = "../../contrib/db_pools/lib/"
features = ["diesel_postgres"]

[dependencies.rocket_dyn_templates]
path = "../../contrib/dyn_templates"
Expand Down
6 changes: 4 additions & 2 deletions examples/todo/Rocket.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[default]
template_dir = "static"

[default.databases.sqlite_database]
url = "db/db.sqlite"
[default.databases.epic_todo_database]
url = "postgresql://postgres@localhost:5432/epic_todo_database"
max_connections = 1
connect_timeout = 5
1 change: 1 addition & 0 deletions examples/todo/db/DB_LIVES_HERE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
db does not live here :(
52 changes: 30 additions & 22 deletions examples/todo/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#[macro_use] extern crate rocket;
#[macro_use] extern crate rocket_sync_db_pools;
#[macro_use] extern crate diesel;

#[cfg(test)]
mod tests;
Expand All @@ -13,13 +11,15 @@ use rocket::response::{Flash, Redirect};
use rocket::serde::Serialize;
use rocket::form::Form;
use rocket::fs::{FileServer, relative};
use rocket_db_pools::{Connection, Database};

use rocket_dyn_templates::Template;

use crate::task::{Task, Todo};

#[database("sqlite_database")]
pub struct DbConn(diesel::SqliteConnection);
#[derive(Database)]
#[database("epic_todo_database")]
pub struct Db(rocket_db_pools::diesel::PgPool);

#[derive(Debug, Serialize)]
#[serde(crate = "rocket::serde")]
Expand All @@ -29,14 +29,14 @@ struct Context {
}

impl Context {
pub async fn err<M: std::fmt::Display>(conn: &DbConn, msg: M) -> Context {
pub async fn err<M: std::fmt::Display>(conn: &mut Connection<Db>, msg: M) -> Context {
Context {
flash: Some(("error".into(), msg.to_string())),
tasks: Task::all(conn).await.unwrap_or_default()
}
}

pub async fn raw(conn: &DbConn, flash: Option<(String, String)>) -> Context {
pub async fn raw(conn: &mut Connection<Db>, flash: Option<(String, String)>) -> Context {
match Task::all(conn).await {
Ok(tasks) => Context { flash, tasks },
Err(e) => {
Expand All @@ -51,11 +51,11 @@ impl Context {
}

#[post("/", data = "<todo_form>")]
async fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
async fn new(todo_form: Form<Todo>, mut conn: Connection<Db>) -> Flash<Redirect> {
let todo = todo_form.into_inner();
if todo.description.is_empty() {
Flash::error(Redirect::to("/"), "Description cannot be empty.")
} else if let Err(e) = Task::insert(todo, &conn).await {
} else if let Err(e) = Task::insert(todo, &mut conn).await {
error_!("DB insertion error: {}", e);
Flash::error(Redirect::to("/"), "Todo could not be inserted due an internal error.")
} else {
Expand All @@ -64,50 +64,58 @@ async fn new(todo_form: Form<Todo>, conn: DbConn) -> Flash<Redirect> {
}

#[put("/<id>")]
async fn toggle(id: i32, conn: DbConn) -> Result<Redirect, Template> {
match Task::toggle_with_id(id, &conn).await {
async fn toggle(id: i32, mut conn: Connection<Db>) -> Result<Redirect, Template> {
match Task::toggle_with_id(id, &mut conn).await {
Ok(_) => Ok(Redirect::to("/")),
Err(e) => {
error_!("DB toggle({}) error: {}", id, e);
Err(Template::render("index", Context::err(&conn, "Failed to toggle task.").await))
Err(Template::render("index", Context::err(&mut conn, "Failed to toggle task.").await))
}
}
}

#[delete("/<id>")]
async fn delete(id: i32, conn: DbConn) -> Result<Flash<Redirect>, Template> {
match Task::delete_with_id(id, &conn).await {
async fn delete(id: i32, mut conn: Connection<Db>) -> Result<Flash<Redirect>, Template> {
match Task::delete_with_id(id, &mut conn).await {
Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")),
Err(e) => {
error_!("DB deletion({}) error: {}", id, e);
Err(Template::render("index", Context::err(&conn, "Failed to delete task.").await))
Err(Template::render("index", Context::err(&mut conn, "Failed to delete task.").await))
}
}
}

#[get("/")]
async fn index(flash: Option<FlashMessage<'_>>, conn: DbConn) -> Template {
async fn index(flash: Option<FlashMessage<'_>>, mut conn: Connection<Db>) -> Template {
let flash = flash.map(FlashMessage::into_inner);
Template::render("index", Context::raw(&conn, flash).await)
Template::render("index", Context::raw(&mut conn, flash).await)
}

async fn run_migrations(rocket: Rocket<Build>) -> Rocket<Build> {
use diesel::Connection;
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};

const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");

DbConn::get_one(&rocket).await
.expect("database connection")
.run(|conn| { conn.run_pending_migrations(MIGRATIONS).expect("diesel migrations"); })
.await;
let config: rocket_db_pools::Config = rocket
.figment()
.extract_inner("databases.epic_todo_database")
.expect("Db not configured");

rocket::tokio::task::spawn_blocking(move || {
diesel::PgConnection::establish(&config.url)
.expect("No database")
.run_pending_migrations(MIGRATIONS)
.expect("Invalid migrations");
})
.await.expect("tokio doesn't work");

rocket
}

#[launch]
fn rocket() -> _ {
rocket::build()
.attach(DbConn::fairing())
.attach(Db::init())
.attach(Template::fairing())
.attach(AdHoc::on_ignite("Run Migrations", run_migrations))
.mount("/", FileServer::from(relative!("static")))
Expand Down
41 changes: 18 additions & 23 deletions examples/todo/src/task.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use rocket::serde::Serialize;
use diesel::{self, result::QueryResult, prelude::*};
use rocket_db_pools::diesel::RunQueryDsl;

mod schema {
table! {
diesel::table! {
tasks {
id -> Nullable<Integer>,
description -> Text,
Expand All @@ -13,7 +14,7 @@ mod schema {

use self::schema::tasks;

use crate::DbConn;
type DbConn = rocket_db_pools::diesel::AsyncPgConnection;

#[derive(Serialize, Queryable, Insertable, Debug, Clone)]
#[serde(crate = "rocket::serde")]
Expand All @@ -31,41 +32,35 @@ pub struct Todo {
}

impl Task {
pub async fn all(conn: &DbConn) -> QueryResult<Vec<Task>> {
conn.run(|c| {
tasks::table.order(tasks::id.desc()).load::<Task>(c)
}).await
pub async fn all(conn: &mut DbConn) -> QueryResult<Vec<Task>> {
tasks::table.order(tasks::id.desc()).load::<Task>(conn).await
}

/// Returns the number of affected rows: 1.
pub async fn insert(todo: Todo, conn: &DbConn) -> QueryResult<usize> {
conn.run(|c| {
let t = Task { id: None, description: todo.description, completed: false };
diesel::insert_into(tasks::table).values(&t).execute(c)
}).await
pub async fn insert(todo: Todo, conn: &mut DbConn) -> QueryResult<usize> {
let t = Task { id: None, description: todo.description, completed: false };
diesel::insert_into(tasks::table).values(&t).execute(conn).await
}

/// Returns the number of affected rows: 1.
pub async fn toggle_with_id(id: i32, conn: &DbConn) -> QueryResult<usize> {
conn.run(move |c| {
let task = tasks::table.filter(tasks::id.eq(id)).get_result::<Task>(c)?;
let new_status = !task.completed;
let updated_task = diesel::update(tasks::table.filter(tasks::id.eq(id)));
updated_task.set(tasks::completed.eq(new_status)).execute(c)
}).await
pub async fn toggle_with_id(id: i32, conn: &mut DbConn) -> QueryResult<usize> {
let task = tasks::table.filter(tasks::id.eq(id)).get_result::<Task>(conn).await?;
let new_status = !task.completed;
let updated_task = diesel::update(tasks::table.filter(tasks::id.eq(id)));
updated_task.set(tasks::completed.eq(new_status)).execute(conn).await
}

/// Returns the number of affected rows: 1.
pub async fn delete_with_id(id: i32, conn: &DbConn) -> QueryResult<usize> {
conn.run(move |c| diesel::delete(tasks::table)
pub async fn delete_with_id(id: i32, conn: &mut DbConn) -> QueryResult<usize> {
diesel::delete(tasks::table)
.filter(tasks::id.eq(id))
.execute(c))
.execute(conn)
.await
}

/// Returns the number of affected rows.
#[cfg(test)]
pub async fn delete_all(conn: &DbConn) -> QueryResult<usize> {
conn.run(|c| diesel::delete(tasks::table).execute(c)).await
pub async fn delete_all(conn: &mut DbConn) -> QueryResult<usize> {
diesel::delete(tasks::table).execute(conn).await
}
}
29 changes: 17 additions & 12 deletions examples/todo/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rand::{Rng, thread_rng, distributions::Alphanumeric};

use rocket::local::asynchronous::Client;
use rocket::http::{Status, ContentType};
use rocket_db_pools::Database;

// We use a lock to synchronize between tests so DB operations don't collide.
// For now. In the future, we'll have a nice way to run each test in a DB
Expand All @@ -15,10 +16,14 @@ macro_rules! run_test {
let _lock = DB_LOCK.lock();

rocket::async_test(async move {
let $client = Client::tracked(super::rocket()).await.expect("Rocket client");
let db = super::DbConn::get_one($client.rocket()).await;
let $conn = db.expect("failed to get database connection for testing");
Task::delete_all(&$conn).await.expect("failed to delete all tasks for testing");
let rocket = super::rocket();
let mut $conn = super::Db::fetch(&rocket)
.expect("database")
.get()
.await
.expect("database connection");
let $client = Client::tracked(rocket).await.expect("Rocket client");
Task::delete_all(&mut $conn).await.expect("failed to delete all tasks for testing");

$block
})
Expand All @@ -39,7 +44,7 @@ fn test_index() {
fn test_insertion_deletion() {
run_test!(|client, conn| {
// Get the tasks before making changes.
let init_tasks = Task::all(&conn).await.unwrap();
let init_tasks = Task::all(&mut conn).await.unwrap();

// Issue a request to insert a new task.
client.post("/todo")
Expand All @@ -49,7 +54,7 @@ fn test_insertion_deletion() {
.await;

// Ensure we have one more task in the database.
let new_tasks = Task::all(&conn).await.unwrap();
let new_tasks = Task::all(&mut conn).await.unwrap();
assert_eq!(new_tasks.len(), init_tasks.len() + 1);

// Ensure the task is what we expect.
Expand All @@ -61,7 +66,7 @@ fn test_insertion_deletion() {
client.delete(format!("/todo/{}", id)).dispatch().await;

// Ensure it's gone.
let final_tasks = Task::all(&conn).await.unwrap();
let final_tasks = Task::all(&mut conn).await.unwrap();
assert_eq!(final_tasks.len(), init_tasks.len());
if final_tasks.len() > 0 {
assert_ne!(final_tasks[0].description, "My first task");
Expand All @@ -79,16 +84,16 @@ fn test_toggle() {
.dispatch()
.await;

let task = Task::all(&conn).await.unwrap()[0].clone();
let task = Task::all(&mut conn).await.unwrap()[0].clone();
assert_eq!(task.completed, false);

// Issue a request to toggle the task; ensure it is completed.
client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
assert_eq!(Task::all(&conn).await.unwrap()[0].completed, true);
assert_eq!(Task::all(&mut conn).await.unwrap()[0].completed, true);

// Issue a request to toggle the task; ensure it's not completed again.
client.put(format!("/todo/{}", task.id.unwrap())).dispatch().await;
assert_eq!(Task::all(&conn).await.unwrap()[0].completed, false);
assert_eq!(Task::all(&mut conn).await.unwrap()[0].completed, false);
})
}

Expand All @@ -98,7 +103,7 @@ fn test_many_insertions() {

run_test!(|client, conn| {
// Get the number of tasks initially.
let init_num = Task::all(&conn).await.unwrap().len();
let init_num = Task::all(&mut conn).await.unwrap().len();
let mut descs = Vec::new();

for i in 0..ITER {
Expand All @@ -119,7 +124,7 @@ fn test_many_insertions() {
descs.insert(0, desc);

// Ensure the task was inserted properly and all other tasks remain.
let tasks = Task::all(&conn).await.unwrap();
let tasks = Task::all(&mut conn).await.unwrap();
assert_eq!(tasks.len(), init_num + i + 1);

for j in 0..i {
Expand Down

0 comments on commit 6a6b53c

Please sign in to comment.