diff options
| author | rtkay123 <dev@kanjala.com> | 2026-04-04 10:51:18 +0200 |
|---|---|---|
| committer | rtkay123 <dev@kanjala.com> | 2026-04-04 10:51:18 +0200 |
| commit | 19c25138f88acf19c9a959a58de4f58e54026ebc (patch) | |
| tree | bd854f20c539770a92fb451503b4c6d132c110a6 /crates | |
| parent | 41d90f42c37df06dabfd717d19f3dc72b5ba2d11 (diff) | |
| download | sellershut-19c25138f88acf19c9a959a58de4f58e54026ebc.tar.bz2 sellershut-19c25138f88acf19c9a959a58de4f58e54026ebc.zip | |
feat: connect to db
Diffstat (limited to 'crates')
26 files changed, 730 insertions, 59 deletions
diff --git a/crates/api-auth/Cargo.toml b/crates/api-auth/Cargo.toml new file mode 100644 index 0000000..7df9411 --- /dev/null +++ b/crates/api-auth/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "api-auth" +version = "0.0.0" +edition = "2024" +license.workspace = true +readme.workspace = true +documentation.workspace = true +homepage.workspace = true + +[dependencies] +api-core = { workspace = true, features = ["auth", "users"] } +async-trait.workspace = true +oauth2 = "5.0.0" +secrecy.workspace = true +serde.workspace = true +sqlx.workspace = true +thiserror.workspace = true +utoipa = { workspace = true, optional = true } +url.workspace = true + +[features] +discord = [] +utoipa = ["dep:utoipa", "serde/derive"] diff --git a/crates/api-auth/src/discord/mod.rs b/crates/api-auth/src/discord/mod.rs new file mode 100644 index 0000000..a39722d --- /dev/null +++ b/crates/api-auth/src/discord/mod.rs @@ -0,0 +1,30 @@ +use api_core::models::user::User; +use async_trait::async_trait; +use sqlx::PgPool; + +use crate::{BasicClient, OauthDriver, error::AuthError}; + +#[derive(Clone, Debug)] +pub struct AuthServiceDiscord { + database: PgPool, + client: BasicClient, +} + +impl AuthServiceDiscord { + pub fn new(database: PgPool, client: BasicClient) -> Self { + Self { database, client } + } +} + +#[async_trait] +impl OauthDriver for AuthServiceDiscord { + async fn get_auth_token(&self) -> Result<String, AuthError> { + todo!() + } + async fn get_user(&self) -> Result<User, AuthError> { + todo!() + } + async fn create_session(&self, _user: &User) { + todo!() + } +} diff --git a/crates/api-auth/src/error.rs b/crates/api-auth/src/error.rs new file mode 100644 index 0000000..ec60e51 --- /dev/null +++ b/crates/api-auth/src/error.rs @@ -0,0 +1,25 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum AuthClientError { + #[error("missing field: {0}")] + MissingField(&'static str), + #[error("invalid auth url: {0}")] + InvalidAuthUrl(#[from] oauth2::url::ParseError), + #[error("invalid token url: {0}")] + InvalidTokenUrl(#[source] oauth2::url::ParseError), + #[error("invalid redirect url: {0}")] + InvalidRedirectUrl(#[source] oauth2::url::ParseError), +} + +#[derive(Debug, Error)] +pub enum AuthError { + #[error("missing field: {0}")] + MissingField(&'static str), + #[error("invalid auth url: {0}")] + InvalidAuthUrl(#[from] oauth2::url::ParseError), + #[error("invalid token url: {0}")] + InvalidTokenUrl(#[source] oauth2::url::ParseError), + #[error("invalid redirect url: {0}")] + InvalidRedirectUrl(#[source] oauth2::url::ParseError), +} diff --git a/crates/api-auth/src/lib.rs b/crates/api-auth/src/lib.rs new file mode 100644 index 0000000..284b772 --- /dev/null +++ b/crates/api-auth/src/lib.rs @@ -0,0 +1,69 @@ +#[cfg(feature = "discord")] +pub mod discord; + +mod error; +use api_core::auth::AuthClientConfig; +use api_core::auth::provider::OauthProvider; +use api_core::models::user::User; +pub use error::AuthClientError; + +use oauth2::{EndpointNotSet, EndpointSet}; + +type C = oauth2::basic::BasicClient< + EndpointSet, + EndpointNotSet, + EndpointNotSet, + EndpointNotSet, + EndpointSet, +>; + +#[derive(Clone, Debug)] +pub struct BasicClient(C); + +#[async_trait::async_trait] +pub trait OauthDriver: Send + Sync + std::fmt::Debug { + async fn get_auth_token(&self) -> Result<String, AuthError>; + async fn get_user(&self) -> Result<User, AuthError>; + async fn create_session(&self, user: &User); +} + +use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; +use sqlx::PgPool; +use std::collections::HashMap; +use std::sync::Arc; +use std::{convert::TryFrom, ops::Deref}; + +use crate::error::AuthError; + +pub struct OauthService { + clients: HashMap<OauthProvider, Arc<dyn OauthDriver>>, +} + +impl Deref for BasicClient { + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl TryFrom<AuthClientConfig> for BasicClient { + type Error = AuthClientError; + + fn try_from(value: AuthClientConfig) -> Result<Self, Self::Error> { + let auth_url = AuthUrl::new(value.auth_url).map_err(AuthClientError::InvalidAuthUrl)?; + + let token_url = TokenUrl::new(value.token_uri).map_err(AuthClientError::InvalidTokenUrl)?; + + let redirect_url = + RedirectUrl::new(value.redirect_uri).map_err(AuthClientError::InvalidRedirectUrl)?; + + Ok(Self( + oauth2::basic::BasicClient::new(ClientId::new(value.client_id)) + .set_client_secret(ClientSecret::new(value.client_secret)) + .set_auth_uri(auth_url) + .set_token_uri(token_url) + .set_redirect_uri(redirect_url), + )) + } +} diff --git a/crates/api-base/src/lib.rs b/crates/api-base/src/lib.rs deleted file mode 100644 index 9c632e0..0000000 --- a/crates/api-base/src/lib.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod health; -mod version; -pub use version::*; diff --git a/crates/api-base/src/version.rs b/crates/api-base/src/version.rs deleted file mode 100644 index 0652c6e..0000000 --- a/crates/api-base/src/version.rs +++ /dev/null @@ -1,45 +0,0 @@ -#[derive(Debug)] -#[cfg_attr( - feature = "utoipa", - derive(utoipa::ToSchema, serde::Deserialize, serde::Serialize), - schema(example = "v0"), - serde(rename_all = "lowercase") -)] -pub enum Version { - V0, -} - -#[cfg(feature = "axum")] -mod request { - use super::*; - use axum::RequestPartsExt; - use axum::extract::{FromRequestParts, Path}; - use axum::http::StatusCode; - use axum::http::request::Parts; - use axum::response::{IntoResponse, Response}; - use std::collections::HashMap; - - impl<S> FromRequestParts<S> for Version - where - S: Send + Sync, - { - type Rejection = Response; - - async fn from_request_parts( - parts: &mut Parts, - _state: &S, - ) -> Result<Self, Self::Rejection> { - let params: Path<HashMap<String, String>> = - parts.extract().await.map_err(IntoResponse::into_response)?; - - let version = params - .get("apiVersion") - .ok_or_else(|| (StatusCode::NOT_FOUND, "version param missing").into_response())?; - - match version.as_str() { - "v0" => Ok(Version::V0), - _ => Err((StatusCode::NOT_FOUND, "unknown version").into_response()), - } - } - } -} diff --git a/crates/api-base/Cargo.toml b/crates/api-core/Cargo.toml index e15c19b..ae9f8f7 100644 --- a/crates/api-base/Cargo.toml +++ b/crates/api-core/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "api-base" +name = "api-core" version = "0.0.0" edition = "2024" license.workspace = true @@ -13,5 +13,12 @@ serde.workspace = true utoipa = { workspace = true, optional = true } [features] +auth = [] +auth-discord = ["auth"] axum = ["dep:axum"] +users = [] utoipa = ["dep:utoipa", "serde/derive", "axum"] + +[dev-dependencies] +tokio = { workspace = true, features = ["macros"] } +tower = { workspace = true, features = ["util"] } diff --git a/crates/api-core/src/auth/mod.rs b/crates/api-core/src/auth/mod.rs new file mode 100644 index 0000000..1045122 --- /dev/null +++ b/crates/api-core/src/auth/mod.rs @@ -0,0 +1,9 @@ +pub mod provider; + +pub struct AuthClientConfig { + pub client_id: String, + pub client_secret: String, + pub redirect_uri: String, + pub token_uri: String, + pub auth_url: String, +} diff --git a/crates/api-core/src/auth/provider.rs b/crates/api-core/src/auth/provider.rs new file mode 100644 index 0000000..803472f --- /dev/null +++ b/crates/api-core/src/auth/provider.rs @@ -0,0 +1,14 @@ +#[non_exhaustive] +/// The oauth provider +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[cfg_attr( + feature = "utoipa", + derive(utoipa::ToSchema, serde::Deserialize, serde::Serialize), + schema(example = "v0"), + serde(rename_all = "camelCase") +)] +pub enum OauthProvider { + /// Discord + #[cfg(feature = "auth-discord")] + Discord, +} diff --git a/crates/api-base/src/health/apidoc.rs b/crates/api-core/src/health/apidoc.rs index 45b8754..45b8754 100644 --- a/crates/api-base/src/health/apidoc.rs +++ b/crates/api-core/src/health/apidoc.rs diff --git a/crates/api-base/src/health/mod.rs b/crates/api-core/src/health/mod.rs index a84dc85..a84dc85 100644 --- a/crates/api-base/src/health/mod.rs +++ b/crates/api-core/src/health/mod.rs diff --git a/crates/api-core/src/lib.rs b/crates/api-core/src/lib.rs new file mode 100644 index 0000000..8c22b49 --- /dev/null +++ b/crates/api-core/src/lib.rs @@ -0,0 +1,7 @@ +pub mod health; +pub mod models; +mod version; +pub use version::*; + +#[cfg(feature = "auth")] +pub mod auth; diff --git a/crates/api-core/src/models/mod.rs b/crates/api-core/src/models/mod.rs new file mode 100644 index 0000000..0f2db76 --- /dev/null +++ b/crates/api-core/src/models/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "users")] +pub mod user; diff --git a/crates/api-core/src/models/user.rs b/crates/api-core/src/models/user.rs new file mode 100644 index 0000000..e6ad9f0 --- /dev/null +++ b/crates/api-core/src/models/user.rs @@ -0,0 +1 @@ +pub struct User {} diff --git a/crates/api-core/src/version.rs b/crates/api-core/src/version.rs new file mode 100644 index 0000000..5f84f3e --- /dev/null +++ b/crates/api-core/src/version.rs @@ -0,0 +1,96 @@ +#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[cfg_attr( + feature = "utoipa", + derive(utoipa::ToSchema, serde::Deserialize, serde::Serialize), + schema(example = "v0"), + serde(rename_all = "camelCase") +)] +pub enum Version { + V0, +} + +#[cfg(feature = "axum")] +mod request { + use super::*; + use axum::RequestPartsExt; + use axum::extract::{FromRequestParts, Path}; + use axum::http::StatusCode; + use axum::http::request::Parts; + use axum::response::{IntoResponse, Response}; + use std::collections::HashMap; + + impl<S> FromRequestParts<S> for Version + where + S: Send + Sync, + { + type Rejection = Response; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result<Self, Self::Rejection> { + let params: Path<HashMap<String, String>> = + parts.extract().await.map_err(IntoResponse::into_response)?; + + let version = params + .get("apiVersion") + .ok_or_else(|| (StatusCode::NOT_FOUND, "version param missing").into_response())?; + + match version.as_str() { + "v0" => Ok(Version::V0), + _ => Err((StatusCode::NOT_FOUND, "unknown version").into_response()), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::{ + Router, + body::Body, + http::{Request, StatusCode}, + routing::get, + }; + use tower::ServiceExt; + + async fn handler(version: Version) -> &'static str { + match version { + Version::V0 => "ok", + } + } + + async fn check(endpoint: &str, expected: StatusCode) { + let app = app(); + let response = app + .oneshot( + Request::builder() + .uri(endpoint) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(expected, response.status()); + } + + fn app() -> Router { + Router::new().route("/{apiVersion}/test", get(handler)) + } + + #[tokio::test] + async fn valid_version_v0() { + check("/v0/test", StatusCode::OK).await + } + + #[tokio::test] + async fn unknown_version() { + check("/v1/test", StatusCode::NOT_FOUND).await + } + + #[tokio::test] + async fn missing_version_param() { + check("/test", StatusCode::NOT_FOUND).await + } +} diff --git a/crates/sellershut/Cargo.toml b/crates/sellershut/Cargo.toml index f7cd15a..14a686c 100644 --- a/crates/sellershut/Cargo.toml +++ b/crates/sellershut/Cargo.toml @@ -10,16 +10,21 @@ description = "A federated marketplace platform" [dependencies] anyhow = "1.0.102" -api-base = { workspace = true, features = ["utoipa"] } +api-auth = { path = "../api-auth", features = ["discord", "utoipa"] } +api-core = { workspace = true, features = ["auth-discord", "utoipa"] } axum = { version = "0.8.8", features = ["macros"] } bon = "3.9.1" clap = { version = "4.6.0", features = ["derive", "env"] } +secrecy = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } -tokio = { version = "1.51.0", features = ["macros", "rt", "rt-multi-thread"] } +serde_json.workspace = true +sqlx = { workspace = true, features = ["migrate"] } +tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } toml = "1.1.2" tracing.workspace = true tracing-appender = "0.2.4" tracing-subscriber = { version = "0.3.23", features = ["env-filter"] } +url = { workspace = true, features = ["serde"] } utoipa = { workspace = true, features = ["axum_extras"] } utoipa-axum = "0.2.0" utoipa-rapidoc = { version = "6.0.0", features = ["axum"], optional = true } @@ -27,7 +32,12 @@ utoipa-redoc = { version = "6.0.0", features = ["axum"], optional = true } utoipa-scalar = { version = "0.3.0", features = ["axum"], optional = true } utoipa-swagger-ui = { version = "9.0.2", features = ["axum"], optional = true } +[dev-dependencies] +tower = { workspace = true, features = ["util"] } + [features] +default = ["auth-discord"] +auth-discord = [] swagger = ["dep:utoipa-swagger-ui"] redoc = ["dep:utoipa-redoc"] rapidoc = ["dep:utoipa-rapidoc"] diff --git a/crates/sellershut/src/config/auth/discord.rs b/crates/sellershut/src/config/auth/discord.rs new file mode 100644 index 0000000..24ad711 --- /dev/null +++ b/crates/sellershut/src/config/auth/discord.rs @@ -0,0 +1,91 @@ +use anyhow::{Context, Result}; +use api_core::auth::AuthClientConfig; +use clap::Parser; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Parser, Deserialize, Serialize, Default, PartialEq, Eq)] +#[serde(default, rename_all = "kebab-case")] +pub struct DiscordClientConfig { + /// Discord OAuth client ID. + #[arg(long, env = "HUT_DISCORD_CLIENT_ID")] + pub discord_client_id: Option<String>, + + /// Discord OAuth client secret. + #[arg(long, env = "HUT_DISCORD_CLIENT_SECRET")] + pub discord_client_secret: Option<String>, + + /// Redirect URI registered with Discord OAuth. + #[arg(long, env = "HUT_DISCORD_REDIRECT_URI")] + pub discord_redirect_uri: Option<String>, + + /// Discord token endpoint URI. + #[arg(long, env = "HUT_DISCORD_TOKEN_URI")] + pub discord_token_uri: Option<String>, + + /// Discord authorization URL. + #[arg(long, env = "HUT_DISCORD_AUTH_URL")] + pub discord_auth_url: Option<String>, +} + +impl DiscordClientConfig { + pub(super) fn merge(self, higher: Self) -> Self { + Self { + discord_client_id: higher.discord_client_id.or(self.discord_client_id), + discord_client_secret: higher.discord_client_secret.or(self.discord_client_secret), + discord_redirect_uri: higher.discord_redirect_uri.or(self.discord_redirect_uri), + discord_token_uri: higher.discord_token_uri.or(self.discord_token_uri), + discord_auth_url: higher.discord_auth_url.or(self.discord_auth_url), + } + } + + pub(super) fn with_defaults(self) -> Self { + Self { + discord_client_id: self.discord_client_id, + discord_client_secret: self.discord_client_secret, + discord_redirect_uri: Some( + self.discord_redirect_uri + .unwrap_or_else(|| "http://localhost:2210/auth/discord/callback".to_string()), + ), + discord_token_uri: Some( + self.discord_token_uri + .unwrap_or_else(|| "https://discord.com/api/oauth2/token".to_string()), + ), + discord_auth_url: Some( + self.discord_auth_url + .unwrap_or_else(|| "https://discord.com/api/oauth2/authorize".to_string()), + ), + } + } + + pub(super) fn defaults() -> Self { + Self { + discord_client_id: None, + discord_client_secret: None, + discord_redirect_uri: Some("http://localhost:2210/auth/discord/callback".to_string()), + discord_token_uri: Some("https://discord.com/api/oauth2/token".to_string()), + discord_auth_url: Some("https://discord.com/api/oauth2/authorize".to_string()), + } + } +} + +impl TryFrom<DiscordClientConfig> for AuthClientConfig { + type Error = anyhow::Error; + + fn try_from(value: DiscordClientConfig) -> Result<Self> { + Ok(Self { + client_id: value + .discord_client_id + .context("missing discord_client_id")?, + client_secret: value + .discord_client_secret + .context("missing discord_client_secret")?, + redirect_uri: value + .discord_redirect_uri + .context("missing discord_redirect_uri")?, + token_uri: value + .discord_token_uri + .context("missing discord_token_uri")?, + auth_url: value.discord_auth_url.context("missing discord_auth_url")?, + }) + } +} diff --git a/crates/sellershut/src/config/auth/mod.rs b/crates/sellershut/src/config/auth/mod.rs new file mode 100644 index 0000000..8fc2d5b --- /dev/null +++ b/crates/sellershut/src/config/auth/mod.rs @@ -0,0 +1,44 @@ +#[cfg(feature = "auth-discord")] +pub mod discord; +use clap::Parser; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Parser, Deserialize, Serialize, Default, PartialEq, Eq)] +#[serde(default, rename_all = "kebab-case")] +pub struct OauthConfig { + /// Discord OAuth configuration. + #[cfg(feature = "auth-discord")] + #[command(flatten)] + pub discord: Option<discord::DiscordClientConfig>, +} + +impl OauthConfig { + pub(super) fn merge(self, higher: Self) -> Self { + Self { + #[cfg(feature = "auth-discord")] + discord: match (self.discord, higher.discord) { + (Some(lower), Some(higher)) => Some(lower.merge(higher)), + (None, Some(higher)) => Some(higher), + (Some(lower), None) => Some(lower), + (None, None) => None, + }, + } + } + + pub(super) fn with_defaults(self) -> Self { + Self { + #[cfg(feature = "auth-discord")] + discord: self + .discord + .map(|d| d.with_defaults()) + .or_else(|| Some(discord::DiscordClientConfig::defaults())), + } + } + + pub(super) fn defaults() -> Self { + Self { + #[cfg(feature = "auth-discord")] + discord: Some(discord::DiscordClientConfig::defaults()), + } + } +} diff --git a/crates/sellershut/src/config/database/mod.rs b/crates/sellershut/src/config/database/mod.rs new file mode 100644 index 0000000..b04319b --- /dev/null +++ b/crates/sellershut/src/config/database/mod.rs @@ -0,0 +1,94 @@ +use clap::Parser; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Parser, Deserialize, Serialize, Default, PartialEq, Eq)] +#[serde(default)] +pub struct DatabaseConfig { + /// Full database connection URL. Takes precedence over the individual database fields. + #[arg(long, env = "HUT_DB_URL")] + #[serde(rename = "url", skip_serializing_if = "Option::is_none")] + pub db_url: Option<String>, + + /// Database host name or IP address. + #[arg(long, env = "HUT_DB_HOST")] + #[serde(rename = "host")] + pub db_host: Option<String>, + + /// Database port number. + #[arg(long, env = "HUT_DB_PORT")] + #[serde(rename = "port")] + pub db_port: Option<u16>, + + /// Database username. + #[arg(long, env = "HUT_DB_USERNAME")] + #[serde(rename = "username")] + pub db_username: Option<String>, + + /// Database password. + #[arg(long, env = "HUT_DB_PASSWORD")] + #[serde(rename = "password")] + pub db_password: Option<String>, + + /// Database name. + #[arg(long, env = "HUT_DB_NAME")] + #[serde(rename = "name")] + pub db_name: Option<String>, +} + +impl DatabaseConfig { + pub(super) fn merge(self, higher: Self) -> Self { + Self { + db_url: higher.db_url.or(self.db_url), + db_host: higher.db_host.or(self.db_host), + db_port: higher.db_port.or(self.db_port), + db_username: higher.db_username.or(self.db_username), + db_password: higher.db_password.or(self.db_password), + db_name: higher.db_name.or(self.db_name), + } + } + + pub(super) fn with_defaults(self) -> Self { + Self { + db_url: self.db_url, + db_host: Some(self.db_host.unwrap_or_else(|| "127.0.0.1".to_string())), + db_port: Some(self.db_port.unwrap_or(5432)), + db_username: Some(self.db_username.unwrap_or_else(|| "postgres".to_string())), + db_password: Some(self.db_password.unwrap_or_else(|| "password".to_string())), + db_name: Some(self.db_name.unwrap_or_else(|| "sellershut".to_string())), + } + } + pub(super) fn defaults() -> Self { + Self { + db_url: None, + db_host: Some("127.0.0.1".to_string()), + db_port: Some(5432), + db_username: Some("postgres".to_string()), + db_password: Some("password".to_string()), + db_name: Some("sellershut".to_string()), + } + } + + pub fn connection_url(&self) -> String { + if let Some(url) = &self.db_url { + return url.clone(); + } + + format!( + "postgres://{}:{}@{}:{}/{}", + self.db_username + .as_deref() + .expect("database username should be set after defaults"), + self.db_password + .as_deref() + .expect("database password should be set after defaults"), + self.db_host + .as_deref() + .expect("database host should be set after defaults"), + self.db_port + .expect("database port should be set after defaults"), + self.db_name + .as_deref() + .expect("database name should be set after defaults"), + ) + } +} diff --git a/crates/sellershut/src/config/mod.rs b/crates/sellershut/src/config/mod.rs index d35ba1e..389b4bc 100644 --- a/crates/sellershut/src/config/mod.rs +++ b/crates/sellershut/src/config/mod.rs @@ -1,4 +1,6 @@ +pub mod auth; pub mod cli; +pub mod database; mod server; use anyhow::Result; @@ -14,9 +16,15 @@ pub struct Config { #[arg(long, env = "HUT_CONFIG")] #[serde(skip)] config: Option<PathBuf>, - /// Server configuration. + /// General server configuration. #[command(flatten)] pub server: server::ServerConfig, + /// Auth configuration. + #[command(flatten)] + pub auth: auth::OauthConfig, + /// Database configuration. + #[command(flatten)] + pub database: database::DatabaseConfig, } impl Config { pub fn load(cli: Self) -> Result<Self> { @@ -33,6 +41,8 @@ impl Config { Self { config: higher.config.or(self.config), server: self.server.merge(higher.server), + auth: self.auth.merge(higher.auth), + database: self.database.merge(higher.database), } } @@ -40,13 +50,17 @@ impl Config { Self { config: self.config, server: self.server.with_defaults(), + auth: self.auth.with_defaults(), + database: self.database.with_defaults(), } } - fn defaults() -> Self { + pub fn defaults() -> Self { Self { config: None, server: server::ServerConfig::defaults(), + auth: auth::OauthConfig::defaults(), + database: database::DatabaseConfig::defaults(), } } } diff --git a/crates/sellershut/src/main.rs b/crates/sellershut/src/main.rs index cb7be07..fca10e1 100644 --- a/crates/sellershut/src/main.rs +++ b/crates/sellershut/src/main.rs @@ -3,17 +3,26 @@ mod server; mod state; use std::{ + collections::HashMap, net::{Ipv6Addr, SocketAddr}, sync::Arc, }; use anyhow::{Context, Result}; -use api_base::health::BaseService; +use api_auth::{BasicClient, OauthDriver, discord::AuthServiceDiscord}; +use api_core::{ + auth::{AuthClientConfig, provider::OauthProvider}, + health::BaseService, +}; use clap::Parser; -use tokio::net::TcpListener; +use sqlx::PgPool; +use tokio::{net::TcpListener}; use tracing::info; -use crate::{config::cli, state::AppState}; +use crate::{ + config::{auth::OauthConfig, cli}, + state::AppState, +}; #[tokio::main] async fn main() -> Result<()> { @@ -30,10 +39,16 @@ async fn main() -> Result<()> { cfg.server.log_directory.as_ref(), )?; + let database = state::postgres(&cfg.database.connection_url(), 100).await?; + + let auth_clients = build_oauth_client(&cfg.auth, database)?; + let state = AppState::builder() .log_handle(log_handle) .base_service(Arc::new(BaseService)) + .auth_clients(auth_clients) .build(); + let addr = SocketAddr::from(( Ipv6Addr::UNSPECIFIED, cfg.server.port.context("missing port")?, @@ -48,3 +63,23 @@ async fn main() -> Result<()> { Ok(()) } + +fn build_oauth_client( + config: &OauthConfig, + database: PgPool, +) -> Result<HashMap<OauthProvider, Arc<dyn OauthDriver>>> { + let auth = config.to_owned(); + let mut collection: HashMap<OauthProvider, Arc<dyn OauthDriver>> = HashMap::new(); + + #[cfg(feature = "auth-discord")] + { + use api_core::auth::provider::OauthProvider; + + let c = AuthClientConfig::try_from(auth.discord.context("missing discord config")?)?; + let client = BasicClient::try_from(c)?; + let auth_service = Arc::new(AuthServiceDiscord::new(database, client)); + collection.insert(OauthProvider::Discord, auth_service); + } + + Ok(collection) +} diff --git a/crates/sellershut/src/server/api/mod.rs b/crates/sellershut/src/server/api/mod.rs index 0fd48c6..c227f59 100644 --- a/crates/sellershut/src/server/api/mod.rs +++ b/crates/sellershut/src/server/api/mod.rs @@ -1,4 +1,4 @@ -use api_base::health::ApiDocBase; +use api_core::health::ApiDocBase; use axum::Router; use utoipa::OpenApi; use utoipa_axum::router::OpenApiRouter; diff --git a/crates/sellershut/src/server/api/routes/logs/mod.rs b/crates/sellershut/src/server/api/routes/logs/mod.rs index 8718d86..9ea0a39 100644 --- a/crates/sellershut/src/server/api/routes/logs/mod.rs +++ b/crates/sellershut/src/server/api/routes/logs/mod.rs @@ -52,3 +52,57 @@ pub async fn reload(State(state): State<AppState>, Json(body): Json<LogLevel>) - StatusCode::BAD_REQUEST } } + +#[cfg(test)] +mod tests { + use axum::{ + Router, + body::Body, + http::{Request, StatusCode, header}, + }; + + use anyhow::Result; + use tower::ServiceExt; + + use crate::server::{self}; + + async fn check( + app: Router, + method: &str, + body: String, + expected_result: StatusCode, + ) -> Result<()> { + let response = app + .oneshot( + Request::builder() + .method(method) + .header(header::CONTENT_TYPE, "application/json") + .uri("/api/logging") + .body(Body::from(body))?, + ) + .await?; + let actual_result = response.status(); + assert_eq!(expected_result, actual_result); + Ok(()) + } + + #[tokio::test] + async fn log_update() -> Result<()> { + let app = server::boostrap::test_app().await; + + let info = serde_json::json!({ + "logLevel": "info", + }); + + check( + app.clone(), + "GET", + info.to_string(), + StatusCode::METHOD_NOT_ALLOWED, + ) + .await?; + + check(app.clone(), "PATCH", info.to_string(), StatusCode::OK).await?; + Ok(()) + } +} diff --git a/crates/sellershut/src/server/api/routes/mod.rs b/crates/sellershut/src/server/api/routes/mod.rs index f343742..1de8e80 100644 --- a/crates/sellershut/src/server/api/routes/mod.rs +++ b/crates/sellershut/src/server/api/routes/mod.rs @@ -36,3 +36,39 @@ pub async fn health(State(state): State<AppState>) -> impl IntoResponse { .base_service .health(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")) } + +#[cfg(test)] +mod tests { + use axum::{ + Router, + body::Body, + http::{Request, StatusCode}, + }; + + use anyhow::Result; + use tower::ServiceExt; + + use crate::server::{self}; + + async fn check(app: Router, method: &str, expected_result: StatusCode) -> Result<()> { + let response = app + .oneshot( + Request::builder() + .method(method) + .uri("/api/health") + .body(Body::empty())?, + ) + .await?; + let actual_result = response.status(); + assert_eq!(expected_result, actual_result); + Ok(()) + } + + #[tokio::test] + async fn health() -> Result<()> { + let app = server::boostrap::test_app().await; + check(app.clone(), "GET", StatusCode::OK).await?; + check(app.clone(), "HEAD", StatusCode::OK).await?; + Ok(()) + } +} diff --git a/crates/sellershut/src/server/mod.rs b/crates/sellershut/src/server/mod.rs index f669af9..a66eed5 100644 --- a/crates/sellershut/src/server/mod.rs +++ b/crates/sellershut/src/server/mod.rs @@ -1,2 +1,44 @@ pub mod api; pub mod logs; + +#[cfg(test)] +mod boostrap { + use std::{collections::HashMap, sync::{Arc, OnceLock}}; + + use api_core::health::BaseService; + use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, reload}; + + use crate::{ + config::Config, + server::{self, logs::LogHandle}, + state::AppState, + }; + + static TEST_LOG_DATA: OnceLock<LogHandle> = OnceLock::new(); + + pub async fn test_app() -> axum::Router { + let log_handle = TEST_LOG_DATA + .get_or_init(|| { + let filter = EnvFilter::new("warn"); + let (layer, handle) = reload::Layer::new(filter); + + let subscriber = Registry::default().with(layer); + + let _ = tracing::subscriber::set_global_default(subscriber); + + handle + }) + .clone(); + let state = Arc::new(BaseService); + let config = Config::defaults(); + let auth_clients = HashMap::default(); + + let state = AppState::builder() + .log_handle(log_handle) + .base_service(state) + .auth_clients(auth_clients) + .build(); + + server::api::router(state, config).await + } +} diff --git a/crates/sellershut/src/state/mod.rs b/crates/sellershut/src/state/mod.rs index 067cc62..821d4eb 100644 --- a/crates/sellershut/src/state/mod.rs +++ b/crates/sellershut/src/state/mod.rs @@ -1,7 +1,9 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; -use api_base::health::HealthDriver; +use api_auth::OauthDriver; +use api_core::{auth::provider::OauthProvider, health::HealthDriver}; use bon::Builder; +use sqlx::PgPool; use crate::server::logs::LogHandle; @@ -9,4 +11,18 @@ use crate::server::logs::LogHandle; pub struct AppState { pub base_service: Arc<dyn HealthDriver>, pub log_handle: LogHandle, + pub auth_clients: HashMap<OauthProvider, Arc<dyn OauthDriver>>, +} + +pub async fn postgres(config: &str, pool_size: u32) -> anyhow::Result<PgPool> { + let pg = sqlx::postgres::PgPoolOptions::new() + // The default connection limit for a Postgres server is 100 connections, with 3 reserved for superusers. + // + // If you're deploying your application with multiple replicas, then the total + // across all replicas should not exceed the Postgres connection limit + // (max_connections postgresql.conf). + .max_connections(pool_size) + .connect(config) + .await?; + Ok(pg) } |
