diff --git a/src/middlewares/methods/block_tag.rs b/src/middlewares/methods/block_tag.rs index 531e497..b71b1ba 100644 --- a/src/middlewares/methods/block_tag.rs +++ b/src/middlewares/methods/block_tag.rs @@ -60,6 +60,8 @@ impl BlockTagMiddleware { } } "latest" => { + // bypass cache for latest block to avoid caching forks + context.insert(BypassCache(true)); let (_, number) = self.api.get_head().read().await; Some(format!("0x{:x}", number).into()) } @@ -68,7 +70,23 @@ impl BlockTagMiddleware { context.insert(BypassCache(true)); None } - _ => None, + number => { + // bypass cache for block number to avoid caching forks unless it's a finalized block + let mut bypass_cache = true; + if let Some((_, finalized_number)) = self.api.current_finalized_head() { + if let Some(hex_number) = number.strip_prefix("0x") { + if let Ok(number) = u64::from_str_radix(hex_number, 16) { + if number <= finalized_number { + bypass_cache = false; + } + } + } + } + if bypass_cache { + context.insert(BypassCache(true)); + } + None + } } } else { None @@ -136,6 +154,10 @@ mod tests { } } + fn bypass_cache(context: &TypeRegistry) -> bool { + context.get::().map_or(false, |x| x.0) + } + async fn create_client() -> (ExecutionContext, EthApi) { let mut builder = TestServerBuilder::new(); @@ -169,7 +191,7 @@ mod tests { #[tokio::test] async fn skip_replacement_if_no_tag() { - let params = vec![json!("0x1234"), json!("0x5678")]; + let params = vec![json!("0x1234"), json!("0x4321")]; let (middleware, mut context) = create_block_tag_middleware(vec![ MethodParam { name: "key".to_string(), @@ -195,8 +217,11 @@ mod tests { .call( CallRequest::new("state_getStorage", params.clone()), Default::default(), - Box::new(move |req: CallRequest, _| { + Box::new(move |req: CallRequest, context| { async move { + // cache bypassed, cannot determine finalized block + assert!(bypass_cache(&context)); + // no replacement assert_eq!(req.params, params); Ok(json!("0x1111")) } @@ -236,7 +261,7 @@ mod tests { tokio::time::sleep(Duration::from_millis(10)).await; let sub = context.subscribe_rx.recv().await.unwrap(); if sub.params.as_array().unwrap().contains(&json!("newFinalizedHeads")) { - sub.run_sink_tasks(vec![SinkTask::Send(json!({ "number": "0x5430", "hash": "0x00" }))]) + sub.run_sink_tasks(vec![SinkTask::Send(json!({ "number": "0x4321", "hash": "0x01" }))]) .await } @@ -252,8 +277,11 @@ mod tests { .call( CallRequest::new("state_getStorage", vec![json!("0x1234"), json!("latest")]), Default::default(), - Box::new(move |req: CallRequest, _| { + Box::new(move |req: CallRequest, context| { async move { + // cache bypassed for latest + assert!(bypass_cache(&context)); + // latest block replaced with block number assert_eq!(req.params, vec![json!("0x1234"), json!("0x4321")]); Ok(json!("0x1111")) } @@ -270,8 +298,11 @@ mod tests { .call( CallRequest::new("state_getStorage", vec![json!("0x1234"), json!("finalized")],), Default::default(), - Box::new(move |req: CallRequest, _| { + Box::new(move |req: CallRequest, context| { async move { + // cache bypassed, block tag not replaced + assert!(bypass_cache(&context)); + // block tag not replaced assert_eq!(req.params, vec![json!("0x1234"), json!("finalized")]); Ok(json!("0x1111")) } @@ -291,9 +322,12 @@ mod tests { .call( CallRequest::new("state_getStorage", vec![json!("0x1234"), json!("finalized")],), Default::default(), - Box::new(move |req: CallRequest, _| { + Box::new(move |req: CallRequest, context| { async move { - assert_eq!(req.params, vec![json!("0x1234"), json!("0x5430")]); + // cache not bypassed, finalized replaced with block number + assert!(!bypass_cache(&context)); + // block tag replaced with block number + assert_eq!(req.params, vec![json!("0x1234"), json!("0x4321")]); Ok(json!("0x1111")) } .boxed() @@ -309,8 +343,11 @@ mod tests { .call( CallRequest::new("state_getStorage", vec![json!("0x1234"), json!("latest")]), Default::default(), - Box::new(move |req: CallRequest, _| { + Box::new(move |req: CallRequest, context| { async move { + // cache bypassed for latest + assert!(bypass_cache(&context)); + // latest block replaced with block number assert_eq!(req.params, vec![json!("0x1234"), json!("0x5432")]); Ok(json!("0x1111")) } diff --git a/src/middlewares/methods/inject_params.rs b/src/middlewares/methods/inject_params.rs index 6992fdc..b07ecad 100644 --- a/src/middlewares/methods/inject_params.rs +++ b/src/middlewares/methods/inject_params.rs @@ -6,7 +6,9 @@ use std::sync::Arc; use crate::{ config::MethodParam, extensions::api::{SubstrateApi, ValueHandle}, - middlewares::{CallRequest, CallResult, Middleware, MiddlewareBuilder, NextFn, RpcMethod, TRACER}, + middlewares::{ + methods::cache::BypassCache, CallRequest, CallResult, Middleware, MiddlewareBuilder, NextFn, RpcMethod, TRACER, + }, utils::errors, utils::{TypeRegistry, TypeRegistryRef}, }; @@ -18,6 +20,7 @@ pub enum InjectType { pub struct InjectParamsMiddleware { head: ValueHandle<(JsonValue, u64)>, + finalized: ValueHandle<(JsonValue, u64)>, inject: InjectType, params: Vec, } @@ -60,6 +63,7 @@ impl InjectParamsMiddleware { pub fn new(api: Arc, inject: InjectType, params: Vec) -> Self { Self { head: api.get_head(), + finalized: api.get_finalized_head(), inject, params, } @@ -99,14 +103,28 @@ impl Middleware for InjectParamsMiddleware { async fn call( &self, mut request: CallRequest, - context: TypeRegistry, + mut context: TypeRegistry, next: NextFn, ) -> CallResult { + let handle_request = |request: CallRequest| async { + for (idx, param) in self.params.iter().enumerate() { + if param.ty == "BlockNumber" { + if let Some(number) = request.params.get(idx).and_then(|x| x.as_u64()) { + let (_, finalized) = self.finalized.read().await; + if number > finalized { + context.insert(BypassCache(true)); + } + } + } + } + next(request, context).await + }; + let idx = self.get_index(); match request.params.len() { len if len == idx + 1 => { // full params with current block - return next(request, context).await; + return handle_request(request).await; } len if len <= idx => { async move { @@ -130,14 +148,14 @@ impl Middleware for InjectParamsMiddleware { } request.params.push(to_inject); - next(request, context).await + handle_request(request).await } .with_context(TRACER.context("inject_params")) .await } _ => { // unexpected number of params - next(request, context).await + handle_request(request).await } } } @@ -160,9 +178,14 @@ mod tests { api: Arc, _server: ServerHandle, head_rx: mpsc::Receiver, - _finalized_head_rx: mpsc::Receiver, + finalized_head_rx: mpsc::Receiver, block_hash_rx: mpsc::Receiver, head_sink: Option, + finalized_head_sink: Option, + } + + fn bypass_cache(context: &TypeRegistry) -> bool { + context.get::().map_or(false, |x| x.0) } async fn create_client() -> ExecutionContext { @@ -171,7 +194,7 @@ mod tests { let head_rx = builder.register_subscription("chain_subscribeNewHeads", "chain_newHead", "chain_unsubscribeNewHeads"); - let _finalized_head_rx = builder.register_subscription( + let finalized_head_rx = builder.register_subscription( "chain_subscribeFinalizedHeads", "chain_finalizedHead", "chain_unsubscribeFinalizedHeads", @@ -188,9 +211,10 @@ mod tests { api: Arc::new(api), _server, head_rx, - _finalized_head_rx, + finalized_head_rx, block_hash_rx, head_sink: None, + finalized_head_sink: None, } } @@ -208,7 +232,15 @@ mod tests { req.respond(json!("0xabcd")); } + let finalized_sub = context.finalized_head_rx.recv().await.unwrap(); + finalized_sub.send(json!({ "number": "0x4321" })).await; + { + let req = context.block_hash_rx.recv().await.unwrap(); + req.respond(json!("0xabcd")); + } + context.head_sink = Some(head_sub.sink); + context.finalized_head_sink = Some(finalized_sub.sink); ( InjectParamsMiddleware::new(context.api.clone(), inject_type, params), @@ -428,6 +460,18 @@ mod tests { let req = context.block_hash_rx.recv().await.unwrap(); req.respond(json!("0xbcde")); } + + // finalized updated + context + .finalized_head_sink + .unwrap() + .send(SubscriptionMessage::from_json(&json!({ "number": "0x5432" })).unwrap()) + .await + .unwrap(); + { + let req = context.block_hash_rx.recv().await.unwrap(); + req.respond(json!("0xbcde")); + } tokio::time::sleep(std::time::Duration::from_millis(1)).await; let result2 = middleware @@ -446,4 +490,176 @@ mod tests { .unwrap(); assert_eq!(result2, json!("0x1111")); } + + #[tokio::test] + async fn skip_cache_if_block_number_not_finalized() { + let (middleware, mut context) = create_inject_middleware( + InjectType::BlockNumberAt(1), + vec![ + MethodParam { + name: "key".to_string(), + ty: "StorageKey".to_string(), + optional: false, + inject: false, + }, + MethodParam { + name: "at".to_string(), + ty: "BlockNumber".to_string(), + optional: true, + inject: true, + }, + ], + ) + .await; + + // head is finalized, cache should not be skipped + { + let result = middleware + .call( + CallRequest::new("state_getStorage", vec![json!("0x1234")]), + Default::default(), + Box::new(move |req: CallRequest, context| { + async move { + // cache not bypassed + assert!(!bypass_cache(&context)); + // block number is not finalized + assert_eq!(req.params, vec![json!("0x1234"), json!(0x4321)]); + Ok(json!("0x1111")) + } + .boxed() + }), + ) + .await + .unwrap(); + assert_eq!(result, json!("0x1111")); + } + + // block head is updated but not finalized, cache should be skipped + { + // head updated but not finalized + context + .head_sink + .unwrap() + .send(SubscriptionMessage::from_json(&json!({ "number": "0x5432" })).unwrap()) + .await + .unwrap(); + { + let req = context.block_hash_rx.recv().await.unwrap(); + req.respond(json!("0xbcde")); + } + tokio::time::sleep(std::time::Duration::from_millis(1)).await; + + let result = middleware + .call( + CallRequest::new("state_getStorage", vec![json!("0x1234")]), + Default::default(), + Box::new(move |req: CallRequest, context| { + async move { + // cache bypassed + assert!(bypass_cache(&context)); + // block number is injected + assert_eq!(req.params, vec![json!("0x1234"), json!(0x5432)]); + Ok(json!("0x1111")) + } + .boxed() + }), + ) + .await + .unwrap(); + assert_eq!(result, json!("0x1111")); + } + + // request with head block number should skip cache + { + let result = middleware + .call( + CallRequest::new("state_getStorage", vec![json!("0x1234"), json!(0x5432)]), + Default::default(), + Box::new(move |req: CallRequest, context| { + async move { + // cache bypassed + assert!(bypass_cache(&context)); + // params not changed + assert_eq!(req.params, vec![json!("0x1234"), json!(0x5432)]); + Ok(json!("0x1111")) + } + .boxed() + }), + ) + .await + .unwrap(); + assert_eq!(result, json!("0x1111")); + } + + // request with finalized block number should not skip cache + { + let result = middleware + .call( + CallRequest::new("state_getStorage", vec![json!("0x1234"), json!(0x4321)]), + Default::default(), + Box::new(move |req: CallRequest, context| { + async move { + // cache not bypassed + assert!(!bypass_cache(&context)); + // params not changed + assert_eq!(req.params, vec![json!("0x1234"), json!(0x4321)]); + Ok(json!("0x1111")) + } + .boxed() + }), + ) + .await + .unwrap(); + assert_eq!(result, json!("0x1111")); + } + + // request with wrong params count will be handled + { + // block is finalized, cache should not be skipped + let result = middleware + .call( + CallRequest::new( + "state_getStorage", + vec![json!("0x1234"), json!(0x4321), json!("0xabcd")], + ), + Default::default(), + Box::new(move |req: CallRequest, context| { + async move { + // cache not bypassed + assert!(!bypass_cache(&context)); + // params not changed + assert_eq!(req.params, vec![json!("0x1234"), json!(0x4321), json!("0xabcd")]); + Ok(json!("0x1111")) + } + .boxed() + }), + ) + .await + .unwrap(); + assert_eq!(result, json!("0x1111")); + + // block is not finalized, cache should be skipped + let result = middleware + .call( + CallRequest::new( + "state_getStorage", + vec![json!("0x1234"), json!(0x5432), json!("0xabcd")], + ), + Default::default(), + Box::new(move |req: CallRequest, context| { + async move { + // cache bypassed + assert!(bypass_cache(&context)); + // params not changed + assert_eq!(req.params, vec![json!("0x1234"), json!(0x5432), json!("0xabcd")]); + Ok(json!("0x1111")) + } + .boxed() + }), + ) + .await + .unwrap(); + assert_eq!(result, json!("0x1111")); + } + } }