From 3f708c5fffed105b27965f8e844a26de6bdf9662 Mon Sep 17 00:00:00 2001 From: rtkay123 Date: Sun, 5 Apr 2026 15:17:55 +0200 Subject: feat(cli): cache --- Cargo.lock | 165 ++++++++++++++- Cargo.toml | 3 + crates/api-auth/Cargo.toml | 3 +- crates/api-auth/src/discord/mod.rs | 16 +- crates/api-auth/src/lib.rs | 6 +- crates/sellershut/Cargo.toml | 1 + crates/sellershut/src/config/cache/mod.rs | 231 +++++++++++++++++++++ crates/sellershut/src/config/mod.rs | 7 + crates/sellershut/src/main.rs | 8 +- .../src/server/api/routes/auth/discord.rs | 2 +- crates/sh-util/Cargo.toml | 26 +++ crates/sh-util/src/cache/cluster.rs | 56 +++++ crates/sh-util/src/cache/mod.rs | 176 ++++++++++++++++ crates/sh-util/src/cache/sentinel.rs | 66 ++++++ crates/sh-util/src/lib.rs | 2 + 15 files changed, 755 insertions(+), 13 deletions(-) create mode 100644 crates/sellershut/src/config/cache/mod.rs create mode 100644 crates/sh-util/Cargo.toml create mode 100644 crates/sh-util/src/cache/cluster.rs create mode 100644 crates/sh-util/src/cache/mod.rs create mode 100644 crates/sh-util/src/cache/sentinel.rs create mode 100644 crates/sh-util/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index e52f6ae..0e5cb2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,6 +98,7 @@ dependencies = [ "oauth2", "secrecy", "serde", + "sh-util", "sqlx", "thiserror 2.0.18", "url", @@ -124,6 +125,21 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arc-swap" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207" +dependencies = [ + "rustversion", +] + +[[package]] +name = "arcstr" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03918c3dbd7701a85c6b9887732e2921175f26c350b4563841d0958c21d57e6d" + [[package]] name = "arrayref" version = "0.3.9" @@ -145,6 +161,17 @@ dependencies = [ "event-listener 2.5.3", ] +[[package]] +name = "async-lock" +version = "3.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f7f2596bd5b78a9fec8088ccd89180d7f9f55b94b0576823bbbdc72ee8311" +dependencies = [ + "event-listener 5.4.1", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "async-session" version = "3.0.0" @@ -152,7 +179,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07da4ce523b4e2ebaaf330746761df23a465b951a83d84bbce4233dabedae630" dependencies = [ "anyhow", - "async-lock", + "async-lock 2.8.0", "async-trait", "base64 0.13.1", "bincode", @@ -262,6 +289,15 @@ dependencies = [ "syn", ] +[[package]] +name = "backon" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cffb0e931875b666fc4fcb20fee52e9bbd1ef836fd9e9e04ec21555f9f85f7ef" +dependencies = [ + "fastrand", +] + [[package]] name = "base64" version = "0.13.1" @@ -280,6 +316,18 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bb8" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "457d7ed3f888dfd2c7af56d4975cade43c622f74bdcddfed6d4352f57acc6310" +dependencies = [ + "futures-util", + "parking_lot", + "portable-atomic", + "tokio", +] + [[package]] name = "bincode" version = "1.3.3" @@ -462,6 +510,20 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -513,6 +575,12 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc16" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "338089f42c427b86394a5ee60ff321da23a5c89c9d89514c829687b26359fcff" + [[package]] name = "crc32fast" version = "1.5.0" @@ -722,6 +790,22 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener 5.4.1", + "pin-project-lite", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -808,6 +892,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -828,6 +923,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1564,6 +1660,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.5" @@ -1736,6 +1838,36 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "redis" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d76e41a79ae5cbb41257d84cf4cf0db0bb5a95b11bf05c62c351de4fe748620d" +dependencies = [ + "arc-swap", + "arcstr", + "async-lock 3.4.2", + "backon", + "bb8", + "bytes", + "cfg-if 1.0.4", + "combine", + "crc16", + "futures-channel", + "futures-util", + "itoa", + "log", + "percent-encoding", + "pin-project-lite", + "rand 0.9.2", + "ryu", + "socket2", + "tokio", + "tokio-util", + "url", + "xxhash-rust", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -1980,6 +2112,7 @@ dependencies = [ "secrecy", "serde", "serde_json", + "sh-util", "sqlx", "tokio", "toml", @@ -2071,6 +2204,16 @@ dependencies = [ "serde", ] +[[package]] +name = "sh-util" +version = "0.0.0" +dependencies = [ + "bb8", + "futures-util", + "redis", + "serde", +] + [[package]] name = "sha1" version = "0.10.6" @@ -2536,6 +2679,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -2563,6 +2707,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "1.1.2+spec-1.1.0" @@ -3349,6 +3506,12 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" +[[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + [[package]] name = "yoke" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index a9a8c27..a93b10e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,9 +12,12 @@ homepage = "https://git.kanjala.com/sellershut" api-core = { path = "./crates/api-core", version = "0.0.0" } async-trait = "0.1.89" axum = "0.8.8" +futures-util = "0.3.32" +redis = { version = "1.1.0", default-features = false } secrecy = "0.10.3" serde = "1.0.228" serde_json = "1.0.149" +sh-util = { path = "./crates/sh-util", version = "0.0.0" } tokio = "1.51.0" tower = "0.5.3" thiserror = "2.0.18" diff --git a/crates/api-auth/Cargo.toml b/crates/api-auth/Cargo.toml index 053bbb9..a0868a5 100644 --- a/crates/api-auth/Cargo.toml +++ b/crates/api-auth/Cargo.toml @@ -13,6 +13,7 @@ async-trait.workspace = true oauth2 = "5.0.0" secrecy.workspace = true serde.workspace = true +sh-util = { workspace = true, optional = true } sqlx.workspace = true thiserror.workspace = true utoipa = { workspace = true, optional = true } @@ -20,5 +21,5 @@ url.workspace = true async-session = "3.0.0" [features] -discord = [] +discord = ["sh-util/cache"] utoipa = ["dep:utoipa", "serde/derive"] diff --git a/crates/api-auth/src/discord/mod.rs b/crates/api-auth/src/discord/mod.rs index 29b9bc2..dbcb139 100644 --- a/crates/api-auth/src/discord/mod.rs +++ b/crates/api-auth/src/discord/mod.rs @@ -2,19 +2,25 @@ use api_core::models::user::User; use async_session::Session; use async_trait::async_trait; use oauth2::{CsrfToken, Scope}; +use sh_util::cache::RedisManager; use sqlx::PgPool; use crate::{BasicClient, CSRF_TOKEN, OauthDriver, error::AuthError}; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct AuthServiceDiscord { database: PgPool, + cache: RedisManager, client: BasicClient, } impl AuthServiceDiscord { - pub fn new(database: PgPool, client: BasicClient) -> Self { - Self { database, client } + pub fn new(database: PgPool, client: BasicClient, cache: RedisManager) -> Self { + Self { + database, + client, + cache, + } } } @@ -26,7 +32,7 @@ impl OauthDriver for AuthServiceDiscord { async fn get_user(&self) -> Result { todo!() } - async fn create_oauth_session(&self)->Result { + async fn create_oauth_session(&self) -> Result { let (auth_url, csrf_token) = self .client .authorize_url(CsrfToken::new_random) @@ -38,7 +44,7 @@ impl OauthDriver for AuthServiceDiscord { Ok(String::default()) } - async fn save_session(&self, user: &User)->Result<(), AuthError>{ + async fn save_session(&self, user: &User) -> Result<(), AuthError> { todo!() } } diff --git a/crates/api-auth/src/lib.rs b/crates/api-auth/src/lib.rs index 95a04c4..367d395 100644 --- a/crates/api-auth/src/lib.rs +++ b/crates/api-auth/src/lib.rs @@ -20,11 +20,11 @@ type C = oauth2::basic::BasicClient< pub struct BasicClient(C); #[async_trait::async_trait] -pub trait OauthDriver: Send + Sync + std::fmt::Debug { +pub trait OauthDriver: Send + Sync { async fn get_auth_token(&self) -> Result; async fn get_user(&self) -> Result; - async fn create_oauth_session(&self)->Result; - async fn save_session(&self, user: &User)->Result<(), AuthError>; + async fn create_oauth_session(&self) -> Result; + async fn save_session(&self, user: &User) -> Result<(), AuthError>; } use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; diff --git a/crates/sellershut/Cargo.toml b/crates/sellershut/Cargo.toml index 14a686c..caf6fd0 100644 --- a/crates/sellershut/Cargo.toml +++ b/crates/sellershut/Cargo.toml @@ -18,6 +18,7 @@ clap = { version = "4.6.0", features = ["derive", "env"] } secrecy = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true +sh-util = { workspace = true, features = ["cache"] } sqlx = { workspace = true, features = ["migrate"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } toml = "1.1.2" diff --git a/crates/sellershut/src/config/cache/mod.rs b/crates/sellershut/src/config/cache/mod.rs new file mode 100644 index 0000000..136c3a4 --- /dev/null +++ b/crates/sellershut/src/config/cache/mod.rs @@ -0,0 +1,231 @@ +use anyhow::Context; +use clap::{Args, ValueEnum}; +use serde::{Deserialize, Serialize}; +use sh_util::cache::{RedisVariant, SentinelConfig}; + +#[derive(Debug, Clone, Copy, ValueEnum, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum CacheMode { + Standalone, + Clustered, + Sentinel, +} + +#[derive(Debug, Clone, Args, Serialize, Deserialize, Default, PartialEq, Eq)] +#[serde(default, rename_all = "kebab-case")] +pub struct CacheConfig { + /// Cache mode: standalone, clustered, or sentinel. + #[arg(long, env = "HUT_CACHE_MODE", value_enum)] + #[serde(rename = "mode", skip_serializing_if = "Option::is_none")] + pub cache_mode: Option, + + /// Full Redis URL. Useful for standalone mode and can override host/port style inputs. + #[arg(long, env = "HUT_CACHE_URL")] + #[serde(rename = "url", skip_serializing_if = "Option::is_none")] + pub cache_url: Option, + + /// Redis host for standalone mode. + #[arg(long, env = "HUT_CACHE_HOST")] + #[serde(rename = "host", skip_serializing_if = "Option::is_none")] + pub cache_host: Option, + + /// Redis port for standalone mode. + #[arg(long, env = "HUT_CACHE_PORT")] + #[serde(rename = "port", skip_serializing_if = "Option::is_none")] + pub cache_port: Option, + + /// Comma-delimited node list for clustered or sentinel discovery, e.g. host1:6379,host2:6379. + #[arg(long, env = "HUT_CACHE_NODES", value_delimiter = ',')] + #[serde(rename = "nodes", skip_serializing_if = "Vec::is_empty")] + pub cache_nodes: Vec, + + /// Redis username. + #[arg(long, env = "HUT_CACHE_USERNAME")] + #[serde(rename = "username", skip_serializing_if = "Option::is_none")] + pub cache_username: Option, + + /// Redis password. + #[arg(long, env = "HUT_CACHE_PASSWORD")] + #[serde(rename = "password", skip_serializing_if = "Option::is_none")] + pub cache_password: Option, + + /// Redis logical database number. + #[arg(long, env = "HUT_CACHE_DB")] + #[serde(rename = "database", skip_serializing_if = "Option::is_none")] + pub cache_database: Option, + + /// Sentinel service name. Required for sentinel mode. + #[arg(long, env = "HUT_CACHE_SERVICE_NAME")] + #[serde(rename = "service_name", skip_serializing_if = "Option::is_none")] + pub cache_service_name: Option, + + /// Whether Redis TLS should use secure mode. + #[arg(long, env = "HUT_CACHE_TLS_MODE_SECURE")] + #[serde(rename = "tls_mode_secure", skip_serializing_if = "Option::is_none")] + pub cache_tls_mode_secure: Option, + + /// Whether the client should use RESP3. + #[arg(long, env = "HUT_CACHE_USE_RESP3")] + #[serde(rename = "use_resp3", skip_serializing_if = "Option::is_none")] + pub cache_use_resp3: Option, +} + +impl CacheConfig { + pub fn merge(self, higher: Self) -> Self { + Self { + cache_mode: higher.cache_mode.or(self.cache_mode), + cache_url: higher.cache_url.or(self.cache_url), + cache_host: higher.cache_host.or(self.cache_host), + cache_port: higher.cache_port.or(self.cache_port), + cache_nodes: if higher.cache_nodes.is_empty() { + self.cache_nodes + } else { + higher.cache_nodes + }, + cache_username: higher.cache_username.or(self.cache_username), + cache_password: higher.cache_password.or(self.cache_password), + cache_database: higher.cache_database.or(self.cache_database), + cache_service_name: higher.cache_service_name.or(self.cache_service_name), + cache_tls_mode_secure: higher + .cache_tls_mode_secure + .or(self.cache_tls_mode_secure), + cache_use_resp3: higher + .cache_use_resp3 + .or(self.cache_use_resp3), + } + } + + pub fn with_defaults(self) -> Self { + Self { + cache_mode: Some(self.cache_mode.unwrap_or(CacheMode::Standalone)), + cache_url: self.cache_url, + cache_host: Some(self.cache_host.unwrap_or_else(|| "127.0.0.1".to_string())), + cache_port: Some(self.cache_port.unwrap_or(6379)), + cache_nodes: self.cache_nodes, + cache_username: self.cache_username, + cache_password: self.cache_password, + cache_database: Some(self.cache_database.unwrap_or(0)), + cache_service_name: self.cache_service_name, + cache_tls_mode_secure: Some(self.cache_tls_mode_secure.unwrap_or(false)), + cache_use_resp3: Some(self.cache_use_resp3.unwrap_or(false)), + } + } + + pub fn defaults() -> Self { + Self::default().with_defaults() + } + + pub fn mode(&self) -> CacheMode { + self.cache_mode.unwrap_or(CacheMode::Standalone) + } + + pub fn url(&self) -> anyhow::Result { + if let Some(url) = &self.cache_url { + return Ok(url.clone()); + } + + match self.mode() { + CacheMode::Standalone => { + let host = self + .cache_host + .as_deref() + .context("cache.host")?; + let port = self + .cache_port + .context("cache.port")?; + let db = self.cache_database.unwrap_or(0); + + let auth = match (&self.cache_username, &self.cache_password) { + (Some(username), Some(password)) => format!("{username}:{password}@"), + (None, Some(password)) => format!(":{password}@"), + (Some(username), None) => format!("{username}@"), + (None, None) => String::new(), + }; + + Ok(format!("redis://{}{}:{}/{}", auth, host, port, db)) + } + CacheMode::Clustered | CacheMode::Sentinel => { + self.cache_nodes + .first() + .cloned() + .context("cache.nodes[0]") + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CacheConfigConversionError { + WrongMode(CacheMode), + MissingField(&'static str), +} + +impl std::fmt::Display for CacheConfigConversionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::WrongMode(mode) => write!(f, "cache mode must be sentinel, got {mode:?}"), + Self::MissingField(field) => write!(f, "missing required cache field: {field}"), + } + } +} + +impl std::error::Error for CacheConfigConversionError {} + +impl TryFrom<&CacheConfig> for SentinelConfig { + type Error = CacheConfigConversionError; + + fn try_from(value: &CacheConfig) -> Result { + if value.mode() != CacheMode::Sentinel { + return Err(CacheConfigConversionError::WrongMode(value.mode())); + } + + Ok(SentinelConfig { + service_name: value + .cache_service_name + .clone() + .ok_or(CacheConfigConversionError::MissingField("cache.service_name"))?, + redis_tls_mode_secure: value.cache_tls_mode_secure.unwrap_or(false), + redis_db: value.cache_database.map(i64::from), + redis_username: value + .cache_username + .clone() + .ok_or(CacheConfigConversionError::MissingField("cache.username"))?, + redis_password: value + .cache_password + .clone() + .ok_or(CacheConfigConversionError::MissingField("cache.password"))?, + redis_use_resp3: value.cache_use_resp3.unwrap_or(false), + }) + } +} + +impl TryFrom for SentinelConfig { + type Error = CacheConfigConversionError; + + fn try_from(value: CacheConfig) -> Result { + SentinelConfig::try_from(&value) + } +} + +impl TryFrom<&CacheConfig> for RedisVariant { + type Error = anyhow::Error; + + fn try_from(value: &CacheConfig) -> Result { + let s = SentinelConfig::try_from(value)?; + + match value.mode() { + CacheMode::Standalone => Ok(RedisVariant::NonClustered), + CacheMode::Clustered => Ok(RedisVariant::Clustered), + CacheMode::Sentinel => Ok(RedisVariant::Sentinel(s)), + } + } + +} + +impl TryFrom for sh_util::cache::RedisVariant { + type Error = anyhow::Error; + + fn try_from(value: CacheConfig) -> Result { + RedisVariant::try_from(&value) + } +} diff --git a/crates/sellershut/src/config/mod.rs b/crates/sellershut/src/config/mod.rs index 389b4bc..156ad0f 100644 --- a/crates/sellershut/src/config/mod.rs +++ b/crates/sellershut/src/config/mod.rs @@ -1,4 +1,5 @@ pub mod auth; +pub mod cache; pub mod cli; pub mod database; mod server; @@ -25,6 +26,9 @@ pub struct Config { /// Database configuration. #[command(flatten)] pub database: database::DatabaseConfig, + /// Cache configuration. + #[command(flatten)] + pub cache: cache::CacheConfig, } impl Config { pub fn load(cli: Self) -> Result { @@ -43,6 +47,7 @@ impl Config { server: self.server.merge(higher.server), auth: self.auth.merge(higher.auth), database: self.database.merge(higher.database), + cache: self.cache.merge(higher.cache), } } @@ -52,6 +57,7 @@ impl Config { server: self.server.with_defaults(), auth: self.auth.with_defaults(), database: self.database.with_defaults(), + cache: self.cache.with_defaults(), } } @@ -61,6 +67,7 @@ impl Config { server: server::ServerConfig::defaults(), auth: auth::OauthConfig::defaults(), database: database::DatabaseConfig::defaults(), + cache: cache::CacheConfig::defaults(), } } } diff --git a/crates/sellershut/src/main.rs b/crates/sellershut/src/main.rs index ebae4ed..a46cf3e 100644 --- a/crates/sellershut/src/main.rs +++ b/crates/sellershut/src/main.rs @@ -15,6 +15,7 @@ use api_core::{ health::BaseService, }; use clap::Parser; +use sh_util::cache::{RedisManager, RedisVariant}; use sqlx::PgPool; use tokio::net::TcpListener; use tracing::info; @@ -40,8 +41,10 @@ async fn main() -> Result<()> { )?; let database = state::postgres(&cfg.database.connection_url(), 100).await?; + let variant = RedisVariant::try_from(cfg.cache.clone())?; + let cache = RedisManager::new(&cfg.cache.url()?, variant).await; - let auth_clients = build_oauth_client(&cfg.auth, database)?; + let auth_clients = build_oauth_client(&cfg.auth, database, cache)?; let state = AppState::builder() .log_handle(log_handle) @@ -67,6 +70,7 @@ async fn main() -> Result<()> { fn build_oauth_client( config: &OauthConfig, database: PgPool, + cache: RedisManager, ) -> Result>> { let auth = config.to_owned(); let mut collection: HashMap> = HashMap::new(); @@ -77,7 +81,7 @@ fn build_oauth_client( let c = AuthClientConfig::try_from(auth.discord.context("missing discord config")?)?; let client = BasicClient::try_from(c)?; - let auth_service = Arc::new(AuthServiceDiscord::new(database, client)); + let auth_service = Arc::new(AuthServiceDiscord::new(database, client, cache)); collection.insert(OauthProvider::Discord, auth_service); } diff --git a/crates/sellershut/src/server/api/routes/auth/discord.rs b/crates/sellershut/src/server/api/routes/auth/discord.rs index 163619b..0296e48 100644 --- a/crates/sellershut/src/server/api/routes/auth/discord.rs +++ b/crates/sellershut/src/server/api/routes/auth/discord.rs @@ -32,7 +32,7 @@ pub async fn discord_auth(State(state): State) -> Result( + info: T, + ) -> Result { + Ok(RedisClusterConnectionManager { + client: ClusterClientBuilder::new(vec![info]).retries(0).build()?, + }) + } +} + +impl bb8::ManageConnection for RedisClusterConnectionManager { + type Connection = redis::cluster_async::ClusterConnection; + type Error = RedisError; + + async fn connect(&self) -> Result { + self.client.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong = conn + .route_command( + redis::cmd("PING"), + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::OneSucceeded), + )), + ) + .await + .and_then(|v| Ok(String::from_redis_value(v)?))?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(( + ErrorKind::Server(redis::ServerErrorKind::ResponseError), + "ping request", + ) + .into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} diff --git a/crates/sh-util/src/cache/mod.rs b/crates/sh-util/src/cache/mod.rs new file mode 100644 index 0000000..67a5121 --- /dev/null +++ b/crates/sh-util/src/cache/mod.rs @@ -0,0 +1,176 @@ +mod cluster; +mod sentinel; +pub use sentinel::SentinelConfig; + +use std::{sync::Arc, time::Duration}; + +use bb8::RunError; +// use bb8_redis::RedisConnectionManager; +use futures_util::lock::Mutex; +use redis::{ + AsyncConnectionConfig, ProtocolVersion, RedisConnectionInfo, RedisError, TlsMode, + aio::ConnectionManagerConfig, sentinel::SentinelNodeConnectionInfo, +}; + +pub use self::cluster::RedisClusterConnectionManager; + +pub const REDIS_CONN_TIMEOUT: Duration = Duration::from_secs(2); + +pub enum RedisVariant { + Clustered, + NonClustered, + Sentinel(sentinel::SentinelConfig), +} + +#[derive(Clone)] +pub enum RedisManager { + Clustered(redis::cluster_async::ClusterConnection), + NonClustered(redis::aio::ConnectionManager), + Sentinel(Arc>), +} + +impl RedisManager { + pub async fn new(dsn: &str, variant: RedisVariant) -> Self { + match variant { + RedisVariant::Clustered => { + let cli = redis::cluster::ClusterClient::builder(vec![dsn]) + .retries(1) + .connection_timeout(REDIS_CONN_TIMEOUT) + .build() + .expect("Error initializing redis-unpooled cluster client"); + let con = cli + .get_async_connection() + .await + .expect("Failed to get redis-cluster-unpooled connection"); + RedisManager::Clustered(con) + } + RedisVariant::NonClustered => { + let cli = + redis::Client::open(dsn).expect("Error initializing redis unpooled client"); + let con = redis::aio::ConnectionManager::new_with_config( + cli, + ConnectionManagerConfig::new() + .set_number_of_retries(1) + .set_connection_timeout(Some(REDIS_CONN_TIMEOUT)), + ) + .await + .expect("Failed to get redis-unpooled connection manager"); + RedisManager::NonClustered(con) + } + RedisVariant::Sentinel(cfg) => { + let tls_mode = if cfg.redis_tls_mode_secure { + TlsMode::Secure + } else { + TlsMode::Insecure + }; + let protocol = if cfg.redis_use_resp3 { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::default() + }; + + let redis_connection_info = RedisConnectionInfo::default() + .set_db(cfg.redis_db.unwrap_or(0)) + .set_protocol(protocol) + .set_username(cfg.redis_username.clone()) + .set_password(cfg.redis_password.clone()); + let sentinel = SentinelNodeConnectionInfo::default() + .set_redis_connection_info(redis_connection_info) + .set_tls_mode(tls_mode); + + let cli = redis::sentinel::SentinelClient::build( + vec![dsn], + cfg.service_name.clone(), + Some(sentinel), + redis::sentinel::SentinelServerType::Master, + ) + .expect("Failed to build sentinel client"); + + RedisManager::Sentinel(Arc::new(Mutex::new(cli))) + } + } + } + + pub async fn get(&self) -> Result> { + match self { + Self::Clustered(conn) => Ok(RedisConnection::Clustered(conn.clone())), + Self::NonClustered(conn) => Ok(RedisConnection::NonClustere(conn.clone())), + Self::Sentinel(conn) => { + let mut conn = conn.lock().await; + let con = conn + .get_async_connection_with_config( + &AsyncConnectionConfig::new() + .set_response_timeout(Some(REDIS_CONN_TIMEOUT)), + ) + .await?; + Ok(RedisConnection::Sentinel(con)) + } + } + } +} + +pub enum RedisConnection { + Clustered(redis::cluster_async::ClusterConnection), + NonClustere(redis::aio::ConnectionManager), + Sentinel(redis::aio::MultiplexedConnection), +} + +impl redis::aio::ConnectionLike for RedisConnection { + fn req_packed_command<'a>( + &'a mut self, + cmd: &'a redis::Cmd, + ) -> redis::RedisFuture<'a, redis::Value> { + match self { + RedisConnection::Clustered(conn) => conn.req_packed_command(cmd), + RedisConnection::NonClustere(conn) => conn.req_packed_command(cmd), + RedisConnection::Sentinel(conn) => conn.req_packed_command(cmd), + } + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> redis::RedisFuture<'a, Vec> { + match self { + RedisConnection::Clustered(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::NonClustere(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::Sentinel(conn) => conn.req_packed_commands(cmd, offset, count), + } + } + + fn get_db(&self) -> i64 { + match self { + RedisConnection::Clustered(conn) => conn.get_db(), + RedisConnection::NonClustere(conn) => conn.get_db(), + RedisConnection::Sentinel(conn) => conn.get_db(), + } + } +} + +#[cfg(test)] +mod tests { + use redis::AsyncCommands; + + use super::RedisManager; + + // Ensure basic set/get works -- should test sharding as well: + #[tokio::test] + // run with `cargo test -- --ignored redis` only when redis is up and configured + #[ignore] + async fn test_set_read_random_keys() { + let mgr = RedisManager::new( + "redis://127.0.0.1:6379/0", + super::RedisVariant::NonClustered, + ) + .await; + let mut conn = mgr.get().await.unwrap(); + + for (val, key) in "abcdefghijklmnopqrstuvwxyz".chars().enumerate() { + let key = key.to_string(); + let _: () = conn.set(key.clone(), val).await.unwrap(); + assert_eq!(conn.get::<_, usize>(&key).await.unwrap(), val); + } + } +} diff --git a/crates/sh-util/src/cache/sentinel.rs b/crates/sh-util/src/cache/sentinel.rs new file mode 100644 index 0000000..e52b043 --- /dev/null +++ b/crates/sh-util/src/cache/sentinel.rs @@ -0,0 +1,66 @@ +use futures_util::lock::Mutex; +use redis::{ + ErrorKind, IntoConnectionInfo, RedisError, + sentinel::{SentinelClient, SentinelNodeConnectionInfo, SentinelServerType}, +}; +use serde::Deserialize; + +struct LockedSentinelClient(pub(crate) Mutex); + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SentinelConfig { + pub service_name: String, + pub redis_tls_mode_secure: bool, + pub redis_db: Option, + pub redis_username: String, + pub redis_password: String, + pub redis_use_resp3: bool, +} + +/// ConnectionManager that implements `bb8::ManageConnection` and supports +/// asynchronous Sentinel connections via `redis::sentinel::SentinelClient` +pub struct RedisSentinelConnectionManager { + client: LockedSentinelClient, +} + +impl RedisSentinelConnectionManager { + pub fn new( + info: Vec, + service_name: String, + node_connection_info: Option, + ) -> Result { + Ok(RedisSentinelConnectionManager { + client: LockedSentinelClient(Mutex::new(SentinelClient::build( + info, + service_name, + node_connection_info, + SentinelServerType::Master, + )?)), + }) + } +} + +impl bb8::ManageConnection for RedisSentinelConnectionManager { + type Connection = redis::aio::MultiplexedConnection; + type Error = RedisError; + + async fn connect(&self) -> Result { + self.client.0.lock().await.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong: String = redis::cmd("PING").query_async(conn).await?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(( + ErrorKind::Server(redis::ServerErrorKind::ResponseError), + "ping request", + ) + .into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} diff --git a/crates/sh-util/src/lib.rs b/crates/sh-util/src/lib.rs new file mode 100644 index 0000000..5501a81 --- /dev/null +++ b/crates/sh-util/src/lib.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "cache")] +pub mod cache; -- cgit v1.2.3