Skip to content

Commit

Permalink
Atomic cache initialization (#110)
Browse files Browse the repository at this point in the history
Only run cache refresh request once in case of concurrent requests. Cache `Vec<u8>` instead of strings to simplify cache request design.
  • Loading branch information
lstrojny authored Feb 17, 2023
1 parent 664f0d7 commit 395cd56
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 130 deletions.
31 changes: 13 additions & 18 deletions src/providers/deutscher_wetterdienst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,14 @@ fn reqwest_cached_measurement_csv(
"{BASE_URL}/10minutenwerte_TU_{station_id}_now.zip"
))?;

let key = (method.clone(), url.clone());
let value = cache.get(&key);

if let Some(csv) = value {
debug!("Found cached measurement data for {}", station_id);
return Ok(csv);
}

debug!("No cached measurement data found for {}", station_id);

let zip = client.request(method, url).send()?.bytes();
let csv = read_measurement_data_zip(&zip?)?;

cache.insert(key, csv.clone());

Ok(csv)
request_cached(&HttpCacheRequest::new(
SOURCE_URI,
client,
cache,
&method,
&url,
|body| read_measurement_data_zip(body),
))
}

impl WeatherProvider for DeutscherWetterdienst {
Expand All @@ -222,8 +214,11 @@ impl WeatherProvider for DeutscherWetterdienst {
cache,
&Method::GET,
&Url::parse(STATION_LIST_URL)?,
|r| Ok(r.text_with_charset("iso-8859-15")?),
|b| Ok(parse_weather_station_list_csv(b)),
|body| {
let str: String = body.iter().map(|&c| c as char).collect();

Ok(parse_weather_station_list_csv(&str))
},
))?;

let closest_station = find_closest_weather_station(&request.query, &stations)?;
Expand Down
184 changes: 75 additions & 109 deletions src/providers/http_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::fmt::Debug;
use std::sync::RwLock;
use std::time::Duration;

pub type Cache = MokaCache<(Method, Url), String>;
pub type Cache = MokaCache<(Method, Url), Vec<u8>>;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Configuration {
Expand All @@ -34,8 +34,7 @@ pub struct HttpCacheRequest<'a, R: Debug = String> {
cache: &'a HttpRequestCache,
method: &'a Method,
url: &'a Url,
to_string: fn(response: Response) -> anyhow::Result<String>,
deserialize: fn(string: &str) -> anyhow::Result<R>,
deserialize: fn(body: &Vec<u8>) -> anyhow::Result<R>,
}

const CONSECUTIVE_FAILURE_COUNT: u32 = 3;
Expand All @@ -54,16 +53,14 @@ impl HttpCacheRequest<'_> {
cache: &'a HttpRequestCache,
method: &'a Method,
url: &'a Url,
to_string: fn(response: Response) -> anyhow::Result<String>,
deserialize: fn(string: &str) -> anyhow::Result<T>,
deserialize: fn(body: &Vec<u8>) -> anyhow::Result<T>,
) -> HttpCacheRequest<'a, T> {
HttpCacheRequest {
source,
client,
cache,
method,
url,
to_string,
deserialize,
}
}
Expand All @@ -75,144 +72,119 @@ impl HttpCacheRequest<'_> {
method: &'a Method,
url: &'a Url,
) -> HttpCacheRequest<'a, T> {
HttpCacheRequest::new::<T>(
source,
client,
cache,
method,
url,
response_to_string,
serde_deserialize_body,
)
HttpCacheRequest::new::<T>(source, client, cache, method, url, serde_deserialize_body)
}
}

fn response_to_string(response: Response) -> anyhow::Result<String> {
Ok(response.text()?)
}

fn serde_deserialize_body<T: Debug + DeserializeOwned>(body: &str) -> anyhow::Result<T> {
fn serde_deserialize_body<T: Debug + DeserializeOwned>(body: &Vec<u8>) -> anyhow::Result<T> {
trace!("Deserializing body {body:?}");
Ok(serde_json::from_str(body)?)
Ok(serde_json::from_slice(body)?)
}

pub fn request_cached<R: Debug>(request: &HttpCacheRequest<R>) -> anyhow::Result<R> {
let key = (request.method.clone(), request.url.clone());
let value = request.cache.get(&key);

debug!(
"Checking cache item for request \"{:#} {:#}\" for {:?} with lifetime {:?}",
request.method,
request.url,
request.source,
request
.cache
.policy()
.time_to_live()
.unwrap_or(Duration::from_secs(0))
);

if let Some(value) = value {
let value = request.cache.try_get_with_by_ref(&key, || {
debug!(
"Found cached item for \"{:#} {:#}\"",
request.method, request.url
"Generating cache item for request \"{:#} {:#}\" for {:?} with lifetime {:?}",
request.method,
request.url,
request.source,
request
.cache
.policy()
.time_to_live()
.unwrap_or(Duration::from_secs(0))
);

let des = (request.deserialize)(&value)?;
let cicruit_breaker_scope = request
.url
.host_str()
.ok_or_else(|| anyhow!("Could not extract host from URL"))?;

return Ok(des);
}
// Separate scope so read lock is dropped at the end if circuit breaker does not yet exist
{
let circuit_breaker_registry_ro =
CIRCUIT_BREAKER_REGISTRY.read().expect("Poisoned lock");

debug!(
"No cache item found for \"{:#} {:#}\". Requesting",
request.method, request.url
);

let cicruit_breaker_scope = request
.url
.host_str()
.ok_or_else(|| anyhow!("Could not extract host from URL"))?;

// Separate scope so read lock is dropped at the end if circuit breaker does not yet exist
{
let circuit_breaker_registry_ro = CIRCUIT_BREAKER_REGISTRY.read().expect("Poisoned lock");
trace!("Read lock acquired for {:?}", cicruit_breaker_scope);

trace!("Read lock acquired for {:?}", cicruit_breaker_scope);
if let Some(cb) = circuit_breaker_registry_ro.get(cicruit_breaker_scope) {
return request_url_with_circuit_breaker(cicruit_breaker_scope, cb, request);
}

if let Some(cb) = circuit_breaker_registry_ro.get(cicruit_breaker_scope) {
return request_url_with_circuit_breaker(cicruit_breaker_scope, cb, request, &key);
drop(circuit_breaker_registry_ro);
}

drop(circuit_breaker_registry_ro);
}
ensure_circuit_breaker(cicruit_breaker_scope);

// Separate scope so write lock is dropped at the end
{
trace!(
"Trying to acquire write lock to instantiate circuit breaker {:?}",
"Trying to acquire read lock after circuit breaker {:?} was instantiated",
cicruit_breaker_scope
);

let mut circuit_breaker_registry_rw =
CIRCUIT_BREAKER_REGISTRY.write().expect("Poisoned lock");
let circuit_breaker_registry_ro = CIRCUIT_BREAKER_REGISTRY
.read()
.expect("Lock should not be poisoned");
trace!(
"Write lock acquired to instantiate circuit breaker {:?}",
"Read lock acquired after circuit breaker {:?} was instantiated",
cicruit_breaker_scope
);
let circuit_breaker = circuit_breaker_registry_ro
.get(cicruit_breaker_scope)
.expect("Circuit breaker must now exist");

if circuit_breaker_registry_rw.contains_key(cicruit_breaker_scope) {
trace!(
"Circuit breaker {:?} already instantiated, skipping",
cicruit_breaker_scope
);
} else {
trace!(
"Circuit breaker {:?} not yet instantiated, instantiating",
cicruit_breaker_scope
);
request_url_with_circuit_breaker(cicruit_breaker_scope, circuit_breaker, request)
});

let circuit_breaker = Config::new()
.failure_policy(consecutive_failures(
CONSECUTIVE_FAILURE_COUNT,
exponential(
Duration::from_secs(EXPONENTIAL_BACKOFF_START_SECS),
Duration::from_secs(EXPONENTIAL_BACKOFF_MAX_SECS),
),
))
.build();

circuit_breaker_registry_rw.insert(cicruit_breaker_scope.to_string(), circuit_breaker);

trace!("Circuit breaker {:?} instantiated", cicruit_breaker_scope);
}

drop(circuit_breaker_registry_rw);
match value {
Ok(v) => Ok((request.deserialize)(&v)?),
Err(e) => Err(anyhow!(e)),
}
}

fn ensure_circuit_breaker(cicruit_breaker_scope: &str) {
trace!(
"Trying to acquire read lock after circuit breaker {:?} was instantiated",
"Trying to acquire write lock to instantiate circuit breaker {:?}",
cicruit_breaker_scope
);
let circuit_breaker_registry_ro = CIRCUIT_BREAKER_REGISTRY
.read()
.expect("Lock should not be poisoned");

let mut circuit_breaker_registry_rw = CIRCUIT_BREAKER_REGISTRY.write().expect("Poisoned lock");
trace!(
"Read lock acquired after circuit breaker {:?} was instantiated",
"Write lock acquired to instantiate circuit breaker {:?}",
cicruit_breaker_scope
);
let circuit_breaker = circuit_breaker_registry_ro
.get(cicruit_breaker_scope)
.expect("Circuit breaker must now exist");

request_url_with_circuit_breaker(cicruit_breaker_scope, circuit_breaker, request, &key)
if !circuit_breaker_registry_rw.contains_key(cicruit_breaker_scope) {
trace!(
"Circuit breaker {:?} not yet instantiated, instantiating",
cicruit_breaker_scope
);

let circuit_breaker = create_circuit_breaker();

circuit_breaker_registry_rw.insert(cicruit_breaker_scope.to_string(), circuit_breaker);

trace!("Circuit breaker {:?} instantiated", cicruit_breaker_scope);
}
}

fn create_circuit_breaker() -> StateMachine<ConsecutiveFailures<Exponential>, ()> {
Config::new()
.failure_policy(consecutive_failures(
CONSECUTIVE_FAILURE_COUNT,
exponential(
Duration::from_secs(EXPONENTIAL_BACKOFF_START_SECS),
Duration::from_secs(EXPONENTIAL_BACKOFF_MAX_SECS),
),
))
.build()
}

fn request_url_with_circuit_breaker<R: Debug>(
circuit_breaker_scope: &str,
circuit_breaker: &HttpCircuitBreaker,
request: &HttpCacheRequest<R>,
key: &(Method, Url),
) -> anyhow::Result<R> {
) -> anyhow::Result<Vec<u8>> {
match circuit_breaker.call(|| request_url(request)) {
Err(Error::Inner(e)) => Err(anyhow!(e)),
Err(Error::Rejected) => Err(anyhow!(
Expand All @@ -226,13 +198,7 @@ fn request_url_with_circuit_breaker<R: Debug>(
response.status()
);

let body = (request.to_string)(response)?;

request.cache.insert(key.clone(), body.clone());

let des = (request.deserialize)(&body)?;

Ok(des)
Ok(response.bytes().map(|v| v.to_vec())?)
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions src/providers/nogoodnik.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ impl WeatherProvider for Nogoodnik {
cache: &HttpRequestCache,
_request: &WeatherRequest<Coordinates>,
) -> anyhow::Result<Weather> {
let _response = request_cached(&HttpCacheRequest::new(
request_cached(&HttpCacheRequest::new_json_request(
SOURCE_URI,
client,
cache,
&Method::GET,
&Url::parse("http://example.org/404")?,
|r| Ok(r.text()?),
|v| Ok(v.to_string()),
))?;

Err(format_err!("This provider is no good and always fails"))
Expand Down

0 comments on commit 395cd56

Please sign in to comment.