diff options
-rw-r--r-- | Cargo.lock | 137 | ||||
-rw-r--r-- | crates/auth/Cargo.toml | 4 | ||||
-rw-r--r-- | crates/auth/auth.toml | 3 | ||||
-rw-r--r-- | crates/auth/migrations/20250723100947_account.sql | 1 | ||||
-rw-r--r-- | crates/auth/migrations/20250723100947_user.sql | 9 | ||||
-rw-r--r-- | crates/auth/migrations/20250723121223_oauth_account.sql | 7 | ||||
-rw-r--r-- | crates/auth/src/cnfg.rs | 1 | ||||
-rw-r--r-- | crates/auth/src/main.rs | 34 | ||||
-rw-r--r-- | crates/auth/src/server.rs | 5 | ||||
-rw-r--r-- | crates/auth/src/server/csrf_token_validation.rs | 49 | ||||
-rw-r--r-- | crates/auth/src/server/routes/authorised.rs | 95 | ||||
-rw-r--r-- | crates/auth/src/server/routes/discord/discord_auth.rs | 41 | ||||
-rw-r--r-- | crates/auth/src/state.rs | 35 |
13 files changed, 405 insertions, 16 deletions
@@ -206,6 +206,10 @@ dependencies = [ "tokio", "tower", "tower-http", + "tower-sessions", + "tower-sessions-core", + "tower-sessions-moka-store", + "tower-sessions-sqlx-store", "tracing", "url", ] @@ -513,6 +517,17 @@ dependencies = [ ] [[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] name = "core-foundation-sys" version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1448,6 +1463,7 @@ checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", + "serde", ] [[package]] @@ -1741,6 +1757,12 @@ dependencies = [ ] [[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] name = "pathdiff" version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2102,6 +2124,28 @@ dependencies = [ ] [[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + +[[package]] name = "rsa" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2918,6 +2962,22 @@ dependencies = [ ] [[package]] +name = "tower-cookies" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "151b5a3e3c45df17466454bb74e9ecedecc955269bdedbf4d150dfa393b55a36" +dependencies = [ + "axum-core", + "cookie", + "futures-util", + "http", + "parking_lot", + "pin-project-lite", + "tower-layer", + "tower-service", +] + +[[package]] name = "tower-http" version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2949,6 +3009,83 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] +name = "tower-sessions" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a05911f23e8fae446005fe9b7b97e66d95b6db589dc1c4d59f6a2d4d4927d3" +dependencies = [ + "async-trait", + "http", + "time", + "tokio", + "tower-cookies", + "tower-layer", + "tower-service", + "tower-sessions-core", + "tower-sessions-memory-store", + "tracing", +] + +[[package]] +name = "tower-sessions-core" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce8cce604865576b7751b7a6bc3058f754569a60d689328bb74c52b1d87e355b" +dependencies = [ + "async-trait", + "axum-core", + "base64", + "futures", + "http", + "parking_lot", + "rand 0.8.5", + "serde", + "serde_json", + "thiserror 2.0.12", + "time", + "tokio", + "tracing", +] + +[[package]] +name = "tower-sessions-memory-store" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb05909f2e1420135a831dd5df9f5596d69196d0a64c3499ca474c4bd3d33242" +dependencies = [ + "async-trait", + "time", + "tokio", + "tower-sessions-core", +] + +[[package]] +name = "tower-sessions-moka-store" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a5e622001aa59953f422ade78a0fa0d1f4d2566c9bf697bffe6aa89f1438f08" +dependencies = [ + "async-trait", + "moka", + "time", + "tower-sessions-core", +] + +[[package]] +name = "tower-sessions-sqlx-store" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e054622079f57fc1a7d6a6089c9334f963d62028fe21dc9eddd58af9a78480b3" +dependencies = [ + "async-trait", + "rmp-serde", + "sqlx", + "thiserror 1.0.69", + "time", + "tower-sessions-core", +] + +[[package]] name = "tracing" version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/crates/auth/Cargo.toml b/crates/auth/Cargo.toml index b5e53d9..b6ad707 100644 --- a/crates/auth/Cargo.toml +++ b/crates/auth/Cargo.toml @@ -25,6 +25,10 @@ time = { workspace = true, features = ["parsing", "serde"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] } tower = { workspace = true, features = ["util"] } tower-http = { workspace = true, features = ["map-request-body", "trace", "util"] } +tower-sessions = "0.14.0" +tower-sessions-core = { version = "0.14.0", features = ["deletion-task"] } +tower-sessions-moka-store = "0.15.0" +tower-sessions-sqlx-store = { version = "0.15.0", features = ["postgres"] } tracing.workspace = true url.workspace = true diff --git a/crates/auth/auth.toml b/crates/auth/auth.toml index cfb6501..4e5b263 100644 --- a/crates/auth/auth.toml +++ b/crates/auth/auth.toml @@ -2,6 +2,9 @@ env = "development" port = 1304 +[misc.oauth] +session-lifespan = 3600 # seconds + [misc.oauth.discord] # query param for provider redirect-url = "http://127.0.0.1:1304/auth/authorised?provider=discord" diff --git a/crates/auth/migrations/20250723100947_account.sql b/crates/auth/migrations/20250723100947_account.sql deleted file mode 100644 index 8ddc1d3..0000000 --- a/crates/auth/migrations/20250723100947_account.sql +++ /dev/null @@ -1 +0,0 @@ --- Add migration script here diff --git a/crates/auth/migrations/20250723100947_user.sql b/crates/auth/migrations/20250723100947_user.sql new file mode 100644 index 0000000..440afc7 --- /dev/null +++ b/crates/auth/migrations/20250723100947_user.sql @@ -0,0 +1,9 @@ +-- Add migration script here +create table auth_user ( + id uuid primary key, + email text unique not null, + avatar text, + description text, + updated_at text, + create_at timestamptz not null default now() +); diff --git a/crates/auth/migrations/20250723121223_oauth_account.sql b/crates/auth/migrations/20250723121223_oauth_account.sql new file mode 100644 index 0000000..826fbcc --- /dev/null +++ b/crates/auth/migrations/20250723121223_oauth_account.sql @@ -0,0 +1,7 @@ +CREATE TABLE oauth_account ( + provider_id text not null, + provider_user_id text not null, + user_id uuid not null, + primary key (provider_id, provider_user_id), + foreign key (user_id) references auth_user(id) +); diff --git a/crates/auth/src/cnfg.rs b/crates/auth/src/cnfg.rs index 6afe2f8..af7b0a0 100644 --- a/crates/auth/src/cnfg.rs +++ b/crates/auth/src/cnfg.rs @@ -10,6 +10,7 @@ pub struct LocalConfig { #[serde(rename_all = "kebab-case")] pub struct OauthConfig { pub discord: OauthCredentials, + pub session_lifespan: u64, } #[derive(Deserialize, Clone)] diff --git a/crates/auth/src/main.rs b/crates/auth/src/main.rs index 4a71d69..ef8a358 100644 --- a/crates/auth/src/main.rs +++ b/crates/auth/src/main.rs @@ -8,6 +8,7 @@ use std::net::{Ipv6Addr, SocketAddr}; use clap::Parser; use stack_up::{Configuration, Services, tracing::Tracing}; +use tokio::{signal, task::AbortHandle}; use tracing::{info, trace}; use crate::{error::AppError, state::AppState}; @@ -56,13 +57,42 @@ async fn main() -> Result<(), AppError> { .run(&services.postgres) .await?; - let state = AppState::create(services, &config).await?; + let (state, deletion_task) = AppState::create(services, &config).await?; let addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, config.application.port)); let listener = tokio::net::TcpListener::bind(addr).await?; info!(port = addr.port(), "serving api"); - axum::serve(listener, server::router(state)).await?; + axum::serve(listener, server::router(state)) + .with_graceful_shutdown(shutdown_signal(deletion_task.abort_handle())) + .await?; + + deletion_task.await??; + Ok(()) } + +async fn shutdown_signal(deletion_task_abort_handle: AbortHandle) { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => { deletion_task_abort_handle.abort() }, + _ = terminate => { deletion_task_abort_handle.abort() }, + } +} diff --git a/crates/auth/src/server.rs b/crates/auth/src/server.rs index 3cfac60..d724d68 100644 --- a/crates/auth/src/server.rs +++ b/crates/auth/src/server.rs @@ -3,8 +3,13 @@ use tower_http::trace::TraceLayer; use crate::{server::routes::health_check, state::AppHandle}; +pub mod csrf_token_validation; pub mod routes; +const CSRF_TOKEN: &str = "csrf_token"; +const COOKIE_NAME: &str = "SESSION"; +const OAUTH_CSRF_COOKIE: &str = "SESSION"; + pub fn router(state: AppHandle) -> Router { Router::new() .merge(routes::discord::discord_router(state.clone())) diff --git a/crates/auth/src/server/csrf_token_validation.rs b/crates/auth/src/server/csrf_token_validation.rs new file mode 100644 index 0000000..c9a627c --- /dev/null +++ b/crates/auth/src/server/csrf_token_validation.rs @@ -0,0 +1,49 @@ +use anyhow::{Context, anyhow}; +use axum_extra::headers; +use oauth2::CsrfToken; +use time::OffsetDateTime; +use tower_sessions::{CachingSessionStore, SessionStore, session::Id}; +use tower_sessions_moka_store::MokaStore; +use tower_sessions_sqlx_store::PostgresStore; + +use crate::{ + error::AppError, + server::{COOKIE_NAME, CSRF_TOKEN, routes::authorised::AuthRequest}, + state::AppHandle, +}; + +pub struct Session { + id: String, + expires_at: OffsetDateTime, + user_id: String, +} + +pub async fn csrf_token_validation_workflow( + auth_request: &AuthRequest, + store: &CachingSessionStore<MokaStore, PostgresStore>, + oauth_session_id: Id, +) -> Result<(), AppError> { + let oauth_session = store.load(&oauth_session_id).await.unwrap().unwrap(); + + // Extract the CSRF token from the session + let csrf_token_serialized = oauth_session + .data + .get(CSRF_TOKEN) + .context("failed to get value from session")?; + let csrf_token = serde_json::from_value::<CsrfToken>(csrf_token_serialized.clone()) + .context("CSRF token not found in session")? + .to_owned(); + + // Cleanup the CSRF token session + store + .delete(&oauth_session_id) + .await + .context("Failed to destroy old session")?; + + // Validate CSRF token is the same as the one in the auth request + if *csrf_token.secret() != auth_request.state { + return Err(anyhow!("CSRF token mismatch").into()); + } + + Ok(()) +} diff --git a/crates/auth/src/server/routes/authorised.rs b/crates/auth/src/server/routes/authorised.rs index ddf048d..42bbde2 100644 --- a/crates/auth/src/server/routes/authorised.rs +++ b/crates/auth/src/server/routes/authorised.rs @@ -1,23 +1,108 @@ +use std::{str::FromStr, time::Duration}; + +use anyhow::Context; use axum::{ extract::{Query, State}, - response::IntoResponse, + http::HeaderMap, + response::{IntoResponse, Redirect}, }; use axum_extra::{TypedHeader, headers}; -use serde::Deserialize; +use oauth2::{AuthorizationCode, TokenResponse}; +use reqwest::header::SET_COOKIE; +use serde::{Deserialize, Serialize}; +use sqlx::types::uuid; +use tower_sessions::{ + SessionStore, + session::{Id, Record}, +}; -use crate::{error::AppError, server::routes::Provider, state::AppHandle}; +use crate::{ + error::AppError, + server::{ + OAUTH_CSRF_COOKIE, csrf_token_validation::csrf_token_validation_workflow, routes::Provider, + }, + state::AppHandle, +}; #[derive(Debug, Deserialize)] pub struct AuthRequest { provider: Provider, code: String, - state: String, + pub state: String, } +#[derive(Debug, Deserialize, Serialize)] +struct User { + id: String, + avatar: Option<String>, + username: String, + discriminator: String, +} + +/// The cookie to store the session id for user information. +const SESSION_COOKIE: &str = "info"; +const SESSION_DATA_KEY: &str = "data"; + async fn login_authorized( Query(query): Query<AuthRequest>, State(state): State<AppHandle>, TypedHeader(cookies): TypedHeader<headers::Cookie>, ) -> Result<impl IntoResponse, AppError> { - Ok("") + let oauth_session_id = Id::from_str( + cookies + .get(OAUTH_CSRF_COOKIE) + .context("missing session cookie")?, + ) + .unwrap(); + csrf_token_validation_workflow(&query, &state.session_store, oauth_session_id).await?; + + let client = state.http_client.clone(); + let store = state.session_store.clone(); + + // Get an auth token + let token = state + .discord_client + .exchange_code(AuthorizationCode::new(query.code.clone())) + .request_async(&client) + .await + .context("failed in sending request request to authorization server")?; + + let user_data: User = client + // https://discord.com/developers/docs/resources/user#get-current-user + .get("https://discordapp.com/api/users/@me") + .bearer_auth(token.access_token().secret()) + .send() + .await + .context("failed in sending request to target Url")? + .json::<User>() + .await + .context("failed to deserialize response as JSON")?; + + // Create a new session filled with user data + let session_id = Id(i128::from_le_bytes(uuid::Uuid::new_v4().to_bytes_le())); + store + .create(&mut Record { + id: session_id, + data: [( + SESSION_DATA_KEY.to_string(), + serde_json::to_value(user_data).unwrap(), + )] + .into(), + expiry_date: time::OffsetDateTime::now_utc() + + Duration::from_secs(state.local_config.oauth.session_lifespan), + }) + .await + .context("failed in inserting serialized value into session")?; + + // Store session and get corresponding cookie. + let cookie = format!("{SESSION_COOKIE}={session_id}; SameSite=Lax; HttpOnly; Secure; Path=/"); + + // Set cookie + let mut headers = HeaderMap::new(); + headers.insert( + SET_COOKIE, + cookie.parse().context("failed to parse cookie")?, + ); + + Ok((headers, Redirect::to("/"))) } diff --git a/crates/auth/src/server/routes/discord/discord_auth.rs b/crates/auth/src/server/routes/discord/discord_auth.rs index b07fa7a..5257a33 100644 --- a/crates/auth/src/server/routes/discord/discord_auth.rs +++ b/crates/auth/src/server/routes/discord/discord_auth.rs @@ -1,11 +1,24 @@ +use std::time::Duration; + +use anyhow::Context; use axum::{ extract::State, http::HeaderMap, response::{IntoResponse, Redirect}, }; use oauth2::{CsrfToken, Scope}; +use reqwest::header::SET_COOKIE; +use sqlx::types::uuid; +use tower_sessions::{ + SessionStore, + session::{Id, Record}, +}; -use crate::{error::AppError, state::AppHandle}; +use crate::{ + error::AppError, + server::{CSRF_TOKEN, OAUTH_CSRF_COOKIE}, + state::AppHandle, +}; pub async fn discord_auth(State(state): State<AppHandle>) -> Result<impl IntoResponse, AppError> { let (auth_url, csrf_token) = state @@ -14,7 +27,33 @@ pub async fn discord_auth(State(state): State<AppHandle>) -> Result<impl IntoRes .add_scope(Scope::new("identify".to_string())) .url(); + // Store the token in the session and retrieve the session cookie. + let session_id = Id(i128::from_le_bytes(uuid::Uuid::new_v4().to_bytes_le())); + let store = state.session_store.clone(); + + store + .create(&mut Record { + id: session_id, + data: [( + CSRF_TOKEN.to_string(), + serde_json::to_value(csrf_token).unwrap(), + )] + .into(), + expiry_date: time::OffsetDateTime::now_utc() + + Duration::from_secs(state.local_config.oauth.session_lifespan), + }) + .await + .unwrap(); + // .context("failed in inserting CSRF token into session")?; + + // Attach the session cookie to the response header + let cookie = + format!("{OAUTH_CSRF_COOKIE}={session_id}; SameSite=Lax; HttpOnly; Secure; Path=/"); let mut headers = HeaderMap::new(); + headers.insert( + SET_COOKIE, + cookie.parse().context("failed to parse cookie")?, + ); Ok((headers, Redirect::to(auth_url.as_ref()))) } diff --git a/crates/auth/src/state.rs b/crates/auth/src/state.rs index 5a483c9..927823c 100644 --- a/crates/auth/src/state.rs +++ b/crates/auth/src/state.rs @@ -1,6 +1,10 @@ use std::{ops::Deref, sync::Arc}; use stack_up::{Configuration, Services}; +use tokio::task::JoinHandle; +use tower_sessions::{CachingSessionStore, ExpiredDeletion, session_store}; +use tower_sessions_moka_store::MokaStore; +use tower_sessions_sqlx_store::PostgresStore; use crate::{ client::{OauthClient, discord::discord_client}, @@ -24,22 +28,39 @@ pub struct AppState { pub local_config: LocalConfig, pub discord_client: OauthClient, pub http_client: reqwest::Client, + pub session_store: CachingSessionStore<MokaStore, PostgresStore>, } impl AppState { pub async fn create( services: Services, configuration: &Configuration, - ) -> Result<AppHandle, AppError> { + ) -> Result<(AppHandle, JoinHandle<Result<(), session_store::Error>>), AppError> { let local_config: LocalConfig = serde_json::from_value(configuration.misc.clone())?; + let session_store_db = + tower_sessions_sqlx_store::PostgresStore::new(services.postgres.clone()); + session_store_db.migrate().await?; + let deletion_task = tokio::task::spawn( + session_store_db + .clone() + .continuously_delete_expired(tokio::time::Duration::from_secs(60)), + ); + let session_store_mem = MokaStore::new(Some(100)); + + let store = CachingSessionStore::new(session_store_mem, session_store_db); + let discord_client = discord_client(&local_config.oauth.discord)?; - Ok(AppHandle(Arc::new(Self { - services, - local_config, - discord_client, - http_client: reqwest::Client::new(), - }))) + Ok(( + AppHandle(Arc::new(Self { + services, + local_config, + discord_client, + http_client: reqwest::Client::new(), + session_store: store, + })), + deletion_task, + )) } } |