Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mraszyk committed Jan 12, 2025
1 parent 089487d commit e9ee36f
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 239 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 0 additions & 14 deletions rs/pocket_ic_server/src/pocket_ic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,20 +401,6 @@ impl Subnets for SubnetsImpl {
.as_ref()
.map(|subnet| subnet.state_machine.clone())
}
fn get_from_node(&self, node_id: NodeId) -> Option<Arc<StateMachine>> {
self.subnets
.read()
.unwrap()
.iter()
.find(|(_, subnet)| {
subnet
.state_machine
.nodes
.iter()
.any(|n| n.node_id == node_id)
})
.map(|(_, subnet)| subnet.state_machine.clone())
}
}

pub struct PocketIc {
Expand Down
8 changes: 2 additions & 6 deletions rs/state_machine_tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,9 @@ DEPENDENCIES = [
"@crate_index//:tokio",
"@crate_index//:tokio-util",
"@crate_index//:tower",
"@crate_index//:url",
"@crate_index//:wat",
]

MACRO_DEPENDENCIES = [
"@crate_index//:async-trait",
]

rust_library(
name = "state_machine_tests",
testonly = True,
Expand All @@ -82,7 +77,6 @@ rust_library(
"src/tests.rs",
],
crate_name = "ic_state_machine_tests",
proc_macro_deps = MACRO_DEPENDENCIES,
version = "0.9.0",
deps = DEPENDENCIES,
)
Expand Down Expand Up @@ -112,6 +106,8 @@ DEV_DEPENDENCIES = [
"//rs/universal_canister/lib",
]

MACRO_DEPENDENCIES = []

rust_binary(
name = "ic-test-state-machine",
testonly = True,
Expand Down
2 changes: 0 additions & 2 deletions rs/state_machine_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ documentation.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = { workspace = true }
candid = { workspace = true }
ciborium = { workspace = true }
clap = { workspace = true}
Expand Down Expand Up @@ -76,7 +75,6 @@ tempfile = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tower = { workspace = true }
url = { workspace = true }
wat = { workspace = true }

[dev-dependencies]
Expand Down
159 changes: 67 additions & 92 deletions rs/state_machine_tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use async_trait::async_trait;
use candid::Decode;
use core::sync::atomic::Ordering;
use ic_artifact_pool::canister_http_pool::CanisterHttpPoolImpl;
Expand Down Expand Up @@ -154,10 +153,8 @@ use ic_types::{
CanisterId, CryptoHashOfState, Cycles, NumBytes, PrincipalId, SubnetId, UserId,
};
use ic_xnet_payload_builder::{
certified_slice_pool::CertifiedSlicePool, EndpointLocator, PoolRefillTask, ProximityMap,
RefillTaskHandle, XNetClient, XNetClientError, XNetEndpointResolver, XNetPayloadBuilderImpl,
XNetPayloadBuilderMetrics, XNetSlicePoolImpl, POOL_BYTE_SIZE_SOFT_CAP,
POOL_SLICE_BYTE_SIZE_MAX,
certified_slice_pool::CertifiedSlicePool, RefillTaskHandle, XNetPayloadBuilderImpl,
XNetPayloadBuilderMetrics, XNetSlicePoolImpl,
};
use rcgen::{CertificateParams, KeyPair};
use serde::Deserialize;
Expand Down Expand Up @@ -599,60 +596,64 @@ impl PocketIngressPool {
pub trait Subnets: Send + Sync {
fn insert(&self, state_machine: Arc<StateMachine>);
fn get(&self, subnet_id: SubnetId) -> Option<Arc<StateMachine>>;
fn get_from_node(&self, node_id: NodeId) -> Option<Arc<StateMachine>>;
}

/// Struct mocking the xnet layer.
struct PocketXNetClientImpl {
/// Pool of `StateMachine`s from which the XNet messages are fetched.
/// Struct mocking the XNet layer.
struct PocketXNetImpl {
/// Pool of `StateMachine`s from which XNet messages are fetched.
subnets: Arc<dyn Subnets>,
/// The certified slice pool of the `StateMachine` for which the XNet layer is mocked.
pool: Arc<Mutex<CertifiedSlicePool>>,
/// The subnet ID of the `StateMachine` for which the XNet layer is mocked.
own_subnet_id: SubnetId,
}

impl PocketXNetClientImpl {
fn new(subnets: Arc<dyn Subnets>) -> Self {
Self { subnets }
impl PocketXNetImpl {
fn new(
subnets: Arc<dyn Subnets>,
pool: Arc<Mutex<CertifiedSlicePool>>,
own_subnet_id: SubnetId,
) -> Self {
Self {
subnets,
pool,
own_subnet_id,
}
}
}

#[async_trait]
impl XNetClient for PocketXNetClientImpl {
async fn query(
&self,
endpoint: &EndpointLocator,
) -> Result<CertifiedStreamSlice, XNetClientError> {
const API_URL_STREAM_PREFIX: &str = "/api/v1/stream/";
let url = url::Url::parse(&endpoint.url.to_string()).unwrap();
let stream_url = url.path();
assert!(stream_url.starts_with(API_URL_STREAM_PREFIX));
let subnet_id_str = &stream_url[API_URL_STREAM_PREFIX.len()..];
let subnet_id = PrincipalId::from_str(subnet_id_str).unwrap();
let mut witness_begin = None;
let mut msg_begin = None;
let mut msg_limit = None;
let mut byte_limit = None;
for (param, value) in url.query_pairs() {
let value = value.parse::<u64>().unwrap();
match param.as_ref() {
"witness_begin" => witness_begin = Some(StreamIndex::new(value)),
"index" => msg_begin = Some(StreamIndex::new(value)),
"msg_begin" => msg_begin = Some(StreamIndex::new(value)),
"msg_limit" => msg_limit = Some(value as usize),
"byte_limit" => byte_limit = Some(value as usize),
_ => panic!("Unexpected XNet request param: {}", param),
fn refill(&self, registry_version: RegistryVersion, log: ReplicaLogger) {
let refill_stream_slice_indices = self.pool.lock().unwrap().refill_stream_slice_indices();

for (subnet_id, indices) in refill_stream_slice_indices {
let sm = self.subnets.get(subnet_id).unwrap();
match sm.generate_certified_stream_slice(
self.own_subnet_id,
Some(indices.witness_begin),
Some(indices.msg_begin),
None,
Some(indices.byte_limit),
) {
Ok(slice) => {
if indices.witness_begin != indices.msg_begin {
// Pulled a stream suffix, append to pooled slice.
self.pool
.lock()
.unwrap()
.append(subnet_id, slice, registry_version, log.clone())
.unwrap();
} else {
// Pulled a complete stream, replace pooled slice (if any).
self.pool
.lock()
.unwrap()
.put(subnet_id, slice, registry_version, log.clone())
.unwrap();
}
}
Err(EncodeStreamError::NoStreamForSubnet(_)) => (),
Err(err) => panic!("Unexpected XNetClient error: {}", err),
}
}
let sm = self.subnets.get_from_node(endpoint.node_id()).unwrap();
match sm.generate_certified_stream_slice(
subnet_id.into(),
witness_begin,
msg_begin,
msg_limit,
byte_limit,
) {
Ok(stream) => Ok(stream),
Err(EncodeStreamError::NoStreamForSubnet(_)) => Err(XNetClientError::NoContent),
Err(err) => panic!("Unexpected XNetClient error: {}", err),
}
}
}

Expand Down Expand Up @@ -814,7 +815,7 @@ pub struct StateMachine {
ingress_pool: Arc<RwLock<PocketIngressPool>>,
ingress_manager: Arc<IngressManager>,
pub ingress_filter: Arc<Mutex<IngressFilterService>>,
pool_refill_task: Arc<RwLock<Option<PoolRefillTask>>>,
pocket_xnet: Arc<RwLock<Option<PocketXNetImpl>>>,
payload_builder: Arc<RwLock<Option<PayloadBuilderImpl>>>,
message_routing: SyncMessageRouting,
pub metrics_registry: MetricsRegistry,
Expand Down Expand Up @@ -1217,41 +1218,19 @@ impl StateMachineBuilder {
let certified_stream_store: Arc<dyn CertifiedStreamStore> = sm.state_manager.clone();
let certified_slice_pool = Arc::new(Mutex::new(CertifiedSlicePool::new(
certified_stream_store,
subnet_id,
&sm.metrics_registry,
)));
let xnet_slice_pool_impl = Box::new(XNetSlicePoolImpl::new(certified_slice_pool.clone()));
let node_id = sm.nodes[0].node_id;
let proximity_map = Arc::new(ProximityMap::new(
node_id,
sm.registry_client.clone(),
&sm.metrics_registry,
sm.replica_logger.clone(),
));
let endpoint_resolver: XNetEndpointResolver = XNetEndpointResolver::new(
sm.registry_client.clone(),
node_id,
sm.subnet_id,
proximity_map,
sm.replica_logger.clone(),
);
let xnet_client: Arc<dyn XNetClient> = Arc::new(PocketXNetClientImpl::new(subnets));
let xnet_metrics = Arc::new(XNetPayloadBuilderMetrics::new(&sm.metrics_registry));
let pool_refill_task = PoolRefillTask::new_for_testing(
certified_slice_pool.clone(),
endpoint_resolver,
xnet_client,
sm.runtime.handle().clone(),
xnet_metrics.clone(),
sm.replica_logger.clone(),
);
let metrics = Arc::new(XNetPayloadBuilderMetrics::new(&sm.metrics_registry));
let xnet_payload_builder = Arc::new(XNetPayloadBuilderImpl::new_from_components(
sm.state_manager.clone(),
sm.state_manager.clone(),
sm.registry_client.clone(),
rng,
xnet_slice_pool_impl,
refill_task_handle,
xnet_metrics,
metrics,
sm.replica_logger.clone(),
));

Expand Down Expand Up @@ -1280,9 +1259,10 @@ impl StateMachineBuilder {
sm.replica_logger.clone(),
));

// Put pool refill task into `StateMachine`
// which contains no `PoolRefillTask` after creation.
*sm.pool_refill_task.write().unwrap() = Some(pool_refill_task);
// Put `PocketXNetImpl` into `StateMachine`
// which contains no `PocketXNetImpl` after creation.
let pocket_xnet_impl = PocketXNetImpl::new(subnets, certified_slice_pool, subnet_id);
*sm.pocket_xnet.write().unwrap() = Some(pocket_xnet_impl);
// Instantiate a `PayloadBuilderImpl` and put it into `StateMachine`
// which contains no `PayloadBuilderImpl` after creation.
*sm.payload_builder.write().unwrap() = Some(PayloadBuilderImpl::new(
Expand Down Expand Up @@ -1326,7 +1306,7 @@ impl StateMachine {
/// because the payload builder contains an `Arc` of this `StateMachine`
/// which creates a circular dependency preventing this `StateMachine`s from being dropped.
pub fn drop_payload_builder(&self) {
self.pool_refill_task.write().unwrap().take();
self.pocket_xnet.write().unwrap().take();
self.payload_builder.write().unwrap().take();
}

Expand Down Expand Up @@ -1371,17 +1351,12 @@ impl StateMachine {
membership_version: subnet_record.clone(),
context_version: subnet_record,
};
let pool_refill_task = self.pool_refill_task.read().unwrap();
let pool_refill_task = pool_refill_task.as_ref().unwrap();
self.runtime.block_on(async move {
pool_refill_task
.refill_pool(
POOL_BYTE_SIZE_SOFT_CAP,
POOL_SLICE_BYTE_SIZE_MAX,
registry_version,
)
.await;
});
self.pocket_xnet
.read()
.unwrap()
.as_ref()
.unwrap()
.refill(registry_version, self.replica_logger.clone());
let payload_builder = self.payload_builder.read().unwrap();
let payload_builder = payload_builder.as_ref().unwrap();
let batch_payload = payload_builder.get_payload(
Expand Down Expand Up @@ -1808,7 +1783,7 @@ impl StateMachine {
ingress_pool,
ingress_manager: ingress_manager.clone(),
ingress_filter: Arc::new(Mutex::new(execution_services.ingress_filter)),
pool_refill_task: Arc::new(RwLock::new(None)), // set by `StateMachineBuilder::build_with_subnets`
pocket_xnet: Arc::new(RwLock::new(None)), // set by `StateMachineBuilder::build_with_subnets`
payload_builder: Arc::new(RwLock::new(None)), // set by `StateMachineBuilder::build_with_subnets`
ingress_history_reader: execution_services.ingress_history_reader,
message_routing,
Expand Down
10 changes: 1 addition & 9 deletions rs/state_machine_tests/tests/multi_subnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use ic_state_machine_tests::{
use ic_test_utilities_types::ids::user_test_id;
use ic_types::{
ingress::{IngressStatus, WasmResult},
CanisterId, Cycles, NodeId, SubnetId,
CanisterId, Cycles, SubnetId,
};
use ic_universal_canister::{wasm, CallArgs, UNIVERSAL_CANISTER_WASM};
use std::collections::BTreeMap;
Expand Down Expand Up @@ -38,14 +38,6 @@ impl Subnets for SubnetsImpl {
fn get(&self, subnet_id: SubnetId) -> Option<Arc<StateMachine>> {
self.subnets.read().unwrap().get(&subnet_id).cloned()
}
fn get_from_node(&self, node_id: NodeId) -> Option<Arc<StateMachine>> {
self.subnets
.read()
.unwrap()
.iter()
.find(|(_, subnet)| subnet.nodes.iter().any(|n| n.node_id == node_id))
.map(|(_, subnet)| subnet.clone())
}
}

fn test_setup(
Expand Down
Loading

0 comments on commit e9ee36f

Please sign in to comment.