pub mod cli; mod log_level; use std::path::PathBuf; use clap::ValueEnum; pub use cli::Cli; pub use cli::Commands; use serde::Deserialize; use tracing_subscriber::EnvFilter; use crate::WardenError; use crate::config::cli::CliEnvironment; use crate::config::cli::database::Database; macro_rules! pick { ($cli:expr, $file:expr, $name:expr, $missing:expr) => {{ let val = $cli.clone().or($file.clone()); if val.is_none() { $missing.push($name); } val }}; } #[derive(Deserialize, Default, Debug, ValueEnum, Clone, Copy)] #[serde(rename_all = "lowercase")] pub enum Environment { Development, #[default] Production, } impl From for Environment { fn from(value: CliEnvironment) -> Self { match value { CliEnvironment::Dev | CliEnvironment::Development => Self::Development, CliEnvironment::Prod | CliEnvironment::Production => Self::Production, } } } #[derive(Debug, Clone)] pub struct Configuration { pub server: Server, pub database: Database, } #[derive(Debug, Clone)] pub struct Server { pub port: u16, pub environment: Environment, pub log_level: EnvFilter, pub log_dir: PathBuf, pub timeout_secs: u64, pub pagination_limit: i64, } impl Server { pub fn merge(cli: &Cli, file: &Cli, missing: &mut Vec<&str>) -> Result { let port = pick!(cli.server.port, file.server.port, "server.port", missing); let timeout = pick!( cli.server.timeout_secs, file.server.timeout_secs, "server.timeout", missing ); let log_dir = pick!( cli.server.log_dir.clone(), file.server.log_dir.clone(), "server.log_dir", missing ); let env = pick!( cli.server.environment, file.server.environment, "server.environment", missing ); let raw_log_level = pick!( cli.server.log_level.clone(), file.server.log_level.clone(), "server.log_level", missing ); let pagination_limit = pick!( cli.server.pagination_limit.clone(), file.server.pagination_limit.clone(), "server.pagination_limit", missing ); if !missing.is_empty() { let err_msg = missing .iter() .map(|f| format!(" - {}", f)) .collect::>() .join("\n"); return Err(WardenError::Config(format!( "Missing required fields:\n{}", err_msg ))); } let log_level = tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| raw_log_level.unwrap().into()); Ok(Self { port: port.unwrap(), environment: env.unwrap().into(), log_dir: log_dir.unwrap(), timeout_secs: timeout.unwrap(), log_level, pagination_limit: pagination_limit.unwrap(), }) } } impl Configuration { pub fn merge(cli: &Cli, file: &Cli) -> Result { let mut missing = Vec::new(); let server = Server::merge(cli, file, &mut missing)?; let database = Database::merge(&cli.database, &file.database)?; Ok(Self { server, database }) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_merge_config() { let mut cli = Cli::default(); cli.server.port = Some(8080); let timeout = 30; cli.server.timeout_secs = Some(timeout); let file = Cli { server: cli::Server { environment: Some(CliEnvironment::Dev), log_level: Some("info".into()), log_dir: Some(PathBuf::from("/tmp")), timeout_secs: Some(timeout), ..Default::default() }, ..Default::default() }; let result = Configuration::merge(&cli, &file); assert!( result.is_ok(), "Merge should succeed when all fields are covered" ); let server = result.unwrap(); assert_eq!(server.server.port, 8080); } #[test] fn test_merge_all_fields_present_success() { let mut cli = Cli::default(); cli.server.port = Some(8080); let timeout = 30; cli.server.timeout_secs = Some(timeout); let file = Cli { server: cli::Server { environment: Some(CliEnvironment::Dev), log_level: Some("info".into()), log_dir: Some(PathBuf::from("/tmp")), timeout_secs: Some(timeout), ..Default::default() }, ..Default::default() }; let mut missing = vec![]; let result = Server::merge(&cli, &file, &mut missing); assert!( result.is_ok(), "Merge should succeed when all fields are covered" ); let server = result.unwrap(); assert_eq!(server.port, 8080); assert_eq!(server.timeout_secs, timeout); } #[test] fn test_merge_error_accumulation() { let cli = Cli { server: cli::Server { port: None, environment: None, log_level: None, log_dir: None, timeout_secs: None, pagination_limit: None, }, ..Default::default() }; let mut missing = vec![]; let result = Server::merge(&cli, &cli, &mut missing); dbg!(&result); match result { Err(WardenError::Config(msg)) => { assert!(msg.contains("server.port")); assert!(msg.contains("server.environment")); } _ => panic!("Expected a Config error with multiple missing fields"), } } #[test] fn test_cli_priority_over_file() { let mut cli = Cli::default(); cli.server.port = Some(9999); let file = Cli { server: cli::Server { port: Some(1111), // This should be ignored environment: Some(CliEnvironment::Prod), log_level: Some("error".into()), log_dir: Some(PathBuf::from("/var/log")), timeout_secs: Some(60), ..Default::default() }, ..Default::default() }; let mut missing = vec![]; let server = Server::merge(&cli, &file, &mut missing).expect("Merge failed"); assert_eq!(server.port, 9999, "CLI port must override File port"); } #[test] fn test_env_filter_from_raw_string() { let log_level = "warn"; let cli = Cli { server: cli::Server { port: Some(80), environment: Some(CliEnvironment::Production), log_level: Some(log_level.to_string()), log_dir: Some(PathBuf::from(".")), timeout_secs: Some(5), ..Default::default() }, ..Default::default() }; let file = Cli::default(); let mut missing = vec![]; let server = Server::merge(&cli, &file, &mut missing).unwrap(); assert_eq!(&server.log_level.to_string(), log_level) } }