diff options
| author | rtkay123 <dev@kanjala.com> | 2026-02-01 13:33:07 +0200 |
|---|---|---|
| committer | rtkay123 <dev@kanjala.com> | 2026-02-01 13:33:07 +0200 |
| commit | ce65d9eeafcd1f9d5c3adef1c9b1af6258ee711a (patch) | |
| tree | 953f6c49f8affd667ec740a949b2d93f82b7d31b /src | |
| parent | 6a9d21bc87f8a738e14f27a1305bf04d0c4b7a0c (diff) | |
| download | sellershut-ce65d9eeafcd1f9d5c3adef1c9b1af6258ee711a.tar.bz2 sellershut-ce65d9eeafcd1f9d5c3adef1c9b1af6258ee711a.zip | |
feat: conn to db
Diffstat (limited to 'src')
| -rw-r--r-- | src/config/cli.rs | 101 | ||||
| -rw-r--r-- | src/config/logging.rs | 68 | ||||
| -rw-r--r-- | src/config/mod.rs | 155 | ||||
| -rw-r--r-- | src/config/port.rs | 42 | ||||
| -rw-r--r-- | src/logging/mod.rs | 19 | ||||
| -rw-r--r-- | src/main.rs | 40 | ||||
| -rw-r--r-- | src/server/mod.rs | 10 | ||||
| -rw-r--r-- | src/server/shutdown/mod.rs | 29 | ||||
| -rw-r--r-- | src/server/state/database.rs | 17 | ||||
| -rw-r--r-- | src/server/state/mod.rs | 17 |
10 files changed, 496 insertions, 2 deletions
diff --git a/src/config/cli.rs b/src/config/cli.rs new file mode 100644 index 0000000..dab7216 --- /dev/null +++ b/src/config/cli.rs @@ -0,0 +1,101 @@ +use std::path::PathBuf; + +use clap::Parser; +use url::Url; + +use crate::config::{logging::LogLevel, port::port_in_range}; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None, name = env!("CARGO_PKG_NAME"))] +pub struct Cli { + /// Sets the port the server listens on + #[arg(short, long, default_value = "2210", env = "PORT", value_parser = port_in_range)] + pub port: Option<u16>, + + /// Server's domain + #[arg(short, long, default_value = "localhost", env = "DOMAIN")] + pub domain: Option<String>, + + /// Sets a custom config file + #[arg(short, long, value_name = "FILE")] + pub config: Option<PathBuf>, + + /// Sets the application log level + #[arg(short, long, value_enum, env = "LOG_LEVEL", default_value = "debug")] + pub log_level: Option<LogLevel>, + + /// Request timeout duration (in seconds) + #[arg(short, long, env = "TIMEOUT_SECONDS", default_value = "10")] + pub timeout_duration: Option<u64>, + + /// Users database connection string + #[arg( + long, + env = "USERS_DATABASE_URL", + default_value = "postgres://postgres:password@localhost:5432/sellershut" + )] + pub db: Option<Url>, + + /// Server's system name + #[arg(short, long, default_value = "sellershut", env = "SYSTEM_NAME")] + pub system_name: Option<String>, + + /// Server's system name + #[arg(short, long, default_value = "prod", env = "SYSTEM_NAME")] + pub environment: Option<String>, + + /// Oauth optionas + #[command(flatten)] + pub oauth: Option<OAuth>, +} + +#[derive(Debug, Clone, Parser)] +pub struct OAuth { + #[cfg(feature = "oauth-discord")] + #[command(flatten)] + discord: DiscordOauth, + #[arg(long)] + oauth_redirect_url: Option<Url>, +} + +#[cfg(feature = "oauth-discord")] +#[derive(Debug, Clone, Parser)] +pub struct DiscordOauth { + #[arg(long)] + discord_client_id: Option<String>, + #[arg(long)] + discord_client_secret: Option<String>, + #[arg(long)] + discord_token_url: Option<Url>, + #[arg(long)] + discord_auth_url: Option<Url>, +} + +#[cfg(test)] +impl Default for Cli { + fn default() -> Self { + let url = Url::parse("postgres://postgres:password@localhost:5432/sellershut").ok(); + Self { + port: Default::default(), + config: Default::default(), + log_level: Default::default(), + timeout_duration: Some(10), + domain: Default::default(), + system_name: Default::default(), + environment: Default::default(), + oauth: None, + db: url, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_cli() { + use clap::CommandFactory; + Cli::command().debug_assert(); + } +} diff --git a/src/config/logging.rs b/src/config/logging.rs new file mode 100644 index 0000000..c6cfe7f --- /dev/null +++ b/src/config/logging.rs @@ -0,0 +1,68 @@ +use clap::ValueEnum; +use serde::Deserialize; + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug, Default, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// The "trace" level. + /// + /// Designates very low priority, often extremely verbose, information. + Trace = 0, + /// The "debug" level. + /// + /// Designates lower priority information. + #[default] + Debug = 1, + /// The "info" level. + /// + /// Designates useful information. + Info = 2, + /// The "warn" level. + /// + /// Designates hazardous situations. + Warn = 3, + /// The "error" level. + /// + /// Designates very serious errors. + Error = 4, +} + +impl From<LogLevel> for tracing::Level { + fn from(value: LogLevel) -> Self { + match value { + LogLevel::Trace => tracing::Level::TRACE, + LogLevel::Debug => tracing::Level::DEBUG, + LogLevel::Info => tracing::Level::INFO, + LogLevel::Warn => tracing::Level::WARN, + LogLevel::Error => tracing::Level::ERROR, + } + } +} + +#[cfg(test)] +mod tests { + use crate::config::logging::LogLevel; + + fn check(level: LogLevel, value: &str) { + let level = tracing::Level::from(level); + assert_eq!(level.to_string().to_lowercase(), value); + } + + #[test] + fn loglevel() { + let level = LogLevel::Trace; + check(level, "trace"); + + let level = LogLevel::Debug; + check(level, "debug"); + + let level = LogLevel::Info; + check(level, "info"); + + let level = LogLevel::Warn; + check(level, "warn"); + + let level = LogLevel::Error; + check(level, "error"); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..45e12c3 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,155 @@ +mod cli; +mod logging; +mod port; +pub use cli::Cli; +use serde::Deserialize; +use url::Url; + +use crate::{config::logging::LogLevel}; + +#[derive(Default, Deserialize, Debug, PartialEq, Eq)] +#[serde(rename_all = "kebab-case")] +pub enum Environment { + #[default] + Dev, + Prod, +} + +#[derive(Debug, Deserialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct Config { + #[serde(default)] + pub database: DatabaseOptions, + #[serde(default)] + pub server: Api, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct Api { + #[serde(default = "default_domain")] + pub domain: String, + + #[serde(default = "default_request_timeout")] + pub request_timeout: u64, + + #[serde(default = "default_port")] + pub port: u16, + + #[serde(default = "default_log_level")] + pub log_level: LogLevel, + + #[serde(default = "default_sys_name")] + pub system_name: String, + + #[serde(default)] + pub environment: Environment, +} + +impl Default for Api { + fn default() -> Self { + Self { + domain: default_domain(), + request_timeout: default_request_timeout(), + port: default_port(), + log_level: default_log_level(), + system_name: default_sys_name(), + environment: Environment::default(), + } + } +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct DatabaseOptions { + #[serde(default = "default_database")] + pub url: Url, + pub pool_size: u32, +} + +impl DatabaseOptions { + pub fn create(url: & Url, pool_size: Option<u32>) -> Self { + Self { + url: url.to_owned(), + pool_size: pool_size.unwrap_or_else(|| { + let def = 100; + tracing::debug!(size = def, "Setting default db pool size"); + def + }), + } + } +} + +fn default_database() -> Url { + Url::parse("postgres://postgres:password@localhost:5432/sellershut") + .expect("valid default DATABASE url") +} + +impl Default for DatabaseOptions { + fn default() -> Self { + Self { + url: default_database(), + pool_size: 100 + + } + } +} + +fn default_sys_name() -> String { + "sellershut".to_string() +} + +fn default_domain() -> String { + "localhost".to_string() +} + +fn default_request_timeout() -> u64 { + 10 +} + +fn default_port() -> u16 { + 2210 +} + +fn default_log_level() -> LogLevel { + LogLevel::Debug +} + + +impl Config { + pub fn merge_with_cli(&mut self, cli: &Cli) { + let server = &mut self.server; + let dsn = &mut self.database; + + if let Some(port) = cli.port { + server.port = port; + } + + if let Some(domain) = &cli.domain { + server.domain = domain.to_string(); + } + + if let Some(log_level) = &cli.log_level { + server.log_level = *log_level; + } + + if let Some(timeout) = cli.timeout_duration { + server.request_timeout = timeout; + } + + if let Some(db_url) = &cli.db { + dsn.url = db_url.clone(); + } + } +} + +#[cfg(test)] +mod tests { + use crate::config::Config; + + #[test] + fn config_file() { + let s = include_str!("../../sellershut.toml"); + assert!(toml::from_str::<Config>(s).is_ok()) + } +} diff --git a/src/config/port.rs b/src/config/port.rs new file mode 100644 index 0000000..01e305b --- /dev/null +++ b/src/config/port.rs @@ -0,0 +1,42 @@ +use std::ops::RangeInclusive; + +const PORT_RANGE: RangeInclusive<usize> = 1..=65535; + +pub fn port_in_range(s: &str) -> Result<u16, String> { + let port = s + .parse() + .map_err(|_| format!("{s} is not a valid port number"))?; + + if PORT_RANGE.contains(&port) { + Ok(port as u16) + } else { + Err(format!( + "port not in range {}-{}", + PORT_RANGE.start(), + PORT_RANGE.end() + )) + } +} + +#[cfg(test)] +mod tests { + use rand::Rng; + + use super::*; + + #[test] + fn in_port_range() { + let mut rng = rand::rng(); + let num = rng.random_range(PORT_RANGE); + + assert!(port_in_range(&num.to_string()).is_ok()); + } + + #[test] + fn outside_port_range() { + let mut rng = rand::rng(); + let num = rng.random_range((65535 + 1)..=usize::MAX); + + assert!(port_in_range(&num.to_string()).is_err()); + } +} diff --git a/src/logging/mod.rs b/src/logging/mod.rs new file mode 100644 index 0000000..3d2ddfe --- /dev/null +++ b/src/logging/mod.rs @@ -0,0 +1,19 @@ +use crate::config::Config; +use tracing::Level; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +pub fn initialise_logging(config: &Config) { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!( + "{}={},tower_http=debug,axum=trace", + env!("CARGO_CRATE_NAME"), + Level::from(config.server.log_level) + ) + .into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); +} diff --git a/src/main.rs b/src/main.rs index e7a11a9..cb8c2a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,39 @@ -fn main() { - println!("Hello, world!"); +mod config; +mod logging; +mod server; + +use std::net::{Ipv6Addr, SocketAddr}; + +use clap::Parser; +use tokio::net::TcpListener; +use tracing::info; + +use crate::{config::Config, logging::initialise_logging, server::state::{AppState }}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let cli = config::Cli::parse(); + let mut config: Config = if let Some(ref path) = cli.config { + let contents = std::fs::read_to_string(path)?; + toml::from_str(&contents)? + } else { + Default::default() + }; + config.merge_with_cli(&cli); + + initialise_logging(&config); + + let state = AppState::new(&config).await?; + let router = server::router(&config, state).await?; + + let addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, config.server.port)); + info!(port = addr.port(), "starting server"); + let listener = TcpListener::bind(addr).await?; + + // Run the server with graceful shutdown + axum::serve(listener, router) + .with_graceful_shutdown(server::shutdown::shutdown_signal()) + .await?; + + Ok(()) } diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..803135f --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,10 @@ +use axum::Router; + +use crate::{config::Config, server::state::AppState}; + +pub mod shutdown; +pub mod state; + +pub async fn router(config: &Config, state: AppState) -> anyhow::Result<Router<()>> { + todo!() +} diff --git a/src/server/shutdown/mod.rs b/src/server/shutdown/mod.rs new file mode 100644 index 0000000..0e89a0c --- /dev/null +++ b/src/server/shutdown/mod.rs @@ -0,0 +1,29 @@ +use tokio::signal; + +pub async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => { + tracing::info!("Received Ctrl+C, starting graceful shutdown"); + }, + _ = terminate => { + tracing::info!("Received SIGTERM, starting graceful shutdown"); + }, + } +} diff --git a/src/server/state/database.rs b/src/server/state/database.rs new file mode 100644 index 0000000..32d3f98 --- /dev/null +++ b/src/server/state/database.rs @@ -0,0 +1,17 @@ +use anyhow::Result; +use sqlx::{PgPool, postgres::PgPoolOptions}; +use tracing::{debug, trace}; + +use crate::config::DatabaseOptions; + + +pub(super) async fn connect(opts: &DatabaseOptions) -> Result<PgPool> { + trace!(host = ?opts.url.host(), "connecting to database"); + let pg = PgPoolOptions::new() + .max_connections(opts.pool_size) + .connect(opts.url.as_str()) + .await?; + debug!(host = ?opts.url.host(), "connected to database"); + + Ok(pg) +} diff --git a/src/server/state/mod.rs b/src/server/state/mod.rs new file mode 100644 index 0000000..f4bf029 --- /dev/null +++ b/src/server/state/mod.rs @@ -0,0 +1,17 @@ +pub mod database; + +use sqlx::PgPool; + +use crate::{config::Config}; + +pub struct AppState { + database: PgPool, +} + +impl AppState { + pub async fn new(config: &Config) -> anyhow::Result<Self> { + let database = database::connect(&config.database).await?; + + Ok(Self{database}) + } +} |
