From ce65d9eeafcd1f9d5c3adef1c9b1af6258ee711a Mon Sep 17 00:00:00 2001 From: rtkay123 Date: Sun, 1 Feb 2026 13:33:07 +0200 Subject: feat: conn to db --- src/config/cli.rs | 101 ++++++++++++++++++++++++++++++++ src/config/logging.rs | 68 ++++++++++++++++++++++ src/config/mod.rs | 155 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/config/port.rs | 42 ++++++++++++++ 4 files changed, 366 insertions(+) create mode 100644 src/config/cli.rs create mode 100644 src/config/logging.rs create mode 100644 src/config/mod.rs create mode 100644 src/config/port.rs (limited to 'src/config') diff --git a/src/config/cli.rs b/src/config/cli.rs new file mode 100644 index 0000000..dab7216 --- /dev/null +++ b/src/config/cli.rs @@ -0,0 +1,101 @@ +use std::path::PathBuf; + +use clap::Parser; +use url::Url; + +use crate::config::{logging::LogLevel, port::port_in_range}; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None, name = env!("CARGO_PKG_NAME"))] +pub struct Cli { + /// Sets the port the server listens on + #[arg(short, long, default_value = "2210", env = "PORT", value_parser = port_in_range)] + pub port: Option, + + /// Server's domain + #[arg(short, long, default_value = "localhost", env = "DOMAIN")] + pub domain: Option, + + /// Sets a custom config file + #[arg(short, long, value_name = "FILE")] + pub config: Option, + + /// Sets the application log level + #[arg(short, long, value_enum, env = "LOG_LEVEL", default_value = "debug")] + pub log_level: Option, + + /// Request timeout duration (in seconds) + #[arg(short, long, env = "TIMEOUT_SECONDS", default_value = "10")] + pub timeout_duration: Option, + + /// Users database connection string + #[arg( + long, + env = "USERS_DATABASE_URL", + default_value = "postgres://postgres:password@localhost:5432/sellershut" + )] + pub db: Option, + + /// Server's system name + #[arg(short, long, default_value = "sellershut", env = "SYSTEM_NAME")] + pub system_name: Option, + + /// Server's system name + #[arg(short, long, default_value = "prod", env = "SYSTEM_NAME")] + pub environment: Option, + + /// Oauth optionas + #[command(flatten)] + pub oauth: Option, +} + +#[derive(Debug, Clone, Parser)] +pub struct OAuth { + #[cfg(feature = "oauth-discord")] + #[command(flatten)] + discord: DiscordOauth, + #[arg(long)] + oauth_redirect_url: Option, +} + +#[cfg(feature = "oauth-discord")] +#[derive(Debug, Clone, Parser)] +pub struct DiscordOauth { + #[arg(long)] + discord_client_id: Option, + #[arg(long)] + discord_client_secret: Option, + #[arg(long)] + discord_token_url: Option, + #[arg(long)] + discord_auth_url: Option, +} + +#[cfg(test)] +impl Default for Cli { + fn default() -> Self { + let url = Url::parse("postgres://postgres:password@localhost:5432/sellershut").ok(); + Self { + port: Default::default(), + config: Default::default(), + log_level: Default::default(), + timeout_duration: Some(10), + domain: Default::default(), + system_name: Default::default(), + environment: Default::default(), + oauth: None, + db: url, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_cli() { + use clap::CommandFactory; + Cli::command().debug_assert(); + } +} diff --git a/src/config/logging.rs b/src/config/logging.rs new file mode 100644 index 0000000..c6cfe7f --- /dev/null +++ b/src/config/logging.rs @@ -0,0 +1,68 @@ +use clap::ValueEnum; +use serde::Deserialize; + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug, Default, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// The "trace" level. + /// + /// Designates very low priority, often extremely verbose, information. + Trace = 0, + /// The "debug" level. + /// + /// Designates lower priority information. + #[default] + Debug = 1, + /// The "info" level. + /// + /// Designates useful information. + Info = 2, + /// The "warn" level. + /// + /// Designates hazardous situations. + Warn = 3, + /// The "error" level. + /// + /// Designates very serious errors. + Error = 4, +} + +impl From for tracing::Level { + fn from(value: LogLevel) -> Self { + match value { + LogLevel::Trace => tracing::Level::TRACE, + LogLevel::Debug => tracing::Level::DEBUG, + LogLevel::Info => tracing::Level::INFO, + LogLevel::Warn => tracing::Level::WARN, + LogLevel::Error => tracing::Level::ERROR, + } + } +} + +#[cfg(test)] +mod tests { + use crate::config::logging::LogLevel; + + fn check(level: LogLevel, value: &str) { + let level = tracing::Level::from(level); + assert_eq!(level.to_string().to_lowercase(), value); + } + + #[test] + fn loglevel() { + let level = LogLevel::Trace; + check(level, "trace"); + + let level = LogLevel::Debug; + check(level, "debug"); + + let level = LogLevel::Info; + check(level, "info"); + + let level = LogLevel::Warn; + check(level, "warn"); + + let level = LogLevel::Error; + check(level, "error"); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..45e12c3 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,155 @@ +mod cli; +mod logging; +mod port; +pub use cli::Cli; +use serde::Deserialize; +use url::Url; + +use crate::{config::logging::LogLevel}; + +#[derive(Default, Deserialize, Debug, PartialEq, Eq)] +#[serde(rename_all = "kebab-case")] +pub enum Environment { + #[default] + Dev, + Prod, +} + +#[derive(Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct Config { + #[serde(default)] + pub database: DatabaseOptions, + #[serde(default)] + pub server: Api, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct Api { + #[serde(default = "default_domain")] + pub domain: String, + + #[serde(default = "default_request_timeout")] + pub request_timeout: u64, + + #[serde(default = "default_port")] + pub port: u16, + + #[serde(default = "default_log_level")] + pub log_level: LogLevel, + + #[serde(default = "default_sys_name")] + pub system_name: String, + + #[serde(default)] + pub environment: Environment, +} + +impl Default for Api { + fn default() -> Self { + Self { + domain: default_domain(), + request_timeout: default_request_timeout(), + port: default_port(), + log_level: default_log_level(), + system_name: default_sys_name(), + environment: Environment::default(), + } + } +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct DatabaseOptions { + #[serde(default = "default_database")] + pub url: Url, + pub pool_size: u32, +} + +impl DatabaseOptions { + pub fn create(url: & Url, pool_size: Option) -> Self { + Self { + url: url.to_owned(), + pool_size: pool_size.unwrap_or_else(|| { + let def = 100; + tracing::debug!(size = def, "Setting default db pool size"); + def + }), + } + } +} + +fn default_database() -> Url { + Url::parse("postgres://postgres:password@localhost:5432/sellershut") + .expect("valid default DATABASE url") +} + +impl Default for DatabaseOptions { + fn default() -> Self { + Self { + url: default_database(), + pool_size: 100 + + } + } +} + +fn default_sys_name() -> String { + "sellershut".to_string() +} + +fn default_domain() -> String { + "localhost".to_string() +} + +fn default_request_timeout() -> u64 { + 10 +} + +fn default_port() -> u16 { + 2210 +} + +fn default_log_level() -> LogLevel { + LogLevel::Debug +} + + +impl Config { + pub fn merge_with_cli(&mut self, cli: &Cli) { + let server = &mut self.server; + let dsn = &mut self.database; + + if let Some(port) = cli.port { + server.port = port; + } + + if let Some(domain) = &cli.domain { + server.domain = domain.to_string(); + } + + if let Some(log_level) = &cli.log_level { + server.log_level = *log_level; + } + + if let Some(timeout) = cli.timeout_duration { + server.request_timeout = timeout; + } + + if let Some(db_url) = &cli.db { + dsn.url = db_url.clone(); + } + } +} + +#[cfg(test)] +mod tests { + use crate::config::Config; + + #[test] + fn config_file() { + let s = include_str!("../../sellershut.toml"); + assert!(toml::from_str::(s).is_ok()) + } +} diff --git a/src/config/port.rs b/src/config/port.rs new file mode 100644 index 0000000..01e305b --- /dev/null +++ b/src/config/port.rs @@ -0,0 +1,42 @@ +use std::ops::RangeInclusive; + +const PORT_RANGE: RangeInclusive = 1..=65535; + +pub fn port_in_range(s: &str) -> Result { + let port = s + .parse() + .map_err(|_| format!("{s} is not a valid port number"))?; + + if PORT_RANGE.contains(&port) { + Ok(port as u16) + } else { + Err(format!( + "port not in range {}-{}", + PORT_RANGE.start(), + PORT_RANGE.end() + )) + } +} + +#[cfg(test)] +mod tests { + use rand::Rng; + + use super::*; + + #[test] + fn in_port_range() { + let mut rng = rand::rng(); + let num = rng.random_range(PORT_RANGE); + + assert!(port_in_range(&num.to_string()).is_ok()); + } + + #[test] + fn outside_port_range() { + let mut rng = rand::rng(); + let num = rng.random_range((65535 + 1)..=usize::MAX); + + assert!(port_in_range(&num.to_string()).is_err()); + } +} -- cgit v1.2.3