Skip to content

Commit

Permalink
feat: created mock spawner and added more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sifnoc committed Nov 27, 2023
1 parent 4d1fee7 commit b30019c
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 88 deletions.
67 changes: 2 additions & 65 deletions bin/mini_tree_server.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
use axum::{extract::Json, http::StatusCode, response::IntoResponse, routing::post, Router};
use const_env::from_env;
use num_bigint::BigUint;
use axum::{routing::post, Router};
use std::net::SocketAddr;

use summa_aggregation::{JsonEntry, JsonMerkleSumTree, JsonNode};
use summa_backend::merkle_sum_tree::{Entry, MerkleSumTree, Node, Tree};

#[from_env]
const N_ASSETS: usize = 2;
#[from_env]
const N_BYTES: usize = 14;
use summa_aggregation::mini_tree_generator::create_mst;

#[tokio::main]
async fn main() {
Expand All @@ -25,58 +17,3 @@ async fn main() {
.await
.unwrap();
}

fn convert_node_to_json(node: &Node<N_ASSETS>) -> JsonNode {
JsonNode {
hash: format!("{:?}", node.hash),
balances: node.balances.iter().map(|b| format!("{:?}", b)).collect(),
}
}

async fn create_mst(
Json(json_entries): Json<Vec<JsonEntry>>,
) -> Result<impl IntoResponse, (StatusCode, Json<JsonMerkleSumTree>)> {
// Convert `JsonEntry` -> `Entry<N_ASSETS>`
let entries = json_entries
.iter()
.map(|entry| {
let mut balances: [BigUint; N_ASSETS] = std::array::from_fn(|_| BigUint::from(0u32));
entry.balances.iter().enumerate().for_each(|(i, balance)| {
balances[i] = balance.parse::<BigUint>().unwrap();
});
Entry::new(entry.username.clone(), balances).unwrap()
})
.collect::<Vec<Entry<N_ASSETS>>>();

let entries_length = entries.len();
let starting_time = std::time::Instant::now();
// Create `MerkleSumTree<N_ASSETS, N_BYTES>` from `parsed_entries`
let tree = MerkleSumTree::<N_ASSETS, N_BYTES>::from_entries(entries, false).unwrap();
println!(
"Time to create tree({} entries): {}ms",
entries_length,
starting_time.elapsed().as_millis()
);

// Convert `MerkleSumTree<N_ASSETS, N_BYTES>` to `JsonMerkleSumTree`
let json_tree = JsonMerkleSumTree {
root: convert_node_to_json(&tree.root()),
nodes: tree
.nodes()
.iter()
.map(|layer| layer.iter().map(convert_node_to_json).collect())
.collect(),
depth: tree.depth().clone(),
entries: tree
.entries()
.iter()
.map(|entry| JsonEntry {
balances: entry.balances().iter().map(|b| b.to_string()).collect(),
username: entry.username().to_string(),
})
.collect(),
is_sorted: false, // Always false because sorted entries inside minitree is meaningless
};

Ok((StatusCode::OK, Json(json_tree)))
}
4 changes: 4 additions & 0 deletions src/aggregation_merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> AggregationMerkleSumTree<N_ASS
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
if mini_trees.is_empty() {
return Err("Empty mini tree inputs".into());
}

// assert that all mini trees have the same depth
let depth = mini_trees[0].depth();
assert!(mini_trees.iter().all(|x| x.depth() == depth));
Expand Down
12 changes: 5 additions & 7 deletions src/executor/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,18 @@ impl Executor {
&self,
json_entries: Vec<JsonEntry>,
) -> Result<MerkleSumTree<N_ASSETS, N_BYTES>, Box<dyn Error + Send>> {
// Parse the response body into a MerkleSumTree
let json_tree = self
let response = self
.client
.post(&self.url)
.json(&json_entries)
.send()
.await
.map_err(|err| Box::new(err) as Box<dyn Error + Send>)
.unwrap()
.map_err(|err| Box::new(err) as Box<dyn Error + Send>)?;

let json_tree = response
.json::<JsonMerkleSumTree>()
.await
.map_err(|err| Box::new(err) as Box<dyn Error + Send>)
.unwrap();
.map_err(|err| Box::new(err) as Box<dyn Error + Send>)?;

let entries = json_entries
.iter()
Expand Down Expand Up @@ -109,7 +108,6 @@ mod test {
use crate::executor::spawner::ExecutorSpawner;
use crate::executor::ContainerSpawner;
use crate::orchestrator::entry_parser;
use bollard::Docker;

#[tokio::test]
async fn test_executor() -> Result<(), Box<dyn Error>> {
Expand Down
108 changes: 108 additions & 0 deletions src/executor/mock_spawner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use axum::{routing::post, Router};
use std::{
future::Future,
net::SocketAddr,
pin::Pin,
str::FromStr,
sync::atomic::{AtomicUsize, Ordering},
};
use tokio;
use tokio::sync::oneshot;

use crate::executor::{Executor, ExecutorSpawner};
use crate::mini_tree_generator::create_mst;

pub struct MockSpawner {
urls: Option<Vec<String>>,
worker_counter: AtomicUsize,
}

impl MockSpawner {
pub fn new(urls: Option<Vec<String>>) -> Self {
MockSpawner {
urls,
worker_counter: AtomicUsize::new(0),
}
}
}

impl ExecutorSpawner for MockSpawner {
fn spawn_executor(&self) -> Pin<Box<dyn Future<Output = Executor> + Send>> {
let (tx, rx) = oneshot::channel();

let id = self.worker_counter.fetch_add(1, Ordering::SeqCst);

// If urls is not None, use the urls to spawn executors
if self.urls.is_some() && self.urls.as_ref().unwrap().len() > id {
let url = self.urls.as_ref().unwrap()[id].clone();
let _ = tx.send(SocketAddr::from_str(&url).unwrap());

return Box::pin(async move {
let url = rx.await.expect("Failed to receive worker URL");
let worker_url = format!("http://{}", url);
Executor::new(worker_url, None)
});
}

// if there is no url or already used all urls, spawn a new executor
tokio::spawn(async move {
let app = Router::new().route("/", post(create_mst));
let addr = SocketAddr::from(([0, 0, 0, 0], 4000 + id as u16));

// send worker url to rx
let _ = tx.send(addr.clone());

// Start the server
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
});

// Return a Future that resolves to Executor
Box::pin(async move {
// load currnet worker counter
let url = rx.await.expect("Failed to receive worker URL");
let worker_url = format!("http://{}", url);
Executor::new(worker_url, None)
})
}

fn terminate_executors(&self) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async move {
// Nothing to do if no executors are running
})
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_new_urls() {
let spawner = MockSpawner::new(None);

// Spawn 2 executors
let executor_1 = spawner.spawn_executor().await;
let executor_2 = spawner.spawn_executor().await;

// Sleep 2 seconds for the container to be ready
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
assert!(!executor_1.get_url().is_empty());
assert!(!executor_2.get_url().is_empty());
}

#[tokio::test]
async fn test_with_given_url() {
let urls = vec!["127.0.0.1:7878".to_string()];
let spawner = MockSpawner::new(Some(urls));

// Spawn 2 executors
let executor_1 = spawner.spawn_executor().await;
let executor_2 = spawner.spawn_executor().await;

assert_eq!(executor_1.get_url(), "http://127.0.0.1:7878");
assert_eq!(executor_2.get_url(), "http://0.0.0.0:4001");
}
}
2 changes: 2 additions & 0 deletions src/executor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod container_spawner;
mod executor;
mod mock_spawner;
mod service_spawner;
mod spawner;

pub use container_spawner::ContainerSpawner;
pub use executor::Executor;
pub use mock_spawner::MockSpawner;
pub use service_spawner::ServiceSpawner;
pub use spawner::ExecutorSpawner;
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(generic_const_exprs)]
pub mod aggregation_merkle_sum_tree;
pub mod executor;
pub mod mini_tree_generator;
pub mod orchestrator;

use serde::{Deserialize, Serialize};
Expand Down
71 changes: 71 additions & 0 deletions src/mini_tree_generator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use axum::{extract::Json, http::StatusCode, response::IntoResponse};
use const_env::from_env;
use num_bigint::BigUint;

use crate::{JsonEntry, JsonMerkleSumTree, JsonNode};
use summa_backend::merkle_sum_tree::{Entry, MerkleSumTree, Node, Tree};

#[from_env]
const N_ASSETS: usize = 2;
#[from_env]
const N_BYTES: usize = 14;

fn convert_node_to_json(node: &Node<N_ASSETS>) -> JsonNode {
JsonNode {
hash: format!("{:?}", node.hash),
balances: node.balances.iter().map(|b| format!("{:?}", b)).collect(),
}
}

pub async fn create_mst(
Json(json_entries): Json<Vec<JsonEntry>>,
) -> Result<impl IntoResponse, (StatusCode, Json<JsonMerkleSumTree>)> {
// Convert `JsonEntry` -> `Entry<N_ASSETS>`
let entries = json_entries
.iter()
.map(|entry| {
let mut balances: [BigUint; N_ASSETS] = std::array::from_fn(|_| BigUint::from(0u32));
entry.balances.iter().enumerate().for_each(|(i, balance)| {
balances[i] = balance.parse::<BigUint>().unwrap();
});
Entry::new(entry.username.clone(), balances).unwrap()
})
.collect::<Vec<Entry<N_ASSETS>>>();

#[cfg(not(test))]
let entries_length = entries.len();
#[cfg(not(test))]
let starting_time = std::time::Instant::now();

// Create `MerkleSumTree<N_ASSETS, N_BYTES>` from `parsed_entries`
let tree = MerkleSumTree::<N_ASSETS, N_BYTES>::from_entries(entries, false).unwrap();

#[cfg(not(test))]
println!(
"Time to create tree({} entries): {}ms",
entries_length,
starting_time.elapsed().as_millis()
);

// Convert `MerkleSumTree<N_ASSETS, N_BYTES>` to `JsonMerkleSumTree`
let json_tree = JsonMerkleSumTree {
root: convert_node_to_json(&tree.root()),
nodes: tree
.nodes()
.iter()
.map(|layer| layer.iter().map(convert_node_to_json).collect())
.collect(),
depth: tree.depth().clone(),
entries: tree
.entries()
.iter()
.map(|entry| JsonEntry {
balances: entry.balances().iter().map(|b| b.to_string()).collect(),
username: entry.username().to_string(),
})
.collect(),
is_sorted: false, // Always false because sorted entries inside minitree is meaningless
};

Ok((StatusCode::OK, Json(json_tree)))
}
Loading

0 comments on commit b30019c

Please sign in to comment.