Skip to content

Commit

Permalink
Invoice callback -> channel, log removal
Browse files Browse the repository at this point in the history
- Deprecate invoice callback function in favor of async-std channels
- Remove global log configuration
- Remove serializable type
  • Loading branch information
saefstroem committed May 16, 2024
1 parent 8def1e4 commit 0d82cfb
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 311 deletions.
577 changes: 381 additions & 196 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ sha2 = "0.10.8"
tokio = "1.37.0"
uuid = {version="1.8.0",features=["v4"]}
reqwest = "0.12.4"
log4rs = "1.3.0"
log = "0.4.21"
zeroize = {version="1.7.0",features=["zeroize_derive"]}
async-std = "1.12.0"
24 changes: 10 additions & 14 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::common::DatabaseError;
use crate::types::Serializable;
use crate::{common::DatabaseError, types::Invoice};
use sled::Tree;

/// Retrieve a value by key from a tree.
Expand All @@ -24,11 +23,10 @@ async fn get_last_from_tree(db: &Tree) -> Result<(Vec<u8>, Vec<u8>), DatabaseErr
db.last()?
.map(|(key, value)| (key.to_vec(), value.to_vec()))
.ok_or(DatabaseError::NotFound)

}

/// Wrapper for retrieving the last added item to the tree
pub async fn get_last<T: Serializable>(tree: &sled::Tree) -> Result<(String, T), DatabaseError> {
pub async fn get_last(tree: &sled::Tree) -> Result<(String, Invoice), DatabaseError> {
let binary_data = get_last_from_tree(tree).await?;
// Convert binary key to String
let key = String::from_utf8(binary_data.0).map_err(|error| {
Expand All @@ -37,17 +35,15 @@ pub async fn get_last<T: Serializable>(tree: &sled::Tree) -> Result<(String, T),
})?;

// Deserialize binary value to T
let value = T::from_bin(binary_data.1).map_err(|error| {
let value = bincode::deserialize::<Invoice>(&binary_data.1).map_err(|error| {
log::error!("Db Interaction Error: {}", error);
DatabaseError::Deserialize
})?;
Ok((key, value))
}

/// Wrapper for retrieving all key value pairs from a tree
pub async fn get_all<T: Serializable>(
tree: &sled::Tree,
) -> Result<Vec<(String, T)>, DatabaseError> {
pub async fn get_all(tree: &sled::Tree) -> Result<Vec<(String, Invoice)>, DatabaseError> {
let binary_data = get_all_from_tree(tree).await?;
let mut all = Vec::with_capacity(binary_data.len());
for (binary_key, binary_value) in binary_data {
Expand All @@ -57,8 +53,8 @@ pub async fn get_all<T: Serializable>(
DatabaseError::Deserialize
})?;

// Deserialize binary value to T
let value = T::from_bin(binary_value).map_err(|error| {
// Deserialize binary value to invoice
let value = bincode::deserialize::<Invoice>(&binary_value).map_err(|error| {
log::error!("Db Interaction Error: {}", error);
DatabaseError::Deserialize
})?;
Expand All @@ -69,9 +65,9 @@ pub async fn get_all<T: Serializable>(
}

/// Wrapper for retrieving a value from a tree
pub async fn get<T: Serializable>(tree: &Tree, key: &str) -> Result<T, DatabaseError> {
pub async fn get(tree: &Tree, key: &str) -> Result<Invoice, DatabaseError> {
let binary_data = get_from_tree(tree, key).await?;
T::from_bin(binary_data).map_err(|error| {
bincode::deserialize::<Invoice>(&binary_data).map_err(|error| {
log::error!("Db Interaction Error: {}", error);
DatabaseError::Deserialize
})
Expand All @@ -89,8 +85,8 @@ async fn set_to_tree(db: &Tree, key: &str, bin: Vec<u8>) -> Result<(), DatabaseE
}

/// Wrapper for setting a value to a tree
pub async fn set<T: Serializable>(tree: &Tree, key: &str, data: &T) -> Result<(), DatabaseError> {
let binary_data = T::to_bin(data).map_err(|error| {
pub async fn set(tree: &Tree, key: &str, data: &Invoice) -> Result<(), DatabaseError> {
let binary_data = bincode::serialize::<Invoice>(data).map_err(|error| {
log::error!("Db Interaction Error: {}", error);
DatabaseError::Serialize
})?;
Expand Down
103 changes: 56 additions & 47 deletions src/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@ use alloy::{
signers::wallet::LocalWallet,
transports::http::Http,
};
use log::LevelFilter;
use log4rs::{
append::file::FileAppender,
config::{Appender, Root},
encode::pattern::PatternEncoder,
Config,
};

use async_std::channel::Sender;
use reqwest::{Client, Url};
use sled::Tree;

Expand Down Expand Up @@ -44,10 +39,32 @@ pub struct PaymentGatewayConfiguration {
pub provider: RootProvider<Http<Client>>,
pub treasury_address: Address,
pub invoice_delay_millis: u64,
pub callback: AsyncCallback,
pub reflector: Reflector,
pub transfer_gas_limit: Option<u128>,
}

/// ## Reflector
/// The reflector allows your payment gateway to be used in a more flexible way.
///
/// In its current state you can pass a Sender from an unbound async-std channel
/// which you can create by doing:
/// ```rust
/// use async_std::channel::unbounded;
/// use acceptevm::gateway::Reflector;
///
/// let (sender, receiver) = unbounded();
///
/// let reflector=Reflector::Sender(sender);
/// ```
///
/// You may clone the receiver as many times as you want but do not use the sender
/// for anything other than passing it to the try_new() method.
#[derive(Clone)]
pub enum Reflector {
/// A sender from async-std
Sender(Sender<Invoice>),
}

// Type alias for the underlying Web3 type.
pub type Wei = U256;

Expand All @@ -62,57 +79,49 @@ impl PaymentGateway {
/// - `treasury_address`: the address of the treasury for all paid invoices, on this EVM network.
/// - `invoice_delay_millis`: how long to wait before checking the next invoice in milliseconds.
/// This is used to prevent potential rate limits from the node.
/// - `callback`: an async function that is called when an invoice is paid.
/// - `reflector`: The reflector is an enum that allows you to receive the paid invoices.
/// At the moment, the only reflector available is the `Sender` from the async-std channel.
/// This means that you will need to create a channel and pass the sender as the reflector.
/// - `sled_path`: The path of the sled database where the pending invoices will
/// be stored. In the event of a crash the invoices are saved and will be
/// checked on reboot.
/// - `name`: A name that describes this gateway. Perhaps the EVM network used?
/// - `transfer_gas_limit`: An optional gas limit used when transferring gas from paid invoices to
/// the treasury. Useful in case your treasury address is a contract address
/// that implements custom functionality for handling incoming gas.
pub fn new<F, Fut>(
///
/// Example:
/// ```rust
/// use acceptevm::gateway::{PaymentGateway, Reflector};
/// use async_std::channel::unbounded;
/// let (sender, _receiver) = unbounded();
/// let reflector = Reflector::Sender(sender);
///
/// PaymentGateway::new(
/// "https://123.com",
/// "0xdac17f958d2ee523a2206206994597c13d831ec7".to_string(),
/// 10,
/// reflector,
/// "./your-wanted-db-path",
/// "test".to_string(),
/// Some(21000),
/// );
/// ```

pub fn new(
rpc_url: &str,
treasury_address: String,
invoice_delay_millis: u64,
callback: F,
reflector: Reflector,
sled_path: &str,
name: String,
transfer_gas_limit: Option<u128>,
) -> PaymentGateway
where
F: Fn(Invoice) -> Fut + 'static + Send + Sync,
Fut: Future<Output = ()> + 'static + Send,
{
// Send allows ownership to be transferred across threads
// Sync allows references to be shared

) -> PaymentGateway {
let db = sled::open(sled_path).unwrap();
let tree = db.open_tree("invoices").unwrap();
let provider = ProviderBuilder::new().on_http(Url::from_str(rpc_url).unwrap());

// Wrap the callback in Arc<Mutex<>> to allow sharing across threads and state mutation
// We have to create a pinned box to prevent the future from being moved around in heap memory.
let callback = Arc::new(move |invoice: Invoice| {
Box::pin(callback(invoice)) as Pin<Box<dyn Future<Output = ()> + Send>>
});

// Setup logging
let logfile = FileAppender::builder()
.encoder(Box::new(PatternEncoder::new("{l} - {m}\n")))
.build("./acceptevm.log")
.unwrap();

let config = Config::builder()
.appender(Appender::builder().build("logfile", Box::new(logfile)))
.build(Root::builder().appender("logfile").build(LevelFilter::Info))
.unwrap();

// Try to initialize and catch error silently if already initialized
// during tests this make this function throw error
if log4rs::init_config(config).is_err() {
println!("Logger already initialized.");
}

// TODO: When implementing token transfers allow the user to add their gas wallet here.

PaymentGateway {
Expand All @@ -122,7 +131,7 @@ impl PaymentGateway {
.parse()
.unwrap_or_else(|_| panic!("Invalid treasury address")),
invoice_delay_millis,
callback,
reflector,
transfer_gas_limit,
},
tree,
Expand All @@ -132,20 +141,20 @@ impl PaymentGateway {

/// Retrieves the last invoice
pub async fn get_last_invoice(&self) -> Result<(String, Invoice), DatabaseError> {
get_last::<Invoice>(&self.tree).await
get_last(&self.tree).await
}

/// Retrieves all invoices in the form of a tuple: String,Invoice
/// where the first element is the key that was used in the database
/// and the second part is the invoice. The key is a SHA256 hash of the
/// creation timestamp and the recipient address.
pub async fn get_all_invoices(&self) -> Result<Vec<(String, Invoice)>, DatabaseError> {
get_all::<Invoice>(&self.tree).await
get_all(&self.tree).await
}

/// Retrieve an invoice from the payment gateway
pub async fn get_invoice(&self, key: String) -> Result<Invoice, DatabaseError> {
get::<Invoice>(&self.tree, &key).await
get(&self.tree, &key).await
}

/// Spawns an asynchronous task that checks all the pending invoices
Expand Down Expand Up @@ -189,7 +198,7 @@ impl PaymentGateway {
let seed = format!("{}{}", signer.address(), get_unix_time_millis());
let invoice_id = hash_now(seed);
// Save the invoice in db.
set::<Invoice>(&self.tree, &invoice_id, &invoice).await?;
set(&self.tree, &invoice_id, &invoice).await?;
Ok(invoice)
}
}
27 changes: 7 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,23 @@ mod tests {
use std::{fs, path::Path, str::FromStr};

use alloy::primitives::U256;
use async_std::channel::unbounded;

use crate::{
common::DatabaseError,
gateway::PaymentGateway,
gateway::{PaymentGateway, Reflector},
types::{Invoice, PaymentMethod},
};

struct Foo {
bar: std::sync::Mutex<i64>,
}

impl Foo {
async fn increase(&self) {
*self.bar.lock().unwrap() += 1;
}
}

fn setup_test_gateway(db_path: &str) -> PaymentGateway {
let foo = std::sync::Arc::new(Foo {
bar: Default::default(),
});
let foo_clone = foo.clone();
let callback = move |_| {
let foo = foo_clone.clone();
async move { foo.increase().await }
};
let (sender, _receiver) = unbounded();
let reflector = Reflector::Sender(sender);

PaymentGateway::new(
"https://123.com",
"0xdac17f958d2ee523a2206206994597c13d831ec7".to_string(),
10,
callback,
reflector,
db_path,
"test".to_string(),
Some(21000),
Expand Down Expand Up @@ -86,4 +72,5 @@ mod tests {
assert_eq!(address_length, 42);
remove_test_db("./test-assert-valid-address-length");
}

}
19 changes: 12 additions & 7 deletions src/poller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use alloy::{
rpc::types::eth::TransactionReceipt,
transports::http::Http,
};
use crate::gateway::Reflector::Sender;
use reqwest::Client;
use sled::Tree;

Expand Down Expand Up @@ -70,11 +71,8 @@ async fn check_and_process(provider: RootProvider<Http<Client>>, invoice: &Invoi

async fn delete_invoice(tree: &Tree, key: String) {
// Optimistically delete the old invoice.
match delete(tree, &key).await {
Ok(()) => {}
Err(error) => {
log::error!("Could not remove invoice, did not callback: {}", error);
}
if let Err(delete_error) = delete(tree, &key).await {
log::error!("Could not remove invoice: {}", delete_error);
}
}

Expand All @@ -89,7 +87,7 @@ async fn transfer_to_treasury(
/// to the specified polling interval.
pub async fn poll_payments(gateway: PaymentGateway) {
loop {
match get_all::<Invoice>(&gateway.tree).await {
match get_all(&gateway.tree).await {
Ok(all) => {
// Loop through all invoices
for (key, mut invoice) in all {
Expand Down Expand Up @@ -121,7 +119,14 @@ pub async fn poll_payments(gateway: PaymentGateway) {
// lock to the callback function.
delete_invoice(&gateway.tree, key).await;
invoice.paid_at_timestamp = get_unix_time_seconds();
(gateway.config.callback)(invoice).await;// Execute callback function
match gateway.config.reflector {
Sender(ref sender) => {
// Attempt to send the PriceData through the channel.
if let Err(error) = sender.send(invoice).await {
log::error!("Failed sending data: {}", error);
}
}
}
}
// To prevent rate limitations on certain Web3 RPC's we sleep here for the specified amount.
tokio::time::sleep(std::time::Duration::from_millis(
Expand Down
7 changes: 0 additions & 7 deletions src/types/errors.rs

This file was deleted.

19 changes: 0 additions & 19 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
mod errors;
use std::ops::{Deref, DerefMut};

use self::errors::SerializableError;
use alloy::{
primitives::{B256, U256},
rpc::types::eth::TransactionReceipt,
};
use serde::{Deserialize, Serialize};
use zeroize::ZeroizeOnDrop;
pub trait Serializable {
fn to_bin(&self) -> Result<Vec<u8>, Box<bincode::ErrorKind>>;
fn from_bin(data: Vec<u8>) -> Result<Self, SerializableError>
where
Self: Sized;
}

/// Describes the structure of a payment method in
/// a gateway
Expand Down Expand Up @@ -62,14 +54,3 @@ pub struct Invoice {
pub receipt: Option<TransactionReceipt>,
}

impl Serializable for Invoice {
/// Serializes invoice to bytes
fn to_bin(&self) -> Result<Vec<u8>, Box<bincode::ErrorKind>> {
bincode::serialize(&self)
}

/// Deserializes invoice from bytes
fn from_bin(data: Vec<u8>) -> Result<Self, SerializableError> {
bincode::deserialize(&data).map_err(|_| SerializableError::Deserialize)
}
}

0 comments on commit 0d82cfb

Please sign in to comment.