diff --git a/src/lib.rs b/src/lib.rs index 5d0c8d2..322143e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,9 +9,33 @@ use tracing::Level; mod error; pub use error::CacheError; +const LOCK_TTL: usize = 60_000; // milliseconds +const LOCK_RETRY_TIME: Duration = Duration::from_secs(1); + +#[derive(Serialize, Deserialize)] +struct CacheItem { + timestamp: DateTime, + payload: Item, +} + +#[tracing::instrument(skip_all, level = Level::TRACE)] +async fn query_cache( + conn: &mut redis::aio::MultiplexedConnection, + key_name: &[u8], +) -> Result>, CacheError> +where + Item: DeserializeOwned, +{ + Ok(conn + .get::<&[u8], Option>>(key_name) + .await? + .map(|s| ciborium::from_reader(&s[..])) + .transpose()?) +} + #[derive(Clone)] pub struct RedisClient { - redis: redis::Client, + client: redis::Client, } pub struct RedisCacheArgs<'a> { @@ -20,6 +44,15 @@ pub struct RedisCacheArgs<'a> { pub expiry: Duration, } +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum CacheDecision { + Cache, + NoCache, +} + +type OutFunc<'f, Args, I, E> = + Box Pin> + Send + 'f>> + 'f>; + pub trait Client where Self: Sized, @@ -34,15 +67,87 @@ where &'c self, f: Func, cache_args: &'a RedisCacheArgs<'a>, - ) -> Box Pin> + Send + 'f>> + 'f> + ) -> OutFunc<'f, Args, Item, E> where Func: Fn(Args) -> Inner + Sync + Send + 'f, Inner: Future> + Send + 'f, Args: Send + 'f, - Item: Send + Serialize + DeserializeOwned, + Item: Serialize + DeserializeOwned + Send + Sync, E: From, 'a: 'f, 'c: 'f; + + fn wrap_const<'c, 'f, 'a, Func, Inner, Args, Item, E>( + &'c self, + f: Func, + cache_args: &'a RedisCacheArgs<'a>, + do_cache: CacheDecision, + ) -> OutFunc<'f, Args, Item, E> + where + Func: Fn(Args) -> Inner + Sync + Send + 'f, + Inner: Future> + Send + 'f, + Args: Send + 'f, + Item: Serialize + DeserializeOwned + Send + Sync, + E: From, + 'a: 'f, + 'c: 'f; + + fn wrap_opt<'c, 'f, 'a, Func, Inner, Args, Item, E>( + &'c self, + f: Func, + cache_args: &'a RedisCacheArgs<'a>, + ) -> OutFunc<'f, Args, Option, E> + where + Func: Fn(Args) -> Inner + Sync + Send + 'f, + Inner: Future, E>> + Send + 'f, + Args: Send + 'f, + Item: Serialize + DeserializeOwned + Send + Sync, + E: From, + 'a: 'f, + 'c: 'f; + + fn wrap_on<'c, 'f, 'ca, 'fa, 'cf, Func, CacheFunc, Inner, Args, Item, E>( + &'c self, + f: Func, + cache_args: &'ca RedisCacheArgs<'ca>, + do_cache: CacheFunc, + ) -> OutFunc<'f, Args, Item, E> + where + Func: Fn(Args) -> Inner + Sync + Send + 'f, + CacheFunc: for<'i> Fn(&'i Item) -> CacheDecision + Send + Sync + 'cf, + Inner: Future> + Send + 'f, + Args: Send + 'fa, + Item: Serialize + DeserializeOwned + Send + Sync, + E: From, + 'ca: 'f, + 'c: 'f, + 'cf: 'f, + 'fa: 'f; +} + +#[tracing::instrument(skip_all, level = Level::TRACE)] +async fn write_cache( + conn: &mut redis::aio::MultiplexedConnection, + key_name: &[u8], + payload: &Item, +) -> Result<(), CacheError> +where + Item: Serialize, +{ + let cache_item = CacheItem { + timestamp: Utc::now(), + payload, + }; + + let _: () = conn + .set(key_name, { + let mut buf = Vec::new(); + ciborium::into_writer(&cache_item, &mut buf)?; + buf + }) + .await?; + + Ok(()) } impl Client for RedisClient { @@ -50,52 +155,87 @@ impl Client for RedisClient { fn new((ip, port): (net::IpAddr, u16)) -> Result { Ok(Self { - redis: redis::Client::open((ip.to_string(), port))?, + client: redis::Client::open((ip.to_string(), port))?, }) } fn get(&self) -> &redis::Client { - &self.redis + &self.client } fn wrap<'c, 'f, 'a, Func, Inner, Args, Item, E>( &'c self, f: Func, cache_args: &'a RedisCacheArgs<'a>, - ) -> Box Pin> + Send + 'f>> + 'f> + ) -> OutFunc<'f, Args, Item, E> where Func: Fn(Args) -> Inner + Sync + Send + 'f, Inner: Future> + Send + 'f, Args: Send + 'f, - Item: Send + Serialize + DeserializeOwned, + Item: Serialize + DeserializeOwned + Send + Sync, E: From, 'a: 'f, 'c: 'f, { - const LOCK_TTL: usize = 60_000; // milliseconds - const LOCK_RETRY_TIME: Duration = Duration::from_secs(1); + self.wrap_const(f, cache_args, CacheDecision::Cache) + } - #[derive(Serialize, Deserialize)] - struct CacheItem { - timestamp: DateTime, - payload: Item, - } + fn wrap_opt<'c, 'f, 'a, Func, Inner, Args, Item, E>( + &'c self, + f: Func, + cache_args: &'a RedisCacheArgs<'a>, + ) -> OutFunc<'f, Args, Option, E> + where + Func: Fn(Args) -> Inner + Sync + Send + 'f, + Inner: Future, E>> + Send + 'f, + Args: Send + 'f, + Item: Serialize + DeserializeOwned + Send + Sync, + E: From, + 'a: 'f, + 'c: 'f, + { + self.wrap_on(f, cache_args, |item| { + item.as_ref() + .map_or(CacheDecision::NoCache, |_| CacheDecision::Cache) + }) + } - #[tracing::instrument(skip_all, level = Level::TRACE)] - async fn query_cache( - conn: &mut redis::aio::MultiplexedConnection, - key_name: &[u8], - ) -> Result>, CacheError> - where - Item: DeserializeOwned, - { - Ok(conn - .get::<&[u8], Option>(key_name) - .await? - .map(|s| serde_json::from_str(&s)) - .transpose()?) - } + fn wrap_const<'c, 'f, 'a, Func, Inner, Args, Item, E>( + &'c self, + f: Func, + cache_args: &'a RedisCacheArgs<'a>, + do_cache: CacheDecision, + ) -> OutFunc<'f, Args, Item, E> + where + Func: Fn(Args) -> Inner + Sync + Send + 'f, + Inner: Future> + Send + 'f, + Args: Send + 'f, + Item: Serialize + DeserializeOwned + Send + Sync, + E: From, + 'a: 'f, + 'c: 'f, + { + self.wrap_on(f, cache_args, move |_item| do_cache) + } + fn wrap_on<'c, 'f, 'ca, 'fa, 'cf, Func, CacheFunc, Inner, Args, Item, E>( + &'c self, + f: Func, + cache_args: &'ca RedisCacheArgs<'ca>, + do_cache: CacheFunc, + ) -> OutFunc<'f, Args, Item, E> + where + Func: Fn(Args) -> Inner + Sync + Send + 'f, + CacheFunc: for<'i> Fn(&'i Item) -> CacheDecision + Send + Sync + 'cf, + Inner: Future> + Send + 'f, + Args: Send + 'fa, + Item: Serialize + DeserializeOwned + Send + Sync, + E: From, + 'ca: 'f, + 'c: 'f, + 'cf: 'f, + 'fa: 'f, + { Box::new(move |args: Args| { Box::pin(async move { let expiry = TimeDelta::from_std(cache_args.expiry) @@ -133,25 +273,20 @@ impl Client for RedisClient { .in_scope(|| async { let payload = f(args).await?; - let cache_item = CacheItem { - timestamp: Utc::now(), - payload, - }; - let _: () = conn - .set( - cache_args.key_name, - serde_json::to_string(&cache_item) - .map_err(Into::into)?, - ) - .await - .map_err(Into::into)?; - lock_manager.unlock(&lock); - tracing::trace!("cache updated"); - Ok(cache_item.payload) + match do_cache(&payload) { + CacheDecision::NoCache => Ok(payload), + CacheDecision::Cache => { + lock_manager.unlock(&lock); + tracing::trace!("cache updated"); + write_cache(&mut conn, cache_args.key_name, &payload) + .await?; + Ok(payload) + } + } }) .await } else { - // Could not get lock because it's already. so some other process is already + // Could not get lock because it's already taken. So some other process is already // gathering data. Wait for it to finish and then just return // the cached response. tracing::trace!("could not acquire lock"); @@ -186,11 +321,18 @@ impl Client for RedisClient { break Ok(response.payload); } None => { - break Err(CacheError::Consistency( - "cached item expected but not found" - .to_owned(), - ) - .into()) + tracing::trace!("no cache item returned, generating own response"); + let payload = f(args).await?; + break match do_cache(&payload) { + CacheDecision::NoCache => Ok(payload), + CacheDecision::Cache => { + write_cache(&mut conn, cache_args.key_name, &payload) + .await?; + lock_manager.unlock(&lock); + tracing::trace!("cache updated"); + Ok(payload) + } + } } } }