diff --git a/src/middlewares/methods/inject_params.rs b/src/middlewares/methods/inject_params.rs index b07ecad..d645372 100644 --- a/src/middlewares/methods/inject_params.rs +++ b/src/middlewares/methods/inject_params.rs @@ -122,42 +122,44 @@ impl Middleware for InjectParamsMiddleware { let idx = self.get_index(); match request.params.len() { - len if len == idx + 1 => { - // full params with current block + len if len > idx + 1 => { + // unexpected number of params return handle_request(request).await; } len if len <= idx => { - async move { - // without current block - let to_inject = self.get_parameter().await; - tracing::trace!("Injected param {} to method {}", &to_inject, request.method); - let params_passed = request.params.len(); - while request.params.len() < idx { - let current = request.params.len(); - if self.params[current].optional { - request.params.push(JsonValue::Null); - } else { - let (required, optional) = self.params_count(); - return Err(errors::invalid_params(format!( - "Expected {:?} parameters ({:?} optional), {:?} found instead", - required + optional, - optional, - params_passed - ))); - } + // without current block + let params_passed = request.params.len(); + while request.params.len() < idx { + let current = request.params.len(); + if self.params[current].optional { + request.params.push(JsonValue::Null); + } else { + let (required, optional) = self.params_count(); + return Err(errors::invalid_params(format!( + "Expected {:?} parameters ({:?} optional), {:?} found instead", + required + optional, + optional, + params_passed + ))); } - request.params.push(to_inject); - - handle_request(request).await } - .with_context(TRACER.context("inject_params")) - .await + // Set param to null, it will be replaced later + request.params.push(JsonValue::Null); } - _ => { - // unexpected number of params - handle_request(request).await + _ => {} // full params, block potentially might be null + }; + + // Here we are sure we have full params in the request, but it still might be set to null + async move { + if request.params[idx] == JsonValue::Null { + let to_inject = self.get_parameter().await; + tracing::trace!("Injected param {} to method {}", &to_inject, request.method); + request.params[idx] = to_inject; } + handle_request(request).await } + .with_context(TRACER.context("inject_params")) + .await } } @@ -286,6 +288,44 @@ mod tests { assert_eq!(result, json!("0x1111")); } + #[tokio::test] + async fn inject_if_param_is_null() { + let params = vec![json!("0x1234"), json!(None::<()>)]; + let (middleware, _) = create_inject_middleware( + InjectType::BlockHashAt(1), + vec![ + MethodParam { + name: "key".to_string(), + ty: "StorageKey".to_string(), + optional: false, + inject: false, + }, + MethodParam { + name: "at".to_string(), + ty: "BlockHash".to_string(), + optional: true, + inject: true, + }, + ], + ) + .await; + let result = middleware + .call( + CallRequest::new("state_getStorage", params.clone()), + Default::default(), + Box::new(move |req: CallRequest, _| { + async move { + assert_eq!(req.params, vec![json!("0x1234"), json!("0xabcd")]); + Ok(json!("0x1111")) + } + .boxed() + }), + ) + .await + .unwrap(); + assert_eq!(result, json!("0x1111")); + } + #[tokio::test] async fn inject_if_without_current_block_hash() { let (middleware, _context) = create_inject_middleware(