From eb2e86997d47249aa31b703598de13ab2eb96caa Mon Sep 17 00:00:00 2001 From: rtkay123 Date: Tue, 3 Feb 2026 13:45:46 +0200 Subject: feat: add cache --- src/config/cache.rs | 57 +++++++++ src/config/cli.rs | 113 ----------------- src/config/cli/cache/mod.rs | 46 +++++++ src/config/cli/mod.rs | 89 ++++++++++++++ src/config/cli/oauth/mod.rs | 36 ++++++ src/config/mod.rs | 35 +++++- src/main.rs | 4 +- src/server/driver/mod.rs | 16 +-- src/server/routes/auth/mod.rs | 17 ++- src/server/state/cache/cluster.rs | 58 +++++++++ src/server/state/cache/mod.rs | 240 +++++++++++++++++++++++++++++++++++++ src/server/state/cache/sentinel.rs | 56 +++++++++ src/server/state/mod.rs | 1 + 13 files changed, 636 insertions(+), 132 deletions(-) create mode 100644 src/config/cache.rs delete mode 100644 src/config/cli.rs create mode 100644 src/config/cli/cache/mod.rs create mode 100644 src/config/cli/mod.rs create mode 100644 src/config/cli/oauth/mod.rs create mode 100644 src/server/state/cache/cluster.rs create mode 100644 src/server/state/cache/mod.rs create mode 100644 src/server/state/cache/sentinel.rs (limited to 'src') diff --git a/src/config/cache.rs b/src/config/cache.rs new file mode 100644 index 0000000..96f3a9b --- /dev/null +++ b/src/config/cache.rs @@ -0,0 +1,57 @@ +use serde::Deserialize; +use url::Url; + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "lowercase")] +pub struct CacheConfig { + #[serde(rename = "dsn")] + pub redis_dsn: Url, + #[serde(default)] + pub pooled: bool, + #[serde(rename = "type")] + pub kind: RedisVariant, + #[serde(default = "default_max_conns")] + #[serde(rename = "max-connections")] + pub max_connections: u16, +} + +#[derive(Debug, Deserialize, Clone, Default)] +#[serde(rename_all = "kebab-case")] +pub enum RedisVariant { + Clustered, + #[default] + NonClustered, + Sentinel(SentinelConfig), +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +pub struct SentinelConfig { + #[serde(rename = "sentinel_service_name")] + pub service_name: String, + #[serde(default)] + pub redis_tls_mode_secure: bool, + pub redis_db: Option, + pub redis_username: Option, + pub redis_password: Option, + #[serde(default)] + pub redis_use_resp3: bool, +} + +fn default_max_conns() -> u16 { + 100 +} + +fn default_cache() -> Url { + Url::parse("redis://localhost:6379").expect("valid default DATABASE url") +} + +impl Default for CacheConfig { + fn default() -> Self { + Self { + redis_dsn: default_cache(), + pooled: Default::default(), + kind: Default::default(), + max_connections: default_max_conns(), + } + } +} diff --git a/src/config/cli.rs b/src/config/cli.rs deleted file mode 100644 index be1b913..0000000 --- a/src/config/cli.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::path::PathBuf; - -use clap::Parser; -#[cfg(feature = "oauth-discord")] -use secrecy::SecretString; -use serde::Deserialize; -use url::Url; - -use crate::config::{logging::LogLevel, port::port_in_range}; - -#[derive(Parser, Debug)] -#[command(version, about, long_about = None, name = env!("CARGO_PKG_NAME"))] -pub struct Cli { - /// Sets the port the server listens on - #[arg(short, long, default_value = "2210", env = "PORT", value_parser = port_in_range)] - pub port: Option, - - /// Server's domain - #[arg(short, long, default_value = "localhost", env = "DOMAIN")] - pub domain: Option, - - /// Sets a custom config file - #[arg(short, long, value_name = "FILE")] - pub config: Option, - - /// Sets the application log level - #[arg(short, long, value_enum, env = "LOG_LEVEL", default_value = "debug")] - pub log_level: Option, - - /// Request timeout duration (in seconds) - #[arg(short, long, env = "TIMEOUT_SECONDS", default_value = "10")] - pub timeout_duration: Option, - - /// Users database connection string - #[arg( - long, - env = "USERS_DATABASE_URL", - default_value = "postgres://postgres:password@localhost:5432/sellershut" - )] - pub db: Option, - - /// Server's system name - #[arg(short, long, default_value = "sellershut", env = "SYSTEM_NAME")] - pub system_name: Option, - - /// Server's system name - #[arg(short, long, default_value = "prod", env = "SYSTEM_NAME")] - pub environment: Option, - - /// Oauth optionas - #[command(flatten)] - #[cfg(feature = "oauth")] - pub oauth: OAuth, -} - -#[derive(Debug, Clone, Parser, Deserialize)] -pub struct OAuth { - #[cfg(feature = "oauth-discord")] - #[command(flatten)] - discord: DiscordOauth, - #[arg(long, env = "OAUTH_REDIRECT_URL")] - oauth_redirect_url: Option, -} - -#[cfg(feature = "oauth-discord")] -#[derive(Debug, Clone, Parser, Deserialize, Default)] -pub struct DiscordOauth { - #[arg(long, env = "OAUTH_DISCORD_CLIENT_ID")] - discord_client_id: Option, - #[arg(long, env = "OAUTH_DISCORD_CLIENT_SECRET")] - discord_client_secret: Option, - #[arg( - long, - env = "OAUTH_DISCORD_TOKEN_URL", - default_value = "https://discord.com/api/oauth2/token" - )] - discord_token_url: Option, - #[arg( - long, - env = "OAUTH_DISCORD_AUTH_URL", - default_value = "https://discord.com/api/oauth2/authorize?response_type=code" - )] - discord_auth_url: Option, -} - -#[cfg(test)] -impl Default for Cli { - fn default() -> Self { - let url = Url::parse("postgres://postgres:password@localhost:5432/sellershut").ok(); - Self { - port: Default::default(), - config: Default::default(), - log_level: Default::default(), - timeout_duration: Some(10), - domain: Default::default(), - system_name: Default::default(), - environment: Default::default(), - oauth: None, - db: url, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn verify_cli() { - use clap::CommandFactory; - Cli::command().debug_assert(); - } -} diff --git a/src/config/cli/cache/mod.rs b/src/config/cli/cache/mod.rs new file mode 100644 index 0000000..04b36bc --- /dev/null +++ b/src/config/cli/cache/mod.rs @@ -0,0 +1,46 @@ +use clap::{Parser, ValueEnum}; +use serde::Deserialize; +use url::Url; + +#[derive(Debug, Clone, Parser, Deserialize, Default)] +pub struct Cache { + /// Cache connection string + #[arg(long, env = "CACHE_URL", default_value = "redis://localhost:6379")] + pub cache_url: Option, + #[arg(long, env = "CACHE_POOL_ENABLED", default_value = "true")] + pub cache_pooled: Option, + #[serde(rename = "type")] + #[arg(long, env = "CACHE_TYPE", default_value = "non-clustered")] + pub cache_type: Option, + #[serde(default = "default_max_conns")] + #[serde(rename = "max-connections")] + #[arg(long, env = "CACHE_MAX_CONNECTIONS", default_value = "100")] + pub cache_max_conn: Option, + #[command(flatten)] + pub sentinel_config: SentinelConfig, +} + +#[derive(Debug, Deserialize, Clone, ValueEnum)] +#[serde(rename_all = "kebab-case")] +pub enum RedisVariant { + Clustered, + NonClustered, + Sentinel, +} + +fn default_max_conns() -> Option { + Some(100) +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Parser, Default)] +pub struct SentinelConfig { + #[serde(rename = "sentinel_service_name")] + #[arg(long, env = "CACHE_SENTINEL_NAME", default_value = "true")] + pub service_name: Option, + #[serde(default)] + #[arg(long, env = "CACHE_TLS_MODE_SECURE")] + pub cache_tls_mode_secure: bool, + #[serde(default)] + #[arg(long, env = "CACHE_USE_RESP3")] + pub cache_use_resp3: bool, +} diff --git a/src/config/cli/mod.rs b/src/config/cli/mod.rs new file mode 100644 index 0000000..81eb2fe --- /dev/null +++ b/src/config/cli/mod.rs @@ -0,0 +1,89 @@ +pub mod cache; + +#[cfg(feature = "oauth")] +pub mod oauth; + +use std::path::PathBuf; + +use clap::Parser; +use url::Url; + +use crate::config::{logging::LogLevel, port::port_in_range}; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None, name = env!("CARGO_PKG_NAME"))] +pub struct Cli { + /// Sets the port the server listens on + #[arg(short, long, default_value = "2210", env = "PORT", value_parser = port_in_range)] + pub port: Option, + + /// Server's domain + #[arg(short, long, default_value = "localhost", env = "DOMAIN")] + pub domain: Option, + + /// Sets a custom config file + #[arg(short, long, value_name = "FILE")] + pub config: Option, + + /// Sets the application log level + #[arg(short, long, value_enum, env = "LOG_LEVEL", default_value = "debug")] + pub log_level: Option, + + /// Request timeout duration (in seconds) + #[arg(short, long, env = "TIMEOUT_SECONDS", default_value = "10")] + pub timeout_duration: Option, + + /// Database connection string + #[arg( + long, + env = "DATABASE_URL", + default_value = "postgres://postgres:password@localhost:5432/sellershut" + )] + pub db: Option, + + #[command(flatten)] + pub cache: Option, + + /// Server's system name + #[arg(short, long, default_value = "sellershut", env = "SYSTEM_NAME")] + pub system_name: Option, + + /// Server's system name + #[arg(short, long, default_value = "prod", env = "SYSTEM_NAME")] + pub environment: Option, + + /// Oauth optionas + #[command(flatten)] + #[cfg(feature = "oauth")] + pub oauth: oauth::OAuth, +} + +#[cfg(test)] +impl Default for Cli { + fn default() -> Self { + let url = Url::parse("postgres://postgres:password@localhost:5432/sellershut").ok(); + Self { + port: Default::default(), + config: Default::default(), + log_level: Default::default(), + timeout_duration: Some(10), + domain: Default::default(), + system_name: Default::default(), + environment: Default::default(), + oauth: Default::default(), + cache: Default::default(), + db: url, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_cli() { + use clap::CommandFactory; + Cli::command().debug_assert(); + } +} diff --git a/src/config/cli/oauth/mod.rs b/src/config/cli/oauth/mod.rs new file mode 100644 index 0000000..4bf1c34 --- /dev/null +++ b/src/config/cli/oauth/mod.rs @@ -0,0 +1,36 @@ +use clap::Parser; +#[cfg(feature = "oauth-discord")] +use secrecy::SecretString; +use serde::Deserialize; +#[cfg(feature = "oauth")] +use url::Url; + +#[derive(Debug, Clone, Parser, Deserialize, Default)] +pub struct OAuth { + #[cfg(feature = "oauth-discord")] + #[command(flatten)] + discord: DiscordOauth, + #[arg(long, env = "OAUTH_REDIRECT_URL")] + oauth_redirect_url: Option, +} + +#[cfg(feature = "oauth-discord")] +#[derive(Debug, Clone, Parser, Deserialize, Default)] +pub struct DiscordOauth { + #[arg(long, env = "OAUTH_DISCORD_CLIENT_ID")] + discord_client_id: Option, + #[arg(long, env = "OAUTH_DISCORD_CLIENT_SECRET")] + discord_client_secret: Option, + #[arg( + long, + env = "OAUTH_DISCORD_TOKEN_URL", + default_value = "https://discord.com/api/oauth2/token" + )] + discord_token_url: Option, + #[arg( + long, + env = "OAUTH_DISCORD_AUTH_URL", + default_value = "https://discord.com/api/oauth2/authorize?response_type=code" + )] + discord_auth_url: Option, +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 7495b22..e64ae5c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,13 +1,15 @@ +pub mod cache; mod cli; mod logging; mod port; pub use cli::Cli; -#[cfg(feature = "oauth")] -use secrecy::SecretString; use serde::Deserialize; use url::Url; -use crate::config::logging::LogLevel; +use crate::config::{ + cache::{CacheConfig, RedisVariant}, + logging::LogLevel, +}; #[derive(Default, Deserialize, Debug, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] @@ -23,6 +25,8 @@ pub struct Config { #[serde(default)] pub database: DatabaseOptions, #[serde(default)] + pub cache: CacheConfig, + #[serde(default)] pub server: Api, #[serde(default)] #[cfg(feature = "oauth")] @@ -65,7 +69,7 @@ pub struct OAuth { #[serde(rename_all = "kebab-case")] pub struct DiscordOauth { pub client_id: String, - pub client_secret: SecretString, + pub client_secret: secrecy::SecretString, #[serde(default = "discord_token_url")] pub token_url: Url, #[serde(default = "discord_auth_url")] @@ -94,7 +98,7 @@ impl Default for OAuth { #[cfg(feature = "oauth-discord")] discord: DiscordOauth { client_id: String::default(), - client_secret: SecretString::default(), + client_secret: secrecy::SecretString::default(), token_url: discord_token_url(), auth_url: discord_auth_url(), }, @@ -175,6 +179,7 @@ impl Config { pub fn merge_with_cli(&mut self, cli: &Cli) { let server = &mut self.server; let dsn = &mut self.database; + let cache = &mut self.cache; if let Some(port) = cli.port { server.port = port; @@ -195,6 +200,26 @@ impl Config { if let Some(db_url) = &cli.db { dsn.url = db_url.clone(); } + + if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_url.clone()) { + cache.redis_dsn = c; + } + + if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_pooled) { + cache.pooled = c; + } + + if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_max_conn) { + cache.max_connections = c; + } + + if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_type.clone()) { + cache.kind = match c { + cli::cache::RedisVariant::Clustered => RedisVariant::Clustered, + cli::cache::RedisVariant::NonClustered => RedisVariant::NonClustered, + cli::cache::RedisVariant::Sentinel => cache.kind.clone(), + }; + } } } diff --git a/src/main.rs b/src/main.rs index 2018956..971c08c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -mod config; +pub mod config; mod logging; mod server; @@ -29,7 +29,7 @@ async fn main() -> anyhow::Result<()> { initialise_logging(&config); - let driver = Services::new(&config.database).await?; + let driver = Services::new(&config.database, &config.cache).await?; let state = AppState::new(&config, driver).await?; let router = server::router(&config, state).await?; diff --git a/src/server/driver/mod.rs b/src/server/driver/mod.rs index 68bd18c..2eaf7dc 100644 --- a/src/server/driver/mod.rs +++ b/src/server/driver/mod.rs @@ -5,26 +5,28 @@ pub mod auth; use async_session::Session; use async_trait::async_trait; #[cfg(feature = "oauth")] -use axum::{ - http::HeaderMap, - response::{IntoResponse, Redirect}, -}; +use axum::{http::HeaderMap, response::Redirect}; #[cfg(feature = "oauth")] use oauth2::CsrfToken; use sqlx::PgPool; -use crate::{config::DatabaseOptions, server::state::database}; +use crate::{ + config::{DatabaseOptions, cache::CacheConfig}, + server::state::{cache::RedisManager, database}, +}; #[derive(Debug, Clone)] pub struct Services { database: PgPool, + cache: RedisManager, } impl Services { - pub async fn new(database: &DatabaseOptions) -> anyhow::Result { + pub async fn new(database: &DatabaseOptions, cache: &CacheConfig) -> anyhow::Result { let database = database::connect(database).await?; + let cache = RedisManager::new(cache).await?; - Ok(Self { database }) + Ok(Self { database, cache }) } } diff --git a/src/server/routes/auth/mod.rs b/src/server/routes/auth/mod.rs index 485983a..7d7ecf3 100644 --- a/src/server/routes/auth/mod.rs +++ b/src/server/routes/auth/mod.rs @@ -55,11 +55,14 @@ pub async fn auth( Query(params): Query, data: Data, ) -> Result { - match params.provider { - #[cfg(feature = "oauth-discord")] + #[cfg(feature = "oauth-discord")] + return match params.provider { OauthProvider::Discord => discord::discord_auth(data), } - .await + .await; + + #[cfg(not(feature = "oauth-discord"))] + Ok(axum::http::StatusCode::INTERNAL_SERVER_ERROR) } #[utoipa::path( @@ -79,9 +82,13 @@ pub async fn authorised( Query(params): Query, data: Data, ) -> Result { - match params.provider { + #[cfg(feature = "oauth-discord")] + return match params.provider { #[cfg(feature = "oauth-discord")] OauthProvider::Discord => discord::discord_auth(data), } - .await + .await; + + #[cfg(not(feature = "oauth-discord"))] + Ok(axum::http::StatusCode::INTERNAL_SERVER_ERROR) } diff --git a/src/server/state/cache/cluster.rs b/src/server/state/cache/cluster.rs new file mode 100644 index 0000000..ea71954 --- /dev/null +++ b/src/server/state/cache/cluster.rs @@ -0,0 +1,58 @@ +use bb8_redis::bb8; +use redis::{ + ErrorKind, FromRedisValue, IntoConnectionInfo, RedisError, + cluster::{ClusterClient, ClusterClientBuilder}, + cluster_routing::{MultipleNodeRoutingInfo, ResponsePolicy, RoutingInfo}, +}; + +/// ConnectionManager that implements `bb8::ManageConnection` and supports +/// asynchronous clustered connections via `redis_cluster_async::Connection` +#[derive(Clone)] +pub struct RedisClusterConnectionManager { + client: ClusterClient, +} + +impl RedisClusterConnectionManager { + pub fn new( + info: T, + ) -> Result { + Ok(RedisClusterConnectionManager { + client: ClusterClientBuilder::new(vec![info]).retries(0).build()?, + }) + } +} + +impl bb8::ManageConnection for RedisClusterConnectionManager { + type Connection = redis::cluster_async::ClusterConnection; + type Error = RedisError; + + async fn connect(&self) -> Result { + self.client.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let cmd = redis::cmd("PING"); + let pong = conn + .route_command( + cmd, + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::OneSucceeded), + )), + ) + .await + .and_then(|v| Ok(String::from_redis_value(v)?))?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(( + ErrorKind::Server(redis::ServerErrorKind::ResponseError), + "ping request", + ) + .into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} diff --git a/src/server/state/cache/mod.rs b/src/server/state/cache/mod.rs new file mode 100644 index 0000000..09af5f7 --- /dev/null +++ b/src/server/state/cache/mod.rs @@ -0,0 +1,240 @@ +mod cluster; +mod sentinel; + +use anyhow::Result; +use redis::{ + AsyncConnectionConfig, ProtocolVersion, RedisConnectionInfo, RedisError, TlsMode, + aio::ConnectionManagerConfig, sentinel::SentinelNodeConnectionInfo, +}; +use std::{fmt::Debug, sync::Arc}; + +use bb8_redis::{ + RedisConnectionManager, + bb8::{self, Pool, RunError}, +}; +use tokio::sync::Mutex; + +use crate::{ + config::cache::{CacheConfig, RedisVariant}, + server::state::cache::{ + cluster::RedisClusterConnectionManager, sentinel::RedisSentinelConnectionManager, + }, +}; + +const REDIS_CONN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2); + +#[derive(Clone)] +pub enum RedisManager { + Clustered(Pool), + NonClustered(Pool), + Sentinel(Pool), + ClusteredUnpooled(redis::cluster_async::ClusterConnection), + NonClusteredUnpooled(redis::aio::ConnectionManager), + SentinelUnpooled(Arc>), +} + +impl Debug for RedisManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Clustered(arg0) => f.debug_tuple("Clustered").field(arg0).finish(), + Self::NonClustered(arg0) => f.debug_tuple("NonClustered").field(arg0).finish(), + Self::Sentinel(arg0) => f.debug_tuple("Sentinel").field(arg0).finish(), + Self::ClusteredUnpooled(_arg0) => f.debug_tuple("ClusteredUnpooled").finish(), + Self::NonClusteredUnpooled(arg0) => { + f.debug_tuple("NonClusteredUnpooled").field(arg0).finish() + } + Self::SentinelUnpooled(_arg0) => f.debug_tuple("SentinelUnpooled").finish(), + } + } +} + +pub enum RedisConnection<'a> { + Clustered(bb8::PooledConnection<'a, RedisClusterConnectionManager>), + NonClustered(bb8::PooledConnection<'a, RedisConnectionManager>), + SentinelPooled(bb8::PooledConnection<'a, RedisSentinelConnectionManager>), + ClusteredUnpooled(redis::cluster_async::ClusterConnection), + NonClusteredUnpooled(redis::aio::ConnectionManager), + SentinelUnpooled(redis::aio::MultiplexedConnection), +} + +impl redis::aio::ConnectionLike for RedisConnection<'_> { + fn req_packed_command<'a>( + &'a mut self, + cmd: &'a redis::Cmd, + ) -> redis::RedisFuture<'a, redis::Value> { + match self { + RedisConnection::Clustered(conn) => conn.req_packed_command(cmd), + RedisConnection::NonClustered(conn) => conn.req_packed_command(cmd), + RedisConnection::ClusteredUnpooled(conn) => conn.req_packed_command(cmd), + RedisConnection::NonClusteredUnpooled(conn) => conn.req_packed_command(cmd), + RedisConnection::SentinelPooled(conn) => conn.req_packed_command(cmd), + RedisConnection::SentinelUnpooled(conn) => conn.req_packed_command(cmd), + } + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> redis::RedisFuture<'a, Vec> { + match self { + RedisConnection::Clustered(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::NonClustered(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::ClusteredUnpooled(conn) => { + conn.req_packed_commands(cmd, offset, count) + } + RedisConnection::NonClusteredUnpooled(conn) => { + conn.req_packed_commands(cmd, offset, count) + } + RedisConnection::SentinelPooled(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::SentinelUnpooled(conn) => conn.req_packed_commands(cmd, offset, count), + } + } + + fn get_db(&self) -> i64 { + match self { + RedisConnection::Clustered(conn) => conn.get_db(), + RedisConnection::NonClustered(conn) => conn.get_db(), + RedisConnection::ClusteredUnpooled(conn) => conn.get_db(), + RedisConnection::NonClusteredUnpooled(conn) => conn.get_db(), + RedisConnection::SentinelPooled(conn) => conn.get_db(), + RedisConnection::SentinelUnpooled(conn) => conn.get_db(), + } + } +} + +impl RedisManager { + pub async fn new(config: &CacheConfig) -> Result { + if config.pooled { + Self::new_pooled( + config.redis_dsn.as_ref(), + &config.kind, + config.max_connections, + ) + .await + } else { + Self::new_unpooled(config.redis_dsn.as_ref(), &config.kind).await + } + } + async fn new_pooled(dsn: &str, variant: &RedisVariant, max_conns: u16) -> Result { + match variant { + RedisVariant::Clustered => { + let mgr = RedisClusterConnectionManager::new(dsn)?; + let pool = bb8::Pool::builder() + .max_size(max_conns.into()) + .build(mgr) + .await?; + Ok(RedisManager::Clustered(pool)) + } + RedisVariant::NonClustered => { + let mgr = RedisConnectionManager::new(dsn)?; + let pool = bb8::Pool::builder() + .max_size(max_conns.into()) + .build(mgr) + .await?; + Ok(RedisManager::NonClustered(pool)) + } + RedisVariant::Sentinel(cfg) => { + let mgr = RedisSentinelConnectionManager::new( + vec![dsn], + cfg.service_name.clone(), + Some(create_config(cfg)), + )?; + let pool = bb8::Pool::builder() + .max_size(max_conns.into()) + .build(mgr) + .await?; + Ok(RedisManager::Sentinel(pool)) + } + } + } + + async fn new_unpooled(dsn: &str, variant: &RedisVariant) -> Result { + match variant { + RedisVariant::Clustered => { + let cli = redis::cluster::ClusterClient::builder(vec![dsn]) + .retries(1) + .connection_timeout(REDIS_CONN_TIMEOUT) + .build()?; + let con = cli.get_async_connection().await?; + Ok(RedisManager::ClusteredUnpooled(con)) + } + RedisVariant::NonClustered => { + let cli = redis::Client::open(dsn)?; + let con = redis::aio::ConnectionManager::new_with_config( + cli, + ConnectionManagerConfig::new() + .set_number_of_retries(1) + .set_connection_timeout(Some(REDIS_CONN_TIMEOUT)), + ) + .await?; + Ok(RedisManager::NonClusteredUnpooled(con)) + } + RedisVariant::Sentinel(cfg) => { + let cli = redis::sentinel::SentinelClient::build( + vec![dsn], + cfg.service_name.clone(), + Some(create_config(cfg)), + redis::sentinel::SentinelServerType::Master, + )?; + + Ok(RedisManager::SentinelUnpooled(Arc::new(Mutex::new(cli)))) + } + } + } + + pub async fn get(&self) -> Result, RunError> { + match self { + Self::Clustered(pool) => Ok(RedisConnection::Clustered(pool.get().await?)), + Self::NonClustered(pool) => Ok(RedisConnection::NonClustered(pool.get().await?)), + Self::Sentinel(pool) => Ok(RedisConnection::SentinelPooled(pool.get().await?)), + Self::ClusteredUnpooled(conn) => Ok(RedisConnection::ClusteredUnpooled(conn.clone())), + Self::NonClusteredUnpooled(conn) => { + Ok(RedisConnection::NonClusteredUnpooled(conn.clone())) + } + Self::SentinelUnpooled(conn) => { + let mut conn = conn.lock().await; + let con = conn + .get_async_connection_with_config( + &AsyncConnectionConfig::new() + .set_response_timeout(Some(REDIS_CONN_TIMEOUT)), + ) + .await?; + Ok(RedisConnection::SentinelUnpooled(con)) + } + } + } +} + +fn create_config(cfg: &crate::config::cache::SentinelConfig) -> SentinelNodeConnectionInfo { + let tls_mode = cfg.redis_tls_mode_secure.then_some(TlsMode::Secure); + let protocol = if cfg.redis_use_resp3 { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::default() + }; + let info = RedisConnectionInfo::default(); + let info = if let Some(pass) = &cfg.redis_password { + info.set_password(pass.clone()) + } else { + info + }; + + let info = if let Some(user) = &cfg.redis_username { + info.set_username(user.clone()) + } else { + info + } + .set_protocol(protocol.clone()) + .set_db(cfg.redis_db.unwrap_or(0)); + + let sent_info = SentinelNodeConnectionInfo::default(); + + if let Some(tls) = tls_mode { + sent_info.set_tls_mode(tls) + } else { + sent_info + } + .set_redis_connection_info(info) +} diff --git a/src/server/state/cache/sentinel.rs b/src/server/state/cache/sentinel.rs new file mode 100644 index 0000000..8dcf394 --- /dev/null +++ b/src/server/state/cache/sentinel.rs @@ -0,0 +1,56 @@ +use bb8_redis::bb8; +use redis::{ + ErrorKind, IntoConnectionInfo, RedisError, + sentinel::{SentinelClient, SentinelNodeConnectionInfo, SentinelServerType}, +}; +use tokio::sync::Mutex; + +struct LockedSentinelClient(pub(crate) Mutex); + +/// ConnectionManager that implements `bb8::ManageConnection` and supports +/// asynchronous Sentinel connections via `redis::sentinel::SentinelClient` +pub struct RedisSentinelConnectionManager { + client: LockedSentinelClient, +} + +impl RedisSentinelConnectionManager { + pub fn new( + info: Vec, + service_name: String, + node_connection_info: Option, + ) -> Result { + Ok(RedisSentinelConnectionManager { + client: LockedSentinelClient(Mutex::new(SentinelClient::build( + info, + service_name, + node_connection_info, + SentinelServerType::Master, + )?)), + }) + } +} + +impl bb8::ManageConnection for RedisSentinelConnectionManager { + type Connection = redis::aio::MultiplexedConnection; + type Error = RedisError; + + async fn connect(&self) -> Result { + self.client.0.lock().await.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong: String = redis::cmd("PING").query_async(conn).await?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(( + ErrorKind::Server(redis::ServerErrorKind::ResponseError), + "ping request", + ) + .into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} diff --git a/src/server/state/mod.rs b/src/server/state/mod.rs index c86052d..f256949 100644 --- a/src/server/state/mod.rs +++ b/src/server/state/mod.rs @@ -1,3 +1,4 @@ +pub mod cache; pub mod database; pub mod federation; -- cgit v1.2.3