diff options
Diffstat (limited to 'src/config')
| -rw-r--r-- | src/config/cli.rs | 23 | ||||
| -rw-r--r-- | src/config/mod.rs | 60 |
2 files changed, 70 insertions, 13 deletions
diff --git a/src/config/cli.rs b/src/config/cli.rs index dab7216..5254135 100644 --- a/src/config/cli.rs +++ b/src/config/cli.rs @@ -1,6 +1,7 @@ use std::path::PathBuf; use clap::Parser; +use serde::Deserialize; use url::Url; use crate::config::{logging::LogLevel, port::port_in_range}; @@ -49,25 +50,33 @@ pub struct Cli { pub oauth: Option<OAuth>, } -#[derive(Debug, Clone, Parser)] +#[derive(Debug, Clone, Parser, Deserialize)] pub struct OAuth { #[cfg(feature = "oauth-discord")] #[command(flatten)] discord: DiscordOauth, - #[arg(long)] + #[arg(long, env = "OAUTH_REDIRECT_URL")] oauth_redirect_url: Option<Url>, } #[cfg(feature = "oauth-discord")] -#[derive(Debug, Clone, Parser)] +#[derive(Debug, Clone, Parser, Deserialize)] pub struct DiscordOauth { - #[arg(long)] + #[arg(long, env = "OAUTH_DISCORD_CLIENT_ID")] discord_client_id: Option<String>, - #[arg(long)] + #[arg(long, env = "OAUTH_DISCORD_CLIENT_SECRET")] discord_client_secret: Option<String>, - #[arg(long)] + #[arg( + long, + env = "OAUTH_DISCORD_TOKEN_URL", + default_value = "https://discord.com/api/oauth2/token" + )] discord_token_url: Option<Url>, - #[arg(long)] + #[arg( + long, + env = "OAUTH_DISCORD_AUTH_URL", + default_value = "https://discord.com/api/oauth2/authorize?response_type=code" + )] discord_auth_url: Option<Url>, } diff --git a/src/config/mod.rs b/src/config/mod.rs index 45e12c3..19ee241 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,10 +2,12 @@ mod cli; mod logging; mod port; pub use cli::Cli; +#[cfg(feature = "oauth-discord")] +use secrecy::SecretString; use serde::Deserialize; use url::Url; -use crate::{config::logging::LogLevel}; +use crate::config::logging::LogLevel; #[derive(Default, Deserialize, Debug, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] @@ -22,6 +24,8 @@ pub struct Config { pub database: DatabaseOptions, #[serde(default)] pub server: Api, + #[serde(default)] + pub oauth: OAuth, } #[derive(Debug, Deserialize)] @@ -46,6 +50,52 @@ pub struct Api { pub environment: Environment, } +#[derive(Debug, Clone, Deserialize)] +pub struct OAuth { + #[cfg(feature = "oauth-discord")] + pub discord: DiscordOauth, + #[serde(rename = "redirect-url")] + pub oauth_redirect_url: Url, +} + +#[cfg(feature = "oauth-discord")] +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct DiscordOauth { + pub client_id: String, + pub client_secret: SecretString, + #[serde(default = "discord_token_url")] + pub token_url: Url, + #[serde(default = "discord_auth_url")] + pub auth_url: Url, +} + +fn discord_token_url() -> Url { + Url::parse("https://discord.com/api/oauth2/authorize?response_type=code").expect("valid url") +} + +fn discord_auth_url() -> Url { + Url::parse("https://discord.com/api/oauth2/authorize?response_type=code").expect("valid url") +} + +fn redirect_url() -> Url { + Url::parse("http://127.0.0.1:2210/auth/authorised").expect("valid url") +} + +impl Default for OAuth { + fn default() -> Self { + Self { + discord: DiscordOauth { + client_id: String::default(), + client_secret: SecretString::default(), + token_url: discord_token_url(), + auth_url: discord_auth_url(), + }, + oauth_redirect_url: redirect_url(), + } +} +} + impl Default for Api { fn default() -> Self { Self { @@ -68,7 +118,7 @@ pub struct DatabaseOptions { } impl DatabaseOptions { - pub fn create(url: & Url, pool_size: Option<u32>) -> Self { + pub fn create(url: &Url, pool_size: Option<u32>) -> Self { Self { url: url.to_owned(), pool_size: pool_size.unwrap_or_else(|| { @@ -89,8 +139,7 @@ impl Default for DatabaseOptions { fn default() -> Self { Self { url: default_database(), - pool_size: 100 - + pool_size: 100, } } } @@ -115,7 +164,6 @@ fn default_log_level() -> LogLevel { LogLevel::Debug } - impl Config { pub fn merge_with_cli(&mut self, cli: &Cli) { let server = &mut self.server; @@ -149,7 +197,7 @@ mod tests { #[test] fn config_file() { - let s = include_str!("../../sellershut.toml"); + let s = include_str!("../../misc/sellershut.toml"); assert!(toml::from_str::<Config>(s).is_ok()) } } |
