diff options
Diffstat (limited to 'crates')
| -rw-r--r-- | crates/api-auth/Cargo.toml | 3 | ||||
| -rw-r--r-- | crates/api-auth/src/discord/mod.rs | 16 | ||||
| -rw-r--r-- | crates/api-auth/src/lib.rs | 6 | ||||
| -rw-r--r-- | crates/sellershut/Cargo.toml | 1 | ||||
| -rw-r--r-- | crates/sellershut/src/config/cache/mod.rs | 231 | ||||
| -rw-r--r-- | crates/sellershut/src/config/mod.rs | 7 | ||||
| -rw-r--r-- | crates/sellershut/src/main.rs | 8 | ||||
| -rw-r--r-- | crates/sellershut/src/server/api/routes/auth/discord.rs | 2 | ||||
| -rw-r--r-- | crates/sh-util/Cargo.toml | 26 | ||||
| -rw-r--r-- | crates/sh-util/src/cache/cluster.rs | 56 | ||||
| -rw-r--r-- | crates/sh-util/src/cache/mod.rs | 176 | ||||
| -rw-r--r-- | crates/sh-util/src/cache/sentinel.rs | 66 | ||||
| -rw-r--r-- | crates/sh-util/src/lib.rs | 2 |
13 files changed, 588 insertions, 12 deletions
diff --git a/crates/api-auth/Cargo.toml b/crates/api-auth/Cargo.toml index 053bbb9..a0868a5 100644 --- a/crates/api-auth/Cargo.toml +++ b/crates/api-auth/Cargo.toml @@ -13,6 +13,7 @@ async-trait.workspace = true oauth2 = "5.0.0" secrecy.workspace = true serde.workspace = true +sh-util = { workspace = true, optional = true } sqlx.workspace = true thiserror.workspace = true utoipa = { workspace = true, optional = true } @@ -20,5 +21,5 @@ url.workspace = true async-session = "3.0.0" [features] -discord = [] +discord = ["sh-util/cache"] utoipa = ["dep:utoipa", "serde/derive"] diff --git a/crates/api-auth/src/discord/mod.rs b/crates/api-auth/src/discord/mod.rs index 29b9bc2..dbcb139 100644 --- a/crates/api-auth/src/discord/mod.rs +++ b/crates/api-auth/src/discord/mod.rs @@ -2,19 +2,25 @@ use api_core::models::user::User; use async_session::Session; use async_trait::async_trait; use oauth2::{CsrfToken, Scope}; +use sh_util::cache::RedisManager; use sqlx::PgPool; use crate::{BasicClient, CSRF_TOKEN, OauthDriver, error::AuthError}; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct AuthServiceDiscord { database: PgPool, + cache: RedisManager, client: BasicClient, } impl AuthServiceDiscord { - pub fn new(database: PgPool, client: BasicClient) -> Self { - Self { database, client } + pub fn new(database: PgPool, client: BasicClient, cache: RedisManager) -> Self { + Self { + database, + client, + cache, + } } } @@ -26,7 +32,7 @@ impl OauthDriver for AuthServiceDiscord { async fn get_user(&self) -> Result<User, AuthError> { todo!() } - async fn create_oauth_session(&self)->Result<String,AuthError> { + async fn create_oauth_session(&self) -> Result<String, AuthError> { let (auth_url, csrf_token) = self .client .authorize_url(CsrfToken::new_random) @@ -38,7 +44,7 @@ impl OauthDriver for AuthServiceDiscord { Ok(String::default()) } - async fn save_session(&self, user: &User)->Result<(), AuthError>{ + async fn save_session(&self, user: &User) -> Result<(), AuthError> { todo!() } } diff --git a/crates/api-auth/src/lib.rs b/crates/api-auth/src/lib.rs index 95a04c4..367d395 100644 --- a/crates/api-auth/src/lib.rs +++ b/crates/api-auth/src/lib.rs @@ -20,11 +20,11 @@ type C = oauth2::basic::BasicClient< pub struct BasicClient(C); #[async_trait::async_trait] -pub trait OauthDriver: Send + Sync + std::fmt::Debug { +pub trait OauthDriver: Send + Sync { async fn get_auth_token(&self) -> Result<String, AuthError>; async fn get_user(&self) -> Result<User, AuthError>; - async fn create_oauth_session(&self)->Result<String, AuthError>; - async fn save_session(&self, user: &User)->Result<(), AuthError>; + async fn create_oauth_session(&self) -> Result<String, AuthError>; + async fn save_session(&self, user: &User) -> Result<(), AuthError>; } use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; diff --git a/crates/sellershut/Cargo.toml b/crates/sellershut/Cargo.toml index 14a686c..caf6fd0 100644 --- a/crates/sellershut/Cargo.toml +++ b/crates/sellershut/Cargo.toml @@ -18,6 +18,7 @@ clap = { version = "4.6.0", features = ["derive", "env"] } secrecy = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true +sh-util = { workspace = true, features = ["cache"] } sqlx = { workspace = true, features = ["migrate"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } toml = "1.1.2" diff --git a/crates/sellershut/src/config/cache/mod.rs b/crates/sellershut/src/config/cache/mod.rs new file mode 100644 index 0000000..136c3a4 --- /dev/null +++ b/crates/sellershut/src/config/cache/mod.rs @@ -0,0 +1,231 @@ +use anyhow::Context; +use clap::{Args, ValueEnum}; +use serde::{Deserialize, Serialize}; +use sh_util::cache::{RedisVariant, SentinelConfig}; + +#[derive(Debug, Clone, Copy, ValueEnum, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum CacheMode { + Standalone, + Clustered, + Sentinel, +} + +#[derive(Debug, Clone, Args, Serialize, Deserialize, Default, PartialEq, Eq)] +#[serde(default, rename_all = "kebab-case")] +pub struct CacheConfig { + /// Cache mode: standalone, clustered, or sentinel. + #[arg(long, env = "HUT_CACHE_MODE", value_enum)] + #[serde(rename = "mode", skip_serializing_if = "Option::is_none")] + pub cache_mode: Option<CacheMode>, + + /// Full Redis URL. Useful for standalone mode and can override host/port style inputs. + #[arg(long, env = "HUT_CACHE_URL")] + #[serde(rename = "url", skip_serializing_if = "Option::is_none")] + pub cache_url: Option<String>, + + /// Redis host for standalone mode. + #[arg(long, env = "HUT_CACHE_HOST")] + #[serde(rename = "host", skip_serializing_if = "Option::is_none")] + pub cache_host: Option<String>, + + /// Redis port for standalone mode. + #[arg(long, env = "HUT_CACHE_PORT")] + #[serde(rename = "port", skip_serializing_if = "Option::is_none")] + pub cache_port: Option<u16>, + + /// Comma-delimited node list for clustered or sentinel discovery, e.g. host1:6379,host2:6379. + #[arg(long, env = "HUT_CACHE_NODES", value_delimiter = ',')] + #[serde(rename = "nodes", skip_serializing_if = "Vec::is_empty")] + pub cache_nodes: Vec<String>, + + /// Redis username. + #[arg(long, env = "HUT_CACHE_USERNAME")] + #[serde(rename = "username", skip_serializing_if = "Option::is_none")] + pub cache_username: Option<String>, + + /// Redis password. + #[arg(long, env = "HUT_CACHE_PASSWORD")] + #[serde(rename = "password", skip_serializing_if = "Option::is_none")] + pub cache_password: Option<String>, + + /// Redis logical database number. + #[arg(long, env = "HUT_CACHE_DB")] + #[serde(rename = "database", skip_serializing_if = "Option::is_none")] + pub cache_database: Option<u32>, + + /// Sentinel service name. Required for sentinel mode. + #[arg(long, env = "HUT_CACHE_SERVICE_NAME")] + #[serde(rename = "service_name", skip_serializing_if = "Option::is_none")] + pub cache_service_name: Option<String>, + + /// Whether Redis TLS should use secure mode. + #[arg(long, env = "HUT_CACHE_TLS_MODE_SECURE")] + #[serde(rename = "tls_mode_secure", skip_serializing_if = "Option::is_none")] + pub cache_tls_mode_secure: Option<bool>, + + /// Whether the client should use RESP3. + #[arg(long, env = "HUT_CACHE_USE_RESP3")] + #[serde(rename = "use_resp3", skip_serializing_if = "Option::is_none")] + pub cache_use_resp3: Option<bool>, +} + +impl CacheConfig { + pub fn merge(self, higher: Self) -> Self { + Self { + cache_mode: higher.cache_mode.or(self.cache_mode), + cache_url: higher.cache_url.or(self.cache_url), + cache_host: higher.cache_host.or(self.cache_host), + cache_port: higher.cache_port.or(self.cache_port), + cache_nodes: if higher.cache_nodes.is_empty() { + self.cache_nodes + } else { + higher.cache_nodes + }, + cache_username: higher.cache_username.or(self.cache_username), + cache_password: higher.cache_password.or(self.cache_password), + cache_database: higher.cache_database.or(self.cache_database), + cache_service_name: higher.cache_service_name.or(self.cache_service_name), + cache_tls_mode_secure: higher + .cache_tls_mode_secure + .or(self.cache_tls_mode_secure), + cache_use_resp3: higher + .cache_use_resp3 + .or(self.cache_use_resp3), + } + } + + pub fn with_defaults(self) -> Self { + Self { + cache_mode: Some(self.cache_mode.unwrap_or(CacheMode::Standalone)), + cache_url: self.cache_url, + cache_host: Some(self.cache_host.unwrap_or_else(|| "127.0.0.1".to_string())), + cache_port: Some(self.cache_port.unwrap_or(6379)), + cache_nodes: self.cache_nodes, + cache_username: self.cache_username, + cache_password: self.cache_password, + cache_database: Some(self.cache_database.unwrap_or(0)), + cache_service_name: self.cache_service_name, + cache_tls_mode_secure: Some(self.cache_tls_mode_secure.unwrap_or(false)), + cache_use_resp3: Some(self.cache_use_resp3.unwrap_or(false)), + } + } + + pub fn defaults() -> Self { + Self::default().with_defaults() + } + + pub fn mode(&self) -> CacheMode { + self.cache_mode.unwrap_or(CacheMode::Standalone) + } + + pub fn url(&self) -> anyhow::Result<String> { + if let Some(url) = &self.cache_url { + return Ok(url.clone()); + } + + match self.mode() { + CacheMode::Standalone => { + let host = self + .cache_host + .as_deref() + .context("cache.host")?; + let port = self + .cache_port + .context("cache.port")?; + let db = self.cache_database.unwrap_or(0); + + let auth = match (&self.cache_username, &self.cache_password) { + (Some(username), Some(password)) => format!("{username}:{password}@"), + (None, Some(password)) => format!(":{password}@"), + (Some(username), None) => format!("{username}@"), + (None, None) => String::new(), + }; + + Ok(format!("redis://{}{}:{}/{}", auth, host, port, db)) + } + CacheMode::Clustered | CacheMode::Sentinel => { + self.cache_nodes + .first() + .cloned() + .context("cache.nodes[0]") + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CacheConfigConversionError { + WrongMode(CacheMode), + MissingField(&'static str), +} + +impl std::fmt::Display for CacheConfigConversionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::WrongMode(mode) => write!(f, "cache mode must be sentinel, got {mode:?}"), + Self::MissingField(field) => write!(f, "missing required cache field: {field}"), + } + } +} + +impl std::error::Error for CacheConfigConversionError {} + +impl TryFrom<&CacheConfig> for SentinelConfig { + type Error = CacheConfigConversionError; + + fn try_from(value: &CacheConfig) -> Result<Self, Self::Error> { + if value.mode() != CacheMode::Sentinel { + return Err(CacheConfigConversionError::WrongMode(value.mode())); + } + + Ok(SentinelConfig { + service_name: value + .cache_service_name + .clone() + .ok_or(CacheConfigConversionError::MissingField("cache.service_name"))?, + redis_tls_mode_secure: value.cache_tls_mode_secure.unwrap_or(false), + redis_db: value.cache_database.map(i64::from), + redis_username: value + .cache_username + .clone() + .ok_or(CacheConfigConversionError::MissingField("cache.username"))?, + redis_password: value + .cache_password + .clone() + .ok_or(CacheConfigConversionError::MissingField("cache.password"))?, + redis_use_resp3: value.cache_use_resp3.unwrap_or(false), + }) + } +} + +impl TryFrom<CacheConfig> for SentinelConfig { + type Error = CacheConfigConversionError; + + fn try_from(value: CacheConfig) -> Result<Self, Self::Error> { + SentinelConfig::try_from(&value) + } +} + +impl TryFrom<&CacheConfig> for RedisVariant { + type Error = anyhow::Error; + + fn try_from(value: &CacheConfig) -> Result<Self, Self::Error> { + let s = SentinelConfig::try_from(value)?; + + match value.mode() { + CacheMode::Standalone => Ok(RedisVariant::NonClustered), + CacheMode::Clustered => Ok(RedisVariant::Clustered), + CacheMode::Sentinel => Ok(RedisVariant::Sentinel(s)), + } + } + +} + +impl TryFrom<CacheConfig> for sh_util::cache::RedisVariant { + type Error = anyhow::Error; + + fn try_from(value: CacheConfig) -> Result<Self, Self::Error> { + RedisVariant::try_from(&value) + } +} diff --git a/crates/sellershut/src/config/mod.rs b/crates/sellershut/src/config/mod.rs index 389b4bc..156ad0f 100644 --- a/crates/sellershut/src/config/mod.rs +++ b/crates/sellershut/src/config/mod.rs @@ -1,4 +1,5 @@ pub mod auth; +pub mod cache; pub mod cli; pub mod database; mod server; @@ -25,6 +26,9 @@ pub struct Config { /// Database configuration. #[command(flatten)] pub database: database::DatabaseConfig, + /// Cache configuration. + #[command(flatten)] + pub cache: cache::CacheConfig, } impl Config { pub fn load(cli: Self) -> Result<Self> { @@ -43,6 +47,7 @@ impl Config { server: self.server.merge(higher.server), auth: self.auth.merge(higher.auth), database: self.database.merge(higher.database), + cache: self.cache.merge(higher.cache), } } @@ -52,6 +57,7 @@ impl Config { server: self.server.with_defaults(), auth: self.auth.with_defaults(), database: self.database.with_defaults(), + cache: self.cache.with_defaults(), } } @@ -61,6 +67,7 @@ impl Config { server: server::ServerConfig::defaults(), auth: auth::OauthConfig::defaults(), database: database::DatabaseConfig::defaults(), + cache: cache::CacheConfig::defaults(), } } } diff --git a/crates/sellershut/src/main.rs b/crates/sellershut/src/main.rs index ebae4ed..a46cf3e 100644 --- a/crates/sellershut/src/main.rs +++ b/crates/sellershut/src/main.rs @@ -15,6 +15,7 @@ use api_core::{ health::BaseService, }; use clap::Parser; +use sh_util::cache::{RedisManager, RedisVariant}; use sqlx::PgPool; use tokio::net::TcpListener; use tracing::info; @@ -40,8 +41,10 @@ async fn main() -> Result<()> { )?; let database = state::postgres(&cfg.database.connection_url(), 100).await?; + let variant = RedisVariant::try_from(cfg.cache.clone())?; + let cache = RedisManager::new(&cfg.cache.url()?, variant).await; - let auth_clients = build_oauth_client(&cfg.auth, database)?; + let auth_clients = build_oauth_client(&cfg.auth, database, cache)?; let state = AppState::builder() .log_handle(log_handle) @@ -67,6 +70,7 @@ async fn main() -> Result<()> { fn build_oauth_client( config: &OauthConfig, database: PgPool, + cache: RedisManager, ) -> Result<HashMap<OauthProvider, Arc<dyn OauthDriver>>> { let auth = config.to_owned(); let mut collection: HashMap<OauthProvider, Arc<dyn OauthDriver>> = HashMap::new(); @@ -77,7 +81,7 @@ fn build_oauth_client( let c = AuthClientConfig::try_from(auth.discord.context("missing discord config")?)?; let client = BasicClient::try_from(c)?; - let auth_service = Arc::new(AuthServiceDiscord::new(database, client)); + let auth_service = Arc::new(AuthServiceDiscord::new(database, client, cache)); collection.insert(OauthProvider::Discord, auth_service); } diff --git a/crates/sellershut/src/server/api/routes/auth/discord.rs b/crates/sellershut/src/server/api/routes/auth/discord.rs index 163619b..0296e48 100644 --- a/crates/sellershut/src/server/api/routes/auth/discord.rs +++ b/crates/sellershut/src/server/api/routes/auth/discord.rs @@ -32,7 +32,7 @@ pub async fn discord_auth(State(state): State<AppState>) -> Result<impl IntoResp .context("missing discord driver")?; let headers = HeaderMap::new(); - Ok((headers, Redirect::to(redirect_url))) + Ok((headers, Redirect::to("/"))) // let (auth_url, csrf_token) = client // .authorize_url(CsrfToken::new_random) diff --git a/crates/sh-util/Cargo.toml b/crates/sh-util/Cargo.toml new file mode 100644 index 0000000..12bf7a4 --- /dev/null +++ b/crates/sh-util/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "sh-util" +version = "0.0.0" +edition = "2024" +license.workspace = true +readme.workspace = true +documentation.workspace = true +homepage.workspace = true + +[dependencies] +bb8 = { version = "0.9.1", optional = true } +futures-util = { workspace = true, optional = true } +redis = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"] } + +[features] +cache = [ + "dep:redis", + "redis/cluster-async", + "redis/connection-manager", + "redis/tokio-comp", + "redis/sentinel", + "redis/bb8", + "dep:bb8", + "dep:futures-util", +] diff --git a/crates/sh-util/src/cache/cluster.rs b/crates/sh-util/src/cache/cluster.rs new file mode 100644 index 0000000..de13629 --- /dev/null +++ b/crates/sh-util/src/cache/cluster.rs @@ -0,0 +1,56 @@ +use redis::{ + ErrorKind, FromRedisValue, IntoConnectionInfo, RedisError, + cluster::{ClusterClient, ClusterClientBuilder}, + cluster_routing::{MultipleNodeRoutingInfo, ResponsePolicy, RoutingInfo}, +}; + +/// ConnectionManager that implements `bb8::ManageConnection` and supports +/// asynchronous clustered connections via `redis_cluster_async::Connection` +#[derive(Clone)] +pub struct RedisClusterConnectionManager { + client: ClusterClient, +} + +impl RedisClusterConnectionManager { + pub fn new<T: IntoConnectionInfo>( + info: T, + ) -> Result<RedisClusterConnectionManager, RedisError> { + Ok(RedisClusterConnectionManager { + client: ClusterClientBuilder::new(vec![info]).retries(0).build()?, + }) + } +} + +impl bb8::ManageConnection for RedisClusterConnectionManager { + type Connection = redis::cluster_async::ClusterConnection; + type Error = RedisError; + + async fn connect(&self) -> Result<Self::Connection, Self::Error> { + self.client.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong = conn + .route_command( + redis::cmd("PING"), + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::OneSucceeded), + )), + ) + .await + .and_then(|v| Ok(String::from_redis_value(v)?))?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(( + ErrorKind::Server(redis::ServerErrorKind::ResponseError), + "ping request", + ) + .into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} diff --git a/crates/sh-util/src/cache/mod.rs b/crates/sh-util/src/cache/mod.rs new file mode 100644 index 0000000..67a5121 --- /dev/null +++ b/crates/sh-util/src/cache/mod.rs @@ -0,0 +1,176 @@ +mod cluster; +mod sentinel; +pub use sentinel::SentinelConfig; + +use std::{sync::Arc, time::Duration}; + +use bb8::RunError; +// use bb8_redis::RedisConnectionManager; +use futures_util::lock::Mutex; +use redis::{ + AsyncConnectionConfig, ProtocolVersion, RedisConnectionInfo, RedisError, TlsMode, + aio::ConnectionManagerConfig, sentinel::SentinelNodeConnectionInfo, +}; + +pub use self::cluster::RedisClusterConnectionManager; + +pub const REDIS_CONN_TIMEOUT: Duration = Duration::from_secs(2); + +pub enum RedisVariant { + Clustered, + NonClustered, + Sentinel(sentinel::SentinelConfig), +} + +#[derive(Clone)] +pub enum RedisManager { + Clustered(redis::cluster_async::ClusterConnection), + NonClustered(redis::aio::ConnectionManager), + Sentinel(Arc<Mutex<redis::sentinel::SentinelClient>>), +} + +impl RedisManager { + pub async fn new(dsn: &str, variant: RedisVariant) -> Self { + match variant { + RedisVariant::Clustered => { + let cli = redis::cluster::ClusterClient::builder(vec![dsn]) + .retries(1) + .connection_timeout(REDIS_CONN_TIMEOUT) + .build() + .expect("Error initializing redis-unpooled cluster client"); + let con = cli + .get_async_connection() + .await + .expect("Failed to get redis-cluster-unpooled connection"); + RedisManager::Clustered(con) + } + RedisVariant::NonClustered => { + let cli = + redis::Client::open(dsn).expect("Error initializing redis unpooled client"); + let con = redis::aio::ConnectionManager::new_with_config( + cli, + ConnectionManagerConfig::new() + .set_number_of_retries(1) + .set_connection_timeout(Some(REDIS_CONN_TIMEOUT)), + ) + .await + .expect("Failed to get redis-unpooled connection manager"); + RedisManager::NonClustered(con) + } + RedisVariant::Sentinel(cfg) => { + let tls_mode = if cfg.redis_tls_mode_secure { + TlsMode::Secure + } else { + TlsMode::Insecure + }; + let protocol = if cfg.redis_use_resp3 { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::default() + }; + + let redis_connection_info = RedisConnectionInfo::default() + .set_db(cfg.redis_db.unwrap_or(0)) + .set_protocol(protocol) + .set_username(cfg.redis_username.clone()) + .set_password(cfg.redis_password.clone()); + let sentinel = SentinelNodeConnectionInfo::default() + .set_redis_connection_info(redis_connection_info) + .set_tls_mode(tls_mode); + + let cli = redis::sentinel::SentinelClient::build( + vec![dsn], + cfg.service_name.clone(), + Some(sentinel), + redis::sentinel::SentinelServerType::Master, + ) + .expect("Failed to build sentinel client"); + + RedisManager::Sentinel(Arc::new(Mutex::new(cli))) + } + } + } + + pub async fn get(&self) -> Result<RedisConnection, RunError<RedisError>> { + match self { + Self::Clustered(conn) => Ok(RedisConnection::Clustered(conn.clone())), + Self::NonClustered(conn) => Ok(RedisConnection::NonClustere(conn.clone())), + Self::Sentinel(conn) => { + let mut conn = conn.lock().await; + let con = conn + .get_async_connection_with_config( + &AsyncConnectionConfig::new() + .set_response_timeout(Some(REDIS_CONN_TIMEOUT)), + ) + .await?; + Ok(RedisConnection::Sentinel(con)) + } + } + } +} + +pub enum RedisConnection { + Clustered(redis::cluster_async::ClusterConnection), + NonClustere(redis::aio::ConnectionManager), + Sentinel(redis::aio::MultiplexedConnection), +} + +impl redis::aio::ConnectionLike for RedisConnection { + fn req_packed_command<'a>( + &'a mut self, + cmd: &'a redis::Cmd, + ) -> redis::RedisFuture<'a, redis::Value> { + match self { + RedisConnection::Clustered(conn) => conn.req_packed_command(cmd), + RedisConnection::NonClustere(conn) => conn.req_packed_command(cmd), + RedisConnection::Sentinel(conn) => conn.req_packed_command(cmd), + } + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> redis::RedisFuture<'a, Vec<redis::Value>> { + match self { + RedisConnection::Clustered(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::NonClustere(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::Sentinel(conn) => conn.req_packed_commands(cmd, offset, count), + } + } + + fn get_db(&self) -> i64 { + match self { + RedisConnection::Clustered(conn) => conn.get_db(), + RedisConnection::NonClustere(conn) => conn.get_db(), + RedisConnection::Sentinel(conn) => conn.get_db(), + } + } +} + +#[cfg(test)] +mod tests { + use redis::AsyncCommands; + + use super::RedisManager; + + // Ensure basic set/get works -- should test sharding as well: + #[tokio::test] + // run with `cargo test -- --ignored redis` only when redis is up and configured + #[ignore] + async fn test_set_read_random_keys() { + let mgr = RedisManager::new( + "redis://127.0.0.1:6379/0", + super::RedisVariant::NonClustered, + ) + .await; + let mut conn = mgr.get().await.unwrap(); + + for (val, key) in "abcdefghijklmnopqrstuvwxyz".chars().enumerate() { + let key = key.to_string(); + let _: () = conn.set(key.clone(), val).await.unwrap(); + assert_eq!(conn.get::<_, usize>(&key).await.unwrap(), val); + } + } +} diff --git a/crates/sh-util/src/cache/sentinel.rs b/crates/sh-util/src/cache/sentinel.rs new file mode 100644 index 0000000..e52b043 --- /dev/null +++ b/crates/sh-util/src/cache/sentinel.rs @@ -0,0 +1,66 @@ +use futures_util::lock::Mutex; +use redis::{ + ErrorKind, IntoConnectionInfo, RedisError, + sentinel::{SentinelClient, SentinelNodeConnectionInfo, SentinelServerType}, +}; +use serde::Deserialize; + +struct LockedSentinelClient(pub(crate) Mutex<SentinelClient>); + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SentinelConfig { + pub service_name: String, + pub redis_tls_mode_secure: bool, + pub redis_db: Option<i64>, + pub redis_username: String, + pub redis_password: String, + pub redis_use_resp3: bool, +} + +/// ConnectionManager that implements `bb8::ManageConnection` and supports +/// asynchronous Sentinel connections via `redis::sentinel::SentinelClient` +pub struct RedisSentinelConnectionManager { + client: LockedSentinelClient, +} + +impl RedisSentinelConnectionManager { + pub fn new<T: IntoConnectionInfo>( + info: Vec<T>, + service_name: String, + node_connection_info: Option<SentinelNodeConnectionInfo>, + ) -> Result<RedisSentinelConnectionManager, RedisError> { + Ok(RedisSentinelConnectionManager { + client: LockedSentinelClient(Mutex::new(SentinelClient::build( + info, + service_name, + node_connection_info, + SentinelServerType::Master, + )?)), + }) + } +} + +impl bb8::ManageConnection for RedisSentinelConnectionManager { + type Connection = redis::aio::MultiplexedConnection; + type Error = RedisError; + + async fn connect(&self) -> Result<Self::Connection, Self::Error> { + self.client.0.lock().await.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong: String = redis::cmd("PING").query_async(conn).await?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(( + ErrorKind::Server(redis::ServerErrorKind::ResponseError), + "ping request", + ) + .into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} diff --git a/crates/sh-util/src/lib.rs b/crates/sh-util/src/lib.rs new file mode 100644 index 0000000..5501a81 --- /dev/null +++ b/crates/sh-util/src/lib.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "cache")] +pub mod cache; |
