Skip to content

Commit

Permalink
validate middleware (#156)
Browse files Browse the repository at this point in the history
* validate middleware

* validate extension for config

* rename
  • Loading branch information
ermalkaleci authored Apr 16, 2024
1 parent 354f917 commit 50b826d
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 2 deletions.
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,27 @@ It's also possible to run individual benchmarks by:

```
cargo bench --bench bench ws_round_trip
```
```

## Validate Middleware

This middleware will intercept all method request/responses and compare the result directly with healthy endpoint responses.
This is useful for debugging to make sure the returned values are as expected.
You can enable validate middleware on your config file.
```yml
middlewares:
methods:
- validate
```
NOTE: Keep in mind that if you place `validate` middleware before `inject_params` you may get false positive errors because the request will not be the same.

Ignored methods can be defined in extension config:
```yml
extensions:
validator:
ignore_methods:
- system_health
- system_name
- system_version
- author_pendingExtrinsics
```
7 changes: 7 additions & 0 deletions src/extensions/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ mod tests;
const TRACER: utils::telemetry::Tracer = utils::telemetry::Tracer::new("client");

pub struct Client {
endpoints: Vec<Arc<Endpoint>>,
sender: tokio::sync::mpsc::Sender<Message>,
rotation_notify: Arc<Notify>,
retries: u32,
Expand Down Expand Up @@ -187,6 +188,7 @@ impl Client {

let rotation_notify = Arc::new(Notify::new());
let rotation_notify_bg = rotation_notify.clone();
let endpoints_ = endpoints.clone();

let background_task = tokio::spawn(async move {
let request_backoff_counter = Arc::new(AtomicU32::new(0));
Expand Down Expand Up @@ -395,6 +397,7 @@ impl Client {
});

Ok(Self {
endpoints: endpoints_,
sender: message_tx,
rotation_notify,
retries: retries.unwrap_or(3),
Expand All @@ -406,6 +409,10 @@ impl Client {
Self::new(endpoints, None, None, None, None)
}

pub fn endpoints(&self) -> &Vec<Arc<Endpoint>> {
self.endpoints.as_ref()
}

pub async fn request(&self, method: &str, params: Vec<JsonValue>) -> CallResult {
async move {
let (tx, rx) = tokio::sync::oneshot::channel();
Expand Down
2 changes: 2 additions & 0 deletions src/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod merge_subscription;
pub mod rate_limit;
pub mod server;
pub mod telemetry;
pub mod validator;

#[async_trait]
pub trait Extension: Sized {
Expand Down Expand Up @@ -138,4 +139,5 @@ define_all_extensions! {
server: server::SubwayServerBuilder,
event_bus: event_bus::EventBus,
rate_limit: rate_limit::RateLimitBuilder,
validator: validator::Validator,
}
67 changes: 67 additions & 0 deletions src/extensions/validator/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use crate::extensions::client::Client;
use crate::middlewares::{CallRequest, CallResult};
use crate::utils::errors;
use async_trait::async_trait;
use serde::Deserialize;
use std::sync::Arc;

use super::{Extension, ExtensionRegistry};

#[derive(Default)]
pub struct Validator {
pub config: ValidateConfig,
}

#[derive(Deserialize, Default, Debug, Clone)]
pub struct ValidateConfig {
pub ignore_methods: Vec<String>,
}

#[async_trait]
impl Extension for Validator {
type Config = ValidateConfig;

async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result<Self, anyhow::Error> {
Ok(Self::new(config.clone()))
}
}

impl Validator {
pub fn new(config: ValidateConfig) -> Self {
Self { config }
}

pub fn ignore(&self, method: &String) -> bool {
self.config.ignore_methods.contains(method)
}

pub fn validate(&self, client: Arc<Client>, request: CallRequest, response: CallResult) {
tokio::spawn(async move {
let healthy_endpoints = client.endpoints().iter().filter(|x| x.health().score() > 0);
futures::future::join_all(healthy_endpoints.map(|endpoint| async {
let expected = endpoint
.request(
&request.method,
request.params.clone(),
std::time::Duration::from_secs(30),
)
.await
.map_err(errors::map_error);

if response != expected {
let request = serde_json::to_string_pretty(&request).unwrap_or_default();
let actual = match &response {
Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(),
Err(e) => e.to_string()
};
let expected = match &expected {
Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_default(),
Err(e) => e.to_string()
};
let endpoint_url = endpoint.url();
tracing::error!("Response mismatch for request:\n{request}\nSubway response:\n{actual}\nEndpoint {endpoint_url} response:\n{expected}");
}
})).await;
});
}
}
1 change: 1 addition & 0 deletions src/middlewares/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub async fn create_method_middleware(
"block_tag" => block_tag::BlockTagMiddleware::build(method, extensions).await,
"inject_params" => inject_params::InjectParamsMiddleware::build(method, extensions).await,
"delay" => delay::DelayMiddleware::build(method, extensions).await,
"validate" => validate::ValidateMiddleware::build(method, extensions).await,
#[cfg(test)]
"crazy" => testing::CrazyMiddleware::build(method, extensions).await,
_ => panic!("Unknown method middleware: {}", name),
Expand Down
1 change: 1 addition & 0 deletions src/middlewares/methods/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod delay;
pub mod inject_params;
pub mod response;
pub mod upstream;
pub mod validate;

#[cfg(test)]
pub mod testing;
52 changes: 52 additions & 0 deletions src/middlewares/methods/validate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use async_trait::async_trait;
use std::sync::Arc;

use crate::{
extensions::{client::Client, validator::Validator},
middlewares::{CallRequest, CallResult, Middleware, MiddlewareBuilder, NextFn, RpcMethod},
utils::{TypeRegistry, TypeRegistryRef},
};

pub struct ValidateMiddleware {
validator: Arc<Validator>,
client: Arc<Client>,
}

impl ValidateMiddleware {
pub fn new(validator: Arc<Validator>, client: Arc<Client>) -> Self {
Self { validator, client }
}
}

#[async_trait]
impl MiddlewareBuilder<RpcMethod, CallRequest, CallResult> for ValidateMiddleware {
async fn build(
_method: &RpcMethod,
extensions: &TypeRegistryRef,
) -> Option<Box<dyn Middleware<CallRequest, CallResult>>> {
let validate = extensions.read().await.get::<Validator>().unwrap_or_default();

let client = extensions
.read()
.await
.get::<Client>()
.expect("Client extension not found");
Some(Box::new(ValidateMiddleware::new(validate, client)))
}
}

#[async_trait]
impl Middleware<CallRequest, CallResult> for ValidateMiddleware {
async fn call(
&self,
request: CallRequest,
context: TypeRegistry,
next: NextFn<CallRequest, CallResult>,
) -> CallResult {
let result = next(request.clone(), context).await;
if !self.validator.ignore(&request.method) {
self.validator.validate(self.client.clone(), request, result.clone());
}
result
}
}
3 changes: 2 additions & 1 deletion src/middlewares/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use jsonrpsee::{
PendingSubscriptionSink,
};
use opentelemetry::trace::FutureExt as _;
use serde::Serialize;
use std::{
fmt::{Debug, Formatter},
sync::Arc,
Expand All @@ -20,7 +21,7 @@ pub mod factory;
pub mod methods;
pub mod subscriptions;

#[derive(Debug)]
#[derive(Clone, Debug, Serialize)]
/// Represents a RPC request made to a middleware function.
pub struct CallRequest {
pub method: String,
Expand Down

0 comments on commit 50b826d

Please sign in to comment.