diff --git a/.sqlx/query-14459c299f5108ac575fdb60f3ad4e2249b3e1b6415ba97febf671770cdf47b4.json b/.sqlx/query-14459c299f5108ac575fdb60f3ad4e2249b3e1b6415ba97febf671770cdf47b4.json new file mode 100644 index 0000000..20fb277 --- /dev/null +++ b/.sqlx/query-14459c299f5108ac575fdb60f3ad4e2249b3e1b6415ba97febf671770cdf47b4.json @@ -0,0 +1,58 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, description, price, category as \"category: ExpenseCategory\"\n FROM expenses\n WHERE category IS NULL\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int4" + }, + { + "ordinal": 1, + "name": "description", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "price", + "type_info": "Float4" + }, + { + "ordinal": 3, + "name": "category: ExpenseCategory", + "type_info": { + "Custom": { + "name": "expense_category", + "kind": { + "Enum": [ + "restaurants", + "shopping", + "services", + "entertainment", + "groceries", + "salary", + "interest Income", + "utilities", + "pharmacy", + "transfer", + "transport", + "others" + ] + } + } + } + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false, + true + ] + }, + "hash": "14459c299f5108ac575fdb60f3ad4e2249b3e1b6415ba97febf671770cdf47b4" +} diff --git a/.sqlx/query-d507e0c62b31fb31803767486b16c17a6887eb80663c7a2606b7caab4948c884.json b/.sqlx/query-d507e0c62b31fb31803767486b16c17a6887eb80663c7a2606b7caab4948c884.json new file mode 100644 index 0000000..0cbe3ff --- /dev/null +++ b/.sqlx/query-d507e0c62b31fb31803767486b16c17a6887eb80663c7a2606b7caab4948c884.json @@ -0,0 +1,35 @@ +{ + "db_name": "PostgreSQL", + "query": "\n update expenses\n set category = $1\n WHERE id = $2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "Custom": { + "name": "expense_category", + "kind": { + "Enum": [ + "restaurants", + "shopping", + "services", + "entertainment", + "groceries", + "salary", + "interest Income", + "utilities", + "pharmacy", + "transfer", + "transport", + "others" + ] + } + } + }, + "Int4" + ] + }, + "nullable": [] + }, + "hash": "d507e0c62b31fb31803767486b16c17a6887eb80663c7a2606b7caab4948c884" +} diff --git a/src/client/classifier.rs b/src/client/classifier.rs new file mode 100644 index 0000000..f84d82e --- /dev/null +++ b/src/client/classifier.rs @@ -0,0 +1,68 @@ +use std::str::FromStr; + +use crate::schema::ExpenseCategory; +use anyhow::bail; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize)] +pub struct TransactionToCategorize { + pub id: String, + pub description: String, + pub amount: f32, + pub balance: f32, + pub category: String, +} + +#[derive(Deserialize)] +pub struct PredictionResult { + pub id: String, + pub category: String, + pub confidence_level: f32, +} + +#[derive(Deserialize)] +pub struct PredictionResponse { + pub predictions: Vec, +} + +#[derive(Debug)] +pub struct PredictRustResponse { + pub id: i32, + pub category: ExpenseCategory, + pub confidence_level: f32, +} + +pub async fn predict( + transactions: &[TransactionToCategorize], +) -> anyhow::Result> { + let client = reqwest::Client::new(); + let resp = client + .post("https://model.fina.center/transactions") + .json(transactions) + .send() + .await?; + + let response: PredictionResponse = match resp.status() { + StatusCode::OK => resp.json().await?, + _ => bail!("unknown error: {}", resp.text().await?), + }; + + let mut results = Vec::new(); + for prediction in response.predictions { + let expense_category = match ExpenseCategory::from_str(&prediction.category) { + Ok(c) => c, + Err(e) => { + tracing::error!(?e, "error categorizing transaction in model"); + bail!("error converting category string into enum: {e:?}") + } + }; + results.push(PredictRustResponse { + id: prediction.id.parse::()?, + category: expense_category, + confidence_level: prediction.confidence_level, + }); + } + + Ok(results) +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 415a1e2..563e536 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,3 +1,4 @@ +pub mod classifier; pub mod frc; pub mod mail; pub mod pluggy; diff --git a/src/main.rs b/src/main.rs index f2d2246..69734f2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ mod hypermedia; mod queries; /// Module containing the database schemas and i/o schemas for hypermedia and data apis. mod schema; +mod tasks; /// Module containing the askama html templates to be rendered. mod templates; /// Module containing time and crypto utility functions. @@ -134,12 +135,6 @@ async fn main() -> anyhow::Result<()> { .await .context("couldn't migrate session store")?; - let deletion_task = tokio::task::spawn( - session_store - .clone() - .continuously_delete_expired(tokio::time::Duration::from_secs(60)), - ); - let client::pluggy::auth::CreateApiKeyOutcome::Success(pluggy_api_key) = client::pluggy::auth::create_api_key(&env.pluggy_client_id, &env.pluggy_client_secret) .await? @@ -154,20 +149,42 @@ async fn main() -> anyhow::Result<()> { pluggy_api_key, }); - let renew_pluggy = renew_pluggy_task( + let deletion_task = tokio::task::spawn( + session_store + .clone() + .continuously_delete_expired(tokio::time::Duration::from_secs(60)), + ); + + let renew_pluggy = tasks::renew_pluggy_task( shared_state.pluggy_api_key.clone(), shared_state.env.pluggy_client_id.clone(), shared_state.env.pluggy_client_secret.clone(), ); + let categorize_transactions_task = + tasks::categorize_transactions_task(shared_state.pool.clone()); tracing::info!("Server started"); - let rest = rest(shared_state.clone(), session_store); - rest.await??; - - tracing::info!("Starting deletion task"); // message won't show - deletion_task.await??; - renew_pluggy.await??; + tokio::select! { + rest_result = rest(shared_state.clone(), session_store) => { + rest_result??; + } + deletio_result = deletion_task => { + if let Err(e) = deletio_result { + tracing::error!(?e, "session deletion task failed"); + } + } + renew_result = renew_pluggy => { + if let Err(e) = renew_result { + tracing::error!(?e, "Pluggy renewal task failed"); + } + } + categorize_result = categorize_transactions_task => { + if let Err(e) = categorize_result { + tracing::error!(?e, "Transaction categorization task failed"); + } + } + } logger_provider.shutdown()?; Ok(()) @@ -274,30 +291,6 @@ fn rest( }) } -fn renew_pluggy_task( - pluggy_api_key: Arc>, - pluggy_client_id: String, - pluggy_client_secret: String, -) -> tokio::task::JoinHandle> { - tokio::task::spawn(async move { - loop { - tokio::time::sleep(Duration::from_secs(60 * 5)).await; - tracing::info!("waky waky renew pluggy"); - - let client::pluggy::auth::CreateApiKeyOutcome::Success(new_pluggy_api_key) = - client::pluggy::auth::create_api_key(&pluggy_client_id, &pluggy_client_secret) - .await? - else { - bail!("task couldn't renew pluggy api_key") - }; - - let mut pluggy_api_key = pluggy_api_key.lock().await; - *pluggy_api_key = new_pluggy_api_key.api_key; - tracing::info!("renewed pluggy api key"); - } - }) -} - /// Returns a default configuration of http security headers. fn generate_general_helmet_headers() -> Helmet { return Helmet::new() diff --git a/src/queries/expenses.rs b/src/queries/expenses.rs index ed6ab8b..6298178 100644 --- a/src/queries/expenses.rs +++ b/src/queries/expenses.rs @@ -264,3 +264,43 @@ pub async fn most_recent_for_account( .fetch_optional(conn) .await } + +pub struct UncategorizedTransactions { + pub id: i32, + pub description: String, + pub price: f32, + pub category: Option, +} + +pub async fn list_uncategorized( + conn: impl PgExecutor<'_>, +) -> Result, sqlx::Error> { + sqlx::query_as!( + UncategorizedTransactions, + r#" + SELECT id, description, price, category as "category: ExpenseCategory" + FROM expenses + WHERE category IS NULL + "# + ) + .fetch_all(conn) + .await +} + +pub async fn update_category( + conn: impl PgExecutor<'_>, + expense_id: i32, + category: ExpenseCategory, +) -> Result { + sqlx::query!( + r#" + update expenses + set category = $1 + WHERE id = $2 + "#, + category as ExpenseCategory, + expense_id, + ) + .execute(conn) + .await +} diff --git a/src/tasks.rs b/src/tasks.rs new file mode 100644 index 0000000..9a554c8 --- /dev/null +++ b/src/tasks.rs @@ -0,0 +1,108 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::bail; +use sqlx::PgPool; + +use crate::client::classifier::TransactionToCategorize; + +pub fn renew_pluggy_task( + pluggy_api_key: Arc>, + pluggy_client_id: String, + pluggy_client_secret: String, +) -> tokio::task::JoinHandle> { + tokio::task::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(60 * 5)).await; + tracing::debug!("waky waky renew pluggy"); + + let crate::client::pluggy::auth::CreateApiKeyOutcome::Success(new_pluggy_api_key) = + crate::client::pluggy::auth::create_api_key( + &pluggy_client_id, + &pluggy_client_secret, + ) + .await? + else { + bail!("task couldn't renew pluggy api_key") + }; + + let mut pluggy_api_key = pluggy_api_key.lock().await; + *pluggy_api_key = new_pluggy_api_key.api_key; + tracing::info!("renewed pluggy api key"); + } + }) +} + +pub fn categorize_transactions_task( + db_pool: PgPool, +) -> tokio::task::JoinHandle> { + tokio::task::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(60 * 5)).await; + tracing::debug!("starting transaction categorization task"); + + let uncategorized = match crate::queries::expenses::list_uncategorized(&db_pool).await { + Ok(rows) => rows, + Err(e) => { + tracing::error!(?e, "failed to fetch uncategorized transactions"); + continue; + } + }; + + if uncategorized.is_empty() { + tracing::debug!("no uncategorized transactions found"); + continue; + } + + tracing::info!( + count = uncategorized.len(), + "processing uncategorized transactions" + ); + + // Convert to TransactionToCategorize format + let transactions: Vec = uncategorized + .iter() + .map(|row| TransactionToCategorize { + id: row.id.to_string(), + description: row.description.clone(), + amount: row.price, + balance: 0.0, + // FIX: use external category + category: row.category.clone().unwrap_or_default().to_string(), + }) + .collect(); + + match crate::client::classifier::predict(&transactions).await { + Ok(transactions_to_update) => { + for tx in &transactions_to_update { + match crate::queries::expenses::update_category( + &db_pool, + tx.id, + tx.category.clone(), + ) + .await + { + Ok(c) => { + if c.rows_affected() > 1 { + tracing::error!( + "i really need a macro that cancels the transaction" + ); + } + tracing::debug!("confidence_level: {}", tx.confidence_level); + } + Err(e) => { + tracing::error!(?e, tx_id = ?tx.id, "failed to update transaction category"); + } + } + } + tracing::info!( + count = transactions_to_update.len(), + "updated transaction categories" + ); + } + Err(e) => { + tracing::error!(?e, "failed to predict transaction categories"); + } + } + } + }) +}