diff options
| author | rtkay123 <dev@kanjala.com> | 2026-02-03 13:45:46 +0200 |
|---|---|---|
| committer | rtkay123 <dev@kanjala.com> | 2026-02-03 13:45:46 +0200 |
| commit | eb2e86997d47249aa31b703598de13ab2eb96caa (patch) | |
| tree | 9a591adee7d027b305d07a04987b5559b99f4d37 | |
| parent | 0ea3cb1d4743b922fbc6e07037096e75caffba8f (diff) | |
| download | sellershut-eb2e86997d47249aa31b703598de13ab2eb96caa.tar.bz2 sellershut-eb2e86997d47249aa31b703598de13ab2eb96caa.zip | |
| -rw-r--r-- | Cargo.lock | 117 | ||||
| -rw-r--r-- | Cargo.toml | 6 | ||||
| -rw-r--r-- | misc/sellershut.toml | 20 | ||||
| -rw-r--r-- | src/config/cache.rs | 57 | ||||
| -rw-r--r-- | src/config/cli/cache/mod.rs | 46 | ||||
| -rw-r--r-- | src/config/cli/mod.rs (renamed from src/config/cli.rs) | 50 | ||||
| -rw-r--r-- | src/config/cli/oauth/mod.rs | 36 | ||||
| -rw-r--r-- | src/config/mod.rs | 35 | ||||
| -rw-r--r-- | src/main.rs | 4 | ||||
| -rw-r--r-- | src/server/driver/mod.rs | 16 | ||||
| -rw-r--r-- | src/server/routes/auth/mod.rs | 17 | ||||
| -rw-r--r-- | src/server/state/cache/cluster.rs | 58 | ||||
| -rw-r--r-- | src/server/state/cache/mod.rs | 240 | ||||
| -rw-r--r-- | src/server/state/cache/sentinel.rs | 56 | ||||
| -rw-r--r-- | src/server/state/mod.rs | 1 |
15 files changed, 693 insertions, 66 deletions
@@ -159,6 +159,21 @@ dependencies = [ ] [[package]] +name = "arc-swap" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ded5f9a03ac8f24d1b8a25101ee812cd32cdc8c50a4c50237de2c4915850e73" +dependencies = [ + "rustversion", +] + +[[package]] +name = "arcstr" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03918c3dbd7701a85c6b9887732e2921175f26c350b4563841d0958c21d57e6d" + +[[package]] name = "arrayref" version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -477,6 +492,15 @@ dependencies = [ ] [[package]] +name = "backon" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cffb0e931875b666fc4fcb20fee52e9bbd1ef836fd9e9e04ec21555f9f85f7ef" +dependencies = [ + "fastrand 2.3.0", +] + +[[package]] name = "base64" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -495,6 +519,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" [[package]] +name = "bb8" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "457d7ed3f888dfd2c7af56d4975cade43c622f74bdcddfed6d4352f57acc6310" +dependencies = [ + "futures-util", + "parking_lot 0.12.5", + "portable-atomic", + "tokio", +] + +[[package]] +name = "bb8-redis" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1063effc7f6cf848bcbcc6e31b5962be75215835587d3109607c643d616f66" +dependencies = [ + "bb8", + "redis", +] + +[[package]] name = "bincode" version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -696,6 +742,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" [[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + +[[package]] name = "concurrent-queue" version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -762,6 +822,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] +name = "crc16" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "338089f42c427b86394a5ee60ff321da23a5c89c9d89514c829687b26359fcff" + +[[package]] name = "crc32fast" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1144,11 +1210,10 @@ checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" [[package]] name = "flate2" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ - "crc32fast", "miniz_oxide", "zlib-rs", ] @@ -1599,14 +1664,13 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.19" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ "base64 0.22.1", "bytes", "futures-channel", - "futures-core", "futures-util", "http", "http-body", @@ -2448,6 +2512,34 @@ dependencies = [ ] [[package]] +name = "redis" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e969d1d702793536d5fda739a82b88ad7cbe7d04f8386ee8cd16ad3eff4854a5" +dependencies = [ + "arc-swap", + "arcstr", + "backon", + "bytes", + "cfg-if 1.0.4", + "combine", + "crc16", + "futures-channel", + "futures-util", + "itoa", + "log", + "percent-encoding", + "pin-project-lite", + "rand 0.9.2", + "ryu", + "socket2", + "tokio", + "tokio-util", + "url", + "xxhash-rust", +] + +[[package]] name = "redox_syscall" version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2777,10 +2869,12 @@ dependencies = [ "async-sqlx-session", "async-trait", "axum", + "bb8-redis", "bon", "clap", "oauth2", "rand 0.9.2", + "redis", "secrecy", "serde", "serde_json", @@ -3503,6 +3597,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot 0.12.5", "pin-project-lite", "signal-hook-registry", "socket2", @@ -4441,6 +4536,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" [[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + +[[package]] name = "yoke" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4559,9 +4660,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.5" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40990edd51aae2c2b6907af74ffb635029d5788228222c4bb811e9351c0caad3" +checksum = "a7948af682ccbc3342b6e9420e8c51c1fe5d7bf7756002b4a3c6cabfe96a7e3c" [[package]] name = "zmij" @@ -38,6 +38,7 @@ async-session = "3.0.0" async-sqlx-session = { version = "0.4.0", default-features = false, features = ["rustls", "pg"] } async-trait.workspace = true axum = { version = "0.8.8", features = ["macros"] } +bb8-redis = "0.26.0" bon.workspace = true clap = { version = "4.5.56", features = ["derive", "env"] } oauth2.workspace = true @@ -60,6 +61,11 @@ utoipa-scalar = { version = "0.3.0", optional = true } utoipa-swagger-ui = { version = "9.0.2", optional = true } uuid = { workspace = true, features = ["v7"] } +[dependencies.redis] +version = "1.0.3" +default-features = false +features = ["cluster-async", "connection-manager", "sentinel", "tokio-comp"] + [dependencies.sqlx] version = "0.8.6" default-features = false diff --git a/misc/sellershut.toml b/misc/sellershut.toml index 0c179f5..06102a0 100644 --- a/misc/sellershut.toml +++ b/misc/sellershut.toml @@ -10,11 +10,27 @@ environment = "dev" redirect-url = "https://example.com" [oauth.discord] -#client-id = "" -#client-secret = "" +client-id = "" +client-secret = "" token-url = "https://example.com" auth-url = "https://example.com" [database] url = "postgres://postres:password@localhost:5432/sellershut" pool-size = 100 + +[cache] +dsn = "redis://localhost:6379" +pooled = true +type = "non-clustered" # clustered, non-clustered or sentinel +max-connections = 100 + +[cache.sentinel] +master-name = "mymaster" +nodes = [ + { host = "127.0.0.1", port = 26379 }, + { host = "127.0.0.2", port = 26379 }, + { host = "127.0.0.3", port = 26379 }, +] + +# vim:ft=toml diff --git a/src/config/cache.rs b/src/config/cache.rs new file mode 100644 index 0000000..96f3a9b --- /dev/null +++ b/src/config/cache.rs @@ -0,0 +1,57 @@ +use serde::Deserialize; +use url::Url; + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "lowercase")] +pub struct CacheConfig { + #[serde(rename = "dsn")] + pub redis_dsn: Url, + #[serde(default)] + pub pooled: bool, + #[serde(rename = "type")] + pub kind: RedisVariant, + #[serde(default = "default_max_conns")] + #[serde(rename = "max-connections")] + pub max_connections: u16, +} + +#[derive(Debug, Deserialize, Clone, Default)] +#[serde(rename_all = "kebab-case")] +pub enum RedisVariant { + Clustered, + #[default] + NonClustered, + Sentinel(SentinelConfig), +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +pub struct SentinelConfig { + #[serde(rename = "sentinel_service_name")] + pub service_name: String, + #[serde(default)] + pub redis_tls_mode_secure: bool, + pub redis_db: Option<i64>, + pub redis_username: Option<String>, + pub redis_password: Option<String>, + #[serde(default)] + pub redis_use_resp3: bool, +} + +fn default_max_conns() -> u16 { + 100 +} + +fn default_cache() -> Url { + Url::parse("redis://localhost:6379").expect("valid default DATABASE url") +} + +impl Default for CacheConfig { + fn default() -> Self { + Self { + redis_dsn: default_cache(), + pooled: Default::default(), + kind: Default::default(), + max_connections: default_max_conns(), + } + } +} diff --git a/src/config/cli/cache/mod.rs b/src/config/cli/cache/mod.rs new file mode 100644 index 0000000..04b36bc --- /dev/null +++ b/src/config/cli/cache/mod.rs @@ -0,0 +1,46 @@ +use clap::{Parser, ValueEnum}; +use serde::Deserialize; +use url::Url; + +#[derive(Debug, Clone, Parser, Deserialize, Default)] +pub struct Cache { + /// Cache connection string + #[arg(long, env = "CACHE_URL", default_value = "redis://localhost:6379")] + pub cache_url: Option<Url>, + #[arg(long, env = "CACHE_POOL_ENABLED", default_value = "true")] + pub cache_pooled: Option<bool>, + #[serde(rename = "type")] + #[arg(long, env = "CACHE_TYPE", default_value = "non-clustered")] + pub cache_type: Option<RedisVariant>, + #[serde(default = "default_max_conns")] + #[serde(rename = "max-connections")] + #[arg(long, env = "CACHE_MAX_CONNECTIONS", default_value = "100")] + pub cache_max_conn: Option<u16>, + #[command(flatten)] + pub sentinel_config: SentinelConfig, +} + +#[derive(Debug, Deserialize, Clone, ValueEnum)] +#[serde(rename_all = "kebab-case")] +pub enum RedisVariant { + Clustered, + NonClustered, + Sentinel, +} + +fn default_max_conns() -> Option<u16> { + Some(100) +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Parser, Default)] +pub struct SentinelConfig { + #[serde(rename = "sentinel_service_name")] + #[arg(long, env = "CACHE_SENTINEL_NAME", default_value = "true")] + pub service_name: Option<String>, + #[serde(default)] + #[arg(long, env = "CACHE_TLS_MODE_SECURE")] + pub cache_tls_mode_secure: bool, + #[serde(default)] + #[arg(long, env = "CACHE_USE_RESP3")] + pub cache_use_resp3: bool, +} diff --git a/src/config/cli.rs b/src/config/cli/mod.rs index be1b913..81eb2fe 100644 --- a/src/config/cli.rs +++ b/src/config/cli/mod.rs @@ -1,9 +1,11 @@ +pub mod cache; + +#[cfg(feature = "oauth")] +pub mod oauth; + use std::path::PathBuf; use clap::Parser; -#[cfg(feature = "oauth-discord")] -use secrecy::SecretString; -use serde::Deserialize; use url::Url; use crate::config::{logging::LogLevel, port::port_in_range}; @@ -31,14 +33,17 @@ pub struct Cli { #[arg(short, long, env = "TIMEOUT_SECONDS", default_value = "10")] pub timeout_duration: Option<u64>, - /// Users database connection string + /// Database connection string #[arg( long, - env = "USERS_DATABASE_URL", + env = "DATABASE_URL", default_value = "postgres://postgres:password@localhost:5432/sellershut" )] pub db: Option<Url>, + #[command(flatten)] + pub cache: Option<cache::Cache>, + /// Server's system name #[arg(short, long, default_value = "sellershut", env = "SYSTEM_NAME")] pub system_name: Option<String>, @@ -50,37 +55,7 @@ pub struct Cli { /// Oauth optionas #[command(flatten)] #[cfg(feature = "oauth")] - pub oauth: OAuth, -} - -#[derive(Debug, Clone, Parser, Deserialize)] -pub struct OAuth { - #[cfg(feature = "oauth-discord")] - #[command(flatten)] - discord: DiscordOauth, - #[arg(long, env = "OAUTH_REDIRECT_URL")] - oauth_redirect_url: Option<Url>, -} - -#[cfg(feature = "oauth-discord")] -#[derive(Debug, Clone, Parser, Deserialize, Default)] -pub struct DiscordOauth { - #[arg(long, env = "OAUTH_DISCORD_CLIENT_ID")] - discord_client_id: Option<String>, - #[arg(long, env = "OAUTH_DISCORD_CLIENT_SECRET")] - discord_client_secret: Option<SecretString>, - #[arg( - long, - env = "OAUTH_DISCORD_TOKEN_URL", - default_value = "https://discord.com/api/oauth2/token" - )] - discord_token_url: Option<Url>, - #[arg( - long, - env = "OAUTH_DISCORD_AUTH_URL", - default_value = "https://discord.com/api/oauth2/authorize?response_type=code" - )] - discord_auth_url: Option<Url>, + pub oauth: oauth::OAuth, } #[cfg(test)] @@ -95,7 +70,8 @@ impl Default for Cli { domain: Default::default(), system_name: Default::default(), environment: Default::default(), - oauth: None, + oauth: Default::default(), + cache: Default::default(), db: url, } } diff --git a/src/config/cli/oauth/mod.rs b/src/config/cli/oauth/mod.rs new file mode 100644 index 0000000..4bf1c34 --- /dev/null +++ b/src/config/cli/oauth/mod.rs @@ -0,0 +1,36 @@ +use clap::Parser; +#[cfg(feature = "oauth-discord")] +use secrecy::SecretString; +use serde::Deserialize; +#[cfg(feature = "oauth")] +use url::Url; + +#[derive(Debug, Clone, Parser, Deserialize, Default)] +pub struct OAuth { + #[cfg(feature = "oauth-discord")] + #[command(flatten)] + discord: DiscordOauth, + #[arg(long, env = "OAUTH_REDIRECT_URL")] + oauth_redirect_url: Option<Url>, +} + +#[cfg(feature = "oauth-discord")] +#[derive(Debug, Clone, Parser, Deserialize, Default)] +pub struct DiscordOauth { + #[arg(long, env = "OAUTH_DISCORD_CLIENT_ID")] + discord_client_id: Option<String>, + #[arg(long, env = "OAUTH_DISCORD_CLIENT_SECRET")] + discord_client_secret: Option<SecretString>, + #[arg( + long, + env = "OAUTH_DISCORD_TOKEN_URL", + default_value = "https://discord.com/api/oauth2/token" + )] + discord_token_url: Option<Url>, + #[arg( + long, + env = "OAUTH_DISCORD_AUTH_URL", + default_value = "https://discord.com/api/oauth2/authorize?response_type=code" + )] + discord_auth_url: Option<Url>, +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 7495b22..e64ae5c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,13 +1,15 @@ +pub mod cache; mod cli; mod logging; mod port; pub use cli::Cli; -#[cfg(feature = "oauth")] -use secrecy::SecretString; use serde::Deserialize; use url::Url; -use crate::config::logging::LogLevel; +use crate::config::{ + cache::{CacheConfig, RedisVariant}, + logging::LogLevel, +}; #[derive(Default, Deserialize, Debug, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] @@ -23,6 +25,8 @@ pub struct Config { #[serde(default)] pub database: DatabaseOptions, #[serde(default)] + pub cache: CacheConfig, + #[serde(default)] pub server: Api, #[serde(default)] #[cfg(feature = "oauth")] @@ -65,7 +69,7 @@ pub struct OAuth { #[serde(rename_all = "kebab-case")] pub struct DiscordOauth { pub client_id: String, - pub client_secret: SecretString, + pub client_secret: secrecy::SecretString, #[serde(default = "discord_token_url")] pub token_url: Url, #[serde(default = "discord_auth_url")] @@ -94,7 +98,7 @@ impl Default for OAuth { #[cfg(feature = "oauth-discord")] discord: DiscordOauth { client_id: String::default(), - client_secret: SecretString::default(), + client_secret: secrecy::SecretString::default(), token_url: discord_token_url(), auth_url: discord_auth_url(), }, @@ -175,6 +179,7 @@ impl Config { pub fn merge_with_cli(&mut self, cli: &Cli) { let server = &mut self.server; let dsn = &mut self.database; + let cache = &mut self.cache; if let Some(port) = cli.port { server.port = port; @@ -195,6 +200,26 @@ impl Config { if let Some(db_url) = &cli.db { dsn.url = db_url.clone(); } + + if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_url.clone()) { + cache.redis_dsn = c; + } + + if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_pooled) { + cache.pooled = c; + } + + if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_max_conn) { + cache.max_connections = c; + } + + if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_type.clone()) { + cache.kind = match c { + cli::cache::RedisVariant::Clustered => RedisVariant::Clustered, + cli::cache::RedisVariant::NonClustered => RedisVariant::NonClustered, + cli::cache::RedisVariant::Sentinel => cache.kind.clone(), + }; + } } } diff --git a/src/main.rs b/src/main.rs index 2018956..971c08c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -mod config; +pub mod config; mod logging; mod server; @@ -29,7 +29,7 @@ async fn main() -> anyhow::Result<()> { initialise_logging(&config); - let driver = Services::new(&config.database).await?; + let driver = Services::new(&config.database, &config.cache).await?; let state = AppState::new(&config, driver).await?; let router = server::router(&config, state).await?; diff --git a/src/server/driver/mod.rs b/src/server/driver/mod.rs index 68bd18c..2eaf7dc 100644 --- a/src/server/driver/mod.rs +++ b/src/server/driver/mod.rs @@ -5,26 +5,28 @@ pub mod auth; use async_session::Session; use async_trait::async_trait; #[cfg(feature = "oauth")] -use axum::{ - http::HeaderMap, - response::{IntoResponse, Redirect}, -}; +use axum::{http::HeaderMap, response::Redirect}; #[cfg(feature = "oauth")] use oauth2::CsrfToken; use sqlx::PgPool; -use crate::{config::DatabaseOptions, server::state::database}; +use crate::{ + config::{DatabaseOptions, cache::CacheConfig}, + server::state::{cache::RedisManager, database}, +}; #[derive(Debug, Clone)] pub struct Services { database: PgPool, + cache: RedisManager, } impl Services { - pub async fn new(database: &DatabaseOptions) -> anyhow::Result<Self> { + pub async fn new(database: &DatabaseOptions, cache: &CacheConfig) -> anyhow::Result<Self> { let database = database::connect(database).await?; + let cache = RedisManager::new(cache).await?; - Ok(Self { database }) + Ok(Self { database, cache }) } } diff --git a/src/server/routes/auth/mod.rs b/src/server/routes/auth/mod.rs index 485983a..7d7ecf3 100644 --- a/src/server/routes/auth/mod.rs +++ b/src/server/routes/auth/mod.rs @@ -55,11 +55,14 @@ pub async fn auth( Query(params): Query<Params>, data: Data<AppState>, ) -> Result<impl IntoResponse, AppError> { - match params.provider { - #[cfg(feature = "oauth-discord")] + #[cfg(feature = "oauth-discord")] + return match params.provider { OauthProvider::Discord => discord::discord_auth(data), } - .await + .await; + + #[cfg(not(feature = "oauth-discord"))] + Ok(axum::http::StatusCode::INTERNAL_SERVER_ERROR) } #[utoipa::path( @@ -79,9 +82,13 @@ pub async fn authorised( Query(params): Query<Params>, data: Data<AppState>, ) -> Result<impl IntoResponse, AppError> { - match params.provider { + #[cfg(feature = "oauth-discord")] + return match params.provider { #[cfg(feature = "oauth-discord")] OauthProvider::Discord => discord::discord_auth(data), } - .await + .await; + + #[cfg(not(feature = "oauth-discord"))] + Ok(axum::http::StatusCode::INTERNAL_SERVER_ERROR) } diff --git a/src/server/state/cache/cluster.rs b/src/server/state/cache/cluster.rs new file mode 100644 index 0000000..ea71954 --- /dev/null +++ b/src/server/state/cache/cluster.rs @@ -0,0 +1,58 @@ +use bb8_redis::bb8; +use redis::{ + ErrorKind, FromRedisValue, IntoConnectionInfo, RedisError, + cluster::{ClusterClient, ClusterClientBuilder}, + cluster_routing::{MultipleNodeRoutingInfo, ResponsePolicy, RoutingInfo}, +}; + +/// ConnectionManager that implements `bb8::ManageConnection` and supports +/// asynchronous clustered connections via `redis_cluster_async::Connection` +#[derive(Clone)] +pub struct RedisClusterConnectionManager { + client: ClusterClient, +} + +impl RedisClusterConnectionManager { + pub fn new<T: IntoConnectionInfo>( + info: T, + ) -> Result<RedisClusterConnectionManager, RedisError> { + Ok(RedisClusterConnectionManager { + client: ClusterClientBuilder::new(vec![info]).retries(0).build()?, + }) + } +} + +impl bb8::ManageConnection for RedisClusterConnectionManager { + type Connection = redis::cluster_async::ClusterConnection; + type Error = RedisError; + + async fn connect(&self) -> Result<Self::Connection, Self::Error> { + self.client.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let cmd = redis::cmd("PING"); + let pong = conn + .route_command( + cmd, + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::OneSucceeded), + )), + ) + .await + .and_then(|v| Ok(String::from_redis_value(v)?))?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(( + ErrorKind::Server(redis::ServerErrorKind::ResponseError), + "ping request", + ) + .into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} diff --git a/src/server/state/cache/mod.rs b/src/server/state/cache/mod.rs new file mode 100644 index 0000000..09af5f7 --- /dev/null +++ b/src/server/state/cache/mod.rs @@ -0,0 +1,240 @@ +mod cluster; +mod sentinel; + +use anyhow::Result; +use redis::{ + AsyncConnectionConfig, ProtocolVersion, RedisConnectionInfo, RedisError, TlsMode, + aio::ConnectionManagerConfig, sentinel::SentinelNodeConnectionInfo, +}; +use std::{fmt::Debug, sync::Arc}; + +use bb8_redis::{ + RedisConnectionManager, + bb8::{self, Pool, RunError}, +}; +use tokio::sync::Mutex; + +use crate::{ + config::cache::{CacheConfig, RedisVariant}, + server::state::cache::{ + cluster::RedisClusterConnectionManager, sentinel::RedisSentinelConnectionManager, + }, +}; + +const REDIS_CONN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2); + +#[derive(Clone)] +pub enum RedisManager { + Clustered(Pool<RedisClusterConnectionManager>), + NonClustered(Pool<RedisConnectionManager>), + Sentinel(Pool<RedisSentinelConnectionManager>), + ClusteredUnpooled(redis::cluster_async::ClusterConnection), + NonClusteredUnpooled(redis::aio::ConnectionManager), + SentinelUnpooled(Arc<Mutex<redis::sentinel::SentinelClient>>), +} + +impl Debug for RedisManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Clustered(arg0) => f.debug_tuple("Clustered").field(arg0).finish(), + Self::NonClustered(arg0) => f.debug_tuple("NonClustered").field(arg0).finish(), + Self::Sentinel(arg0) => f.debug_tuple("Sentinel").field(arg0).finish(), + Self::ClusteredUnpooled(_arg0) => f.debug_tuple("ClusteredUnpooled").finish(), + Self::NonClusteredUnpooled(arg0) => { + f.debug_tuple("NonClusteredUnpooled").field(arg0).finish() + } + Self::SentinelUnpooled(_arg0) => f.debug_tuple("SentinelUnpooled").finish(), + } + } +} + +pub enum RedisConnection<'a> { + Clustered(bb8::PooledConnection<'a, RedisClusterConnectionManager>), + NonClustered(bb8::PooledConnection<'a, RedisConnectionManager>), + SentinelPooled(bb8::PooledConnection<'a, RedisSentinelConnectionManager>), + ClusteredUnpooled(redis::cluster_async::ClusterConnection), + NonClusteredUnpooled(redis::aio::ConnectionManager), + SentinelUnpooled(redis::aio::MultiplexedConnection), +} + +impl redis::aio::ConnectionLike for RedisConnection<'_> { + fn req_packed_command<'a>( + &'a mut self, + cmd: &'a redis::Cmd, + ) -> redis::RedisFuture<'a, redis::Value> { + match self { + RedisConnection::Clustered(conn) => conn.req_packed_command(cmd), + RedisConnection::NonClustered(conn) => conn.req_packed_command(cmd), + RedisConnection::ClusteredUnpooled(conn) => conn.req_packed_command(cmd), + RedisConnection::NonClusteredUnpooled(conn) => conn.req_packed_command(cmd), + RedisConnection::SentinelPooled(conn) => conn.req_packed_command(cmd), + RedisConnection::SentinelUnpooled(conn) => conn.req_packed_command(cmd), + } + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> redis::RedisFuture<'a, Vec<redis::Value>> { + match self { + RedisConnection::Clustered(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::NonClustered(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::ClusteredUnpooled(conn) => { + conn.req_packed_commands(cmd, offset, count) + } + RedisConnection::NonClusteredUnpooled(conn) => { + conn.req_packed_commands(cmd, offset, count) + } + RedisConnection::SentinelPooled(conn) => conn.req_packed_commands(cmd, offset, count), + RedisConnection::SentinelUnpooled(conn) => conn.req_packed_commands(cmd, offset, count), + } + } + + fn get_db(&self) -> i64 { + match self { + RedisConnection::Clustered(conn) => conn.get_db(), + RedisConnection::NonClustered(conn) => conn.get_db(), + RedisConnection::ClusteredUnpooled(conn) => conn.get_db(), + RedisConnection::NonClusteredUnpooled(conn) => conn.get_db(), + RedisConnection::SentinelPooled(conn) => conn.get_db(), + RedisConnection::SentinelUnpooled(conn) => conn.get_db(), + } + } +} + +impl RedisManager { + pub async fn new(config: &CacheConfig) -> Result<Self> { + if config.pooled { + Self::new_pooled( + config.redis_dsn.as_ref(), + &config.kind, + config.max_connections, + ) + .await + } else { + Self::new_unpooled(config.redis_dsn.as_ref(), &config.kind).await + } + } + async fn new_pooled(dsn: &str, variant: &RedisVariant, max_conns: u16) -> Result<Self> { + match variant { + RedisVariant::Clustered => { + let mgr = RedisClusterConnectionManager::new(dsn)?; + let pool = bb8::Pool::builder() + .max_size(max_conns.into()) + .build(mgr) + .await?; + Ok(RedisManager::Clustered(pool)) + } + RedisVariant::NonClustered => { + let mgr = RedisConnectionManager::new(dsn)?; + let pool = bb8::Pool::builder() + .max_size(max_conns.into()) + .build(mgr) + .await?; + Ok(RedisManager::NonClustered(pool)) + } + RedisVariant::Sentinel(cfg) => { + let mgr = RedisSentinelConnectionManager::new( + vec![dsn], + cfg.service_name.clone(), + Some(create_config(cfg)), + )?; + let pool = bb8::Pool::builder() + .max_size(max_conns.into()) + .build(mgr) + .await?; + Ok(RedisManager::Sentinel(pool)) + } + } + } + + async fn new_unpooled(dsn: &str, variant: &RedisVariant) -> Result<Self> { + match variant { + RedisVariant::Clustered => { + let cli = redis::cluster::ClusterClient::builder(vec![dsn]) + .retries(1) + .connection_timeout(REDIS_CONN_TIMEOUT) + .build()?; + let con = cli.get_async_connection().await?; + Ok(RedisManager::ClusteredUnpooled(con)) + } + RedisVariant::NonClustered => { + let cli = redis::Client::open(dsn)?; + let con = redis::aio::ConnectionManager::new_with_config( + cli, + ConnectionManagerConfig::new() + .set_number_of_retries(1) + .set_connection_timeout(Some(REDIS_CONN_TIMEOUT)), + ) + .await?; + Ok(RedisManager::NonClusteredUnpooled(con)) + } + RedisVariant::Sentinel(cfg) => { + let cli = redis::sentinel::SentinelClient::build( + vec![dsn], + cfg.service_name.clone(), + Some(create_config(cfg)), + redis::sentinel::SentinelServerType::Master, + )?; + + Ok(RedisManager::SentinelUnpooled(Arc::new(Mutex::new(cli)))) + } + } + } + + pub async fn get(&self) -> Result<RedisConnection<'_>, RunError<RedisError>> { + match self { + Self::Clustered(pool) => Ok(RedisConnection::Clustered(pool.get().await?)), + Self::NonClustered(pool) => Ok(RedisConnection::NonClustered(pool.get().await?)), + Self::Sentinel(pool) => Ok(RedisConnection::SentinelPooled(pool.get().await?)), + Self::ClusteredUnpooled(conn) => Ok(RedisConnection::ClusteredUnpooled(conn.clone())), + Self::NonClusteredUnpooled(conn) => { + Ok(RedisConnection::NonClusteredUnpooled(conn.clone())) + } + Self::SentinelUnpooled(conn) => { + let mut conn = conn.lock().await; + let con = conn + .get_async_connection_with_config( + &AsyncConnectionConfig::new() + .set_response_timeout(Some(REDIS_CONN_TIMEOUT)), + ) + .await?; + Ok(RedisConnection::SentinelUnpooled(con)) + } + } + } +} + +fn create_config(cfg: &crate::config::cache::SentinelConfig) -> SentinelNodeConnectionInfo { + let tls_mode = cfg.redis_tls_mode_secure.then_some(TlsMode::Secure); + let protocol = if cfg.redis_use_resp3 { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::default() + }; + let info = RedisConnectionInfo::default(); + let info = if let Some(pass) = &cfg.redis_password { + info.set_password(pass.clone()) + } else { + info + }; + + let info = if let Some(user) = &cfg.redis_username { + info.set_username(user.clone()) + } else { + info + } + .set_protocol(protocol.clone()) + .set_db(cfg.redis_db.unwrap_or(0)); + + let sent_info = SentinelNodeConnectionInfo::default(); + + if let Some(tls) = tls_mode { + sent_info.set_tls_mode(tls) + } else { + sent_info + } + .set_redis_connection_info(info) +} diff --git a/src/server/state/cache/sentinel.rs b/src/server/state/cache/sentinel.rs new file mode 100644 index 0000000..8dcf394 --- /dev/null +++ b/src/server/state/cache/sentinel.rs @@ -0,0 +1,56 @@ +use bb8_redis::bb8; +use redis::{ + ErrorKind, IntoConnectionInfo, RedisError, + sentinel::{SentinelClient, SentinelNodeConnectionInfo, SentinelServerType}, +}; +use tokio::sync::Mutex; + +struct LockedSentinelClient(pub(crate) Mutex<SentinelClient>); + +/// ConnectionManager that implements `bb8::ManageConnection` and supports +/// asynchronous Sentinel connections via `redis::sentinel::SentinelClient` +pub struct RedisSentinelConnectionManager { + client: LockedSentinelClient, +} + +impl RedisSentinelConnectionManager { + pub fn new<T: IntoConnectionInfo>( + info: Vec<T>, + service_name: String, + node_connection_info: Option<SentinelNodeConnectionInfo>, + ) -> Result<RedisSentinelConnectionManager, RedisError> { + Ok(RedisSentinelConnectionManager { + client: LockedSentinelClient(Mutex::new(SentinelClient::build( + info, + service_name, + node_connection_info, + SentinelServerType::Master, + )?)), + }) + } +} + +impl bb8::ManageConnection for RedisSentinelConnectionManager { + type Connection = redis::aio::MultiplexedConnection; + type Error = RedisError; + + async fn connect(&self) -> Result<Self::Connection, Self::Error> { + self.client.0.lock().await.get_async_connection().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong: String = redis::cmd("PING").query_async(conn).await?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(( + ErrorKind::Server(redis::ServerErrorKind::ResponseError), + "ping request", + ) + .into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} diff --git a/src/server/state/mod.rs b/src/server/state/mod.rs index c86052d..f256949 100644 --- a/src/server/state/mod.rs +++ b/src/server/state/mod.rs @@ -1,3 +1,4 @@ +pub mod cache; pub mod database; pub mod federation; |
