summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock137
-rw-r--r--crates/auth/Cargo.toml4
-rw-r--r--crates/auth/auth.toml3
-rw-r--r--crates/auth/migrations/20250723100947_account.sql1
-rw-r--r--crates/auth/migrations/20250723100947_user.sql9
-rw-r--r--crates/auth/migrations/20250723121223_oauth_account.sql7
-rw-r--r--crates/auth/src/cnfg.rs1
-rw-r--r--crates/auth/src/main.rs34
-rw-r--r--crates/auth/src/server.rs5
-rw-r--r--crates/auth/src/server/csrf_token_validation.rs49
-rw-r--r--crates/auth/src/server/routes/authorised.rs95
-rw-r--r--crates/auth/src/server/routes/discord/discord_auth.rs41
-rw-r--r--crates/auth/src/state.rs35
13 files changed, 405 insertions, 16 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 78bb019..a06795b 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -206,6 +206,10 @@ dependencies = [
"tokio",
"tower",
"tower-http",
+ "tower-sessions",
+ "tower-sessions-core",
+ "tower-sessions-moka-store",
+ "tower-sessions-sqlx-store",
"tracing",
"url",
]
@@ -513,6 +517,17 @@ dependencies = [
]
[[package]]
+name = "cookie"
+version = "0.18.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
+dependencies = [
+ "percent-encoding",
+ "time",
+ "version_check",
+]
+
+[[package]]
name = "core-foundation-sys"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1448,6 +1463,7 @@ checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
dependencies = [
"autocfg",
"scopeguard",
+ "serde",
]
[[package]]
@@ -1741,6 +1757,12 @@ dependencies = [
]
[[package]]
+name = "paste"
+version = "1.0.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
+
+[[package]]
name = "pathdiff"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2102,6 +2124,28 @@ dependencies = [
]
[[package]]
+name = "rmp"
+version = "0.8.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4"
+dependencies = [
+ "byteorder",
+ "num-traits",
+ "paste",
+]
+
+[[package]]
+name = "rmp-serde"
+version = "1.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db"
+dependencies = [
+ "byteorder",
+ "rmp",
+ "serde",
+]
+
+[[package]]
name = "rsa"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2918,6 +2962,22 @@ dependencies = [
]
[[package]]
+name = "tower-cookies"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "151b5a3e3c45df17466454bb74e9ecedecc955269bdedbf4d150dfa393b55a36"
+dependencies = [
+ "axum-core",
+ "cookie",
+ "futures-util",
+ "http",
+ "parking_lot",
+ "pin-project-lite",
+ "tower-layer",
+ "tower-service",
+]
+
+[[package]]
name = "tower-http"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2949,6 +3009,83 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
+name = "tower-sessions"
+version = "0.14.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "43a05911f23e8fae446005fe9b7b97e66d95b6db589dc1c4d59f6a2d4d4927d3"
+dependencies = [
+ "async-trait",
+ "http",
+ "time",
+ "tokio",
+ "tower-cookies",
+ "tower-layer",
+ "tower-service",
+ "tower-sessions-core",
+ "tower-sessions-memory-store",
+ "tracing",
+]
+
+[[package]]
+name = "tower-sessions-core"
+version = "0.14.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ce8cce604865576b7751b7a6bc3058f754569a60d689328bb74c52b1d87e355b"
+dependencies = [
+ "async-trait",
+ "axum-core",
+ "base64",
+ "futures",
+ "http",
+ "parking_lot",
+ "rand 0.8.5",
+ "serde",
+ "serde_json",
+ "thiserror 2.0.12",
+ "time",
+ "tokio",
+ "tracing",
+]
+
+[[package]]
+name = "tower-sessions-memory-store"
+version = "0.14.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fb05909f2e1420135a831dd5df9f5596d69196d0a64c3499ca474c4bd3d33242"
+dependencies = [
+ "async-trait",
+ "time",
+ "tokio",
+ "tower-sessions-core",
+]
+
+[[package]]
+name = "tower-sessions-moka-store"
+version = "0.15.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6a5e622001aa59953f422ade78a0fa0d1f4d2566c9bf697bffe6aa89f1438f08"
+dependencies = [
+ "async-trait",
+ "moka",
+ "time",
+ "tower-sessions-core",
+]
+
+[[package]]
+name = "tower-sessions-sqlx-store"
+version = "0.15.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e054622079f57fc1a7d6a6089c9334f963d62028fe21dc9eddd58af9a78480b3"
+dependencies = [
+ "async-trait",
+ "rmp-serde",
+ "sqlx",
+ "thiserror 1.0.69",
+ "time",
+ "tower-sessions-core",
+]
+
+[[package]]
name = "tracing"
version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/crates/auth/Cargo.toml b/crates/auth/Cargo.toml
index b5e53d9..b6ad707 100644
--- a/crates/auth/Cargo.toml
+++ b/crates/auth/Cargo.toml
@@ -25,6 +25,10 @@ time = { workspace = true, features = ["parsing", "serde"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }
tower = { workspace = true, features = ["util"] }
tower-http = { workspace = true, features = ["map-request-body", "trace", "util"] }
+tower-sessions = "0.14.0"
+tower-sessions-core = { version = "0.14.0", features = ["deletion-task"] }
+tower-sessions-moka-store = "0.15.0"
+tower-sessions-sqlx-store = { version = "0.15.0", features = ["postgres"] }
tracing.workspace = true
url.workspace = true
diff --git a/crates/auth/auth.toml b/crates/auth/auth.toml
index cfb6501..4e5b263 100644
--- a/crates/auth/auth.toml
+++ b/crates/auth/auth.toml
@@ -2,6 +2,9 @@
env = "development"
port = 1304
+[misc.oauth]
+session-lifespan = 3600 # seconds
+
[misc.oauth.discord]
# query param for provider
redirect-url = "http://127.0.0.1:1304/auth/authorised?provider=discord"
diff --git a/crates/auth/migrations/20250723100947_account.sql b/crates/auth/migrations/20250723100947_account.sql
deleted file mode 100644
index 8ddc1d3..0000000
--- a/crates/auth/migrations/20250723100947_account.sql
+++ /dev/null
@@ -1 +0,0 @@
--- Add migration script here
diff --git a/crates/auth/migrations/20250723100947_user.sql b/crates/auth/migrations/20250723100947_user.sql
new file mode 100644
index 0000000..440afc7
--- /dev/null
+++ b/crates/auth/migrations/20250723100947_user.sql
@@ -0,0 +1,9 @@
+-- Add migration script here
+create table auth_user (
+ id uuid primary key,
+ email text unique not null,
+ avatar text,
+ description text,
+ updated_at text,
+ create_at timestamptz not null default now()
+);
diff --git a/crates/auth/migrations/20250723121223_oauth_account.sql b/crates/auth/migrations/20250723121223_oauth_account.sql
new file mode 100644
index 0000000..826fbcc
--- /dev/null
+++ b/crates/auth/migrations/20250723121223_oauth_account.sql
@@ -0,0 +1,7 @@
+CREATE TABLE oauth_account (
+ provider_id text not null,
+ provider_user_id text not null,
+ user_id uuid not null,
+ primary key (provider_id, provider_user_id),
+ foreign key (user_id) references auth_user(id)
+);
diff --git a/crates/auth/src/cnfg.rs b/crates/auth/src/cnfg.rs
index 6afe2f8..af7b0a0 100644
--- a/crates/auth/src/cnfg.rs
+++ b/crates/auth/src/cnfg.rs
@@ -10,6 +10,7 @@ pub struct LocalConfig {
#[serde(rename_all = "kebab-case")]
pub struct OauthConfig {
pub discord: OauthCredentials,
+ pub session_lifespan: u64,
}
#[derive(Deserialize, Clone)]
diff --git a/crates/auth/src/main.rs b/crates/auth/src/main.rs
index 4a71d69..ef8a358 100644
--- a/crates/auth/src/main.rs
+++ b/crates/auth/src/main.rs
@@ -8,6 +8,7 @@ use std::net::{Ipv6Addr, SocketAddr};
use clap::Parser;
use stack_up::{Configuration, Services, tracing::Tracing};
+use tokio::{signal, task::AbortHandle};
use tracing::{info, trace};
use crate::{error::AppError, state::AppState};
@@ -56,13 +57,42 @@ async fn main() -> Result<(), AppError> {
.run(&services.postgres)
.await?;
- let state = AppState::create(services, &config).await?;
+ let (state, deletion_task) = AppState::create(services, &config).await?;
let addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, config.application.port));
let listener = tokio::net::TcpListener::bind(addr).await?;
info!(port = addr.port(), "serving api");
- axum::serve(listener, server::router(state)).await?;
+ axum::serve(listener, server::router(state))
+ .with_graceful_shutdown(shutdown_signal(deletion_task.abort_handle()))
+ .await?;
+
+ deletion_task.await??;
+
Ok(())
}
+
+async fn shutdown_signal(deletion_task_abort_handle: AbortHandle) {
+ let ctrl_c = async {
+ signal::ctrl_c()
+ .await
+ .expect("failed to install Ctrl+C handler");
+ };
+
+ #[cfg(unix)]
+ let terminate = async {
+ signal::unix::signal(signal::unix::SignalKind::terminate())
+ .expect("failed to install signal handler")
+ .recv()
+ .await;
+ };
+
+ #[cfg(not(unix))]
+ let terminate = std::future::pending::<()>();
+
+ tokio::select! {
+ _ = ctrl_c => { deletion_task_abort_handle.abort() },
+ _ = terminate => { deletion_task_abort_handle.abort() },
+ }
+}
diff --git a/crates/auth/src/server.rs b/crates/auth/src/server.rs
index 3cfac60..d724d68 100644
--- a/crates/auth/src/server.rs
+++ b/crates/auth/src/server.rs
@@ -3,8 +3,13 @@ use tower_http::trace::TraceLayer;
use crate::{server::routes::health_check, state::AppHandle};
+pub mod csrf_token_validation;
pub mod routes;
+const CSRF_TOKEN: &str = "csrf_token";
+const COOKIE_NAME: &str = "SESSION";
+const OAUTH_CSRF_COOKIE: &str = "SESSION";
+
pub fn router(state: AppHandle) -> Router {
Router::new()
.merge(routes::discord::discord_router(state.clone()))
diff --git a/crates/auth/src/server/csrf_token_validation.rs b/crates/auth/src/server/csrf_token_validation.rs
new file mode 100644
index 0000000..c9a627c
--- /dev/null
+++ b/crates/auth/src/server/csrf_token_validation.rs
@@ -0,0 +1,49 @@
+use anyhow::{Context, anyhow};
+use axum_extra::headers;
+use oauth2::CsrfToken;
+use time::OffsetDateTime;
+use tower_sessions::{CachingSessionStore, SessionStore, session::Id};
+use tower_sessions_moka_store::MokaStore;
+use tower_sessions_sqlx_store::PostgresStore;
+
+use crate::{
+ error::AppError,
+ server::{COOKIE_NAME, CSRF_TOKEN, routes::authorised::AuthRequest},
+ state::AppHandle,
+};
+
+pub struct Session {
+ id: String,
+ expires_at: OffsetDateTime,
+ user_id: String,
+}
+
+pub async fn csrf_token_validation_workflow(
+ auth_request: &AuthRequest,
+ store: &CachingSessionStore<MokaStore, PostgresStore>,
+ oauth_session_id: Id,
+) -> Result<(), AppError> {
+ let oauth_session = store.load(&oauth_session_id).await.unwrap().unwrap();
+
+ // Extract the CSRF token from the session
+ let csrf_token_serialized = oauth_session
+ .data
+ .get(CSRF_TOKEN)
+ .context("failed to get value from session")?;
+ let csrf_token = serde_json::from_value::<CsrfToken>(csrf_token_serialized.clone())
+ .context("CSRF token not found in session")?
+ .to_owned();
+
+ // Cleanup the CSRF token session
+ store
+ .delete(&oauth_session_id)
+ .await
+ .context("Failed to destroy old session")?;
+
+ // Validate CSRF token is the same as the one in the auth request
+ if *csrf_token.secret() != auth_request.state {
+ return Err(anyhow!("CSRF token mismatch").into());
+ }
+
+ Ok(())
+}
diff --git a/crates/auth/src/server/routes/authorised.rs b/crates/auth/src/server/routes/authorised.rs
index ddf048d..42bbde2 100644
--- a/crates/auth/src/server/routes/authorised.rs
+++ b/crates/auth/src/server/routes/authorised.rs
@@ -1,23 +1,108 @@
+use std::{str::FromStr, time::Duration};
+
+use anyhow::Context;
use axum::{
extract::{Query, State},
- response::IntoResponse,
+ http::HeaderMap,
+ response::{IntoResponse, Redirect},
};
use axum_extra::{TypedHeader, headers};
-use serde::Deserialize;
+use oauth2::{AuthorizationCode, TokenResponse};
+use reqwest::header::SET_COOKIE;
+use serde::{Deserialize, Serialize};
+use sqlx::types::uuid;
+use tower_sessions::{
+ SessionStore,
+ session::{Id, Record},
+};
-use crate::{error::AppError, server::routes::Provider, state::AppHandle};
+use crate::{
+ error::AppError,
+ server::{
+ OAUTH_CSRF_COOKIE, csrf_token_validation::csrf_token_validation_workflow, routes::Provider,
+ },
+ state::AppHandle,
+};
#[derive(Debug, Deserialize)]
pub struct AuthRequest {
provider: Provider,
code: String,
- state: String,
+ pub state: String,
}
+#[derive(Debug, Deserialize, Serialize)]
+struct User {
+ id: String,
+ avatar: Option<String>,
+ username: String,
+ discriminator: String,
+}
+
+/// The cookie to store the session id for user information.
+const SESSION_COOKIE: &str = "info";
+const SESSION_DATA_KEY: &str = "data";
+
async fn login_authorized(
Query(query): Query<AuthRequest>,
State(state): State<AppHandle>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> Result<impl IntoResponse, AppError> {
- Ok("")
+ let oauth_session_id = Id::from_str(
+ cookies
+ .get(OAUTH_CSRF_COOKIE)
+ .context("missing session cookie")?,
+ )
+ .unwrap();
+ csrf_token_validation_workflow(&query, &state.session_store, oauth_session_id).await?;
+
+ let client = state.http_client.clone();
+ let store = state.session_store.clone();
+
+ // Get an auth token
+ let token = state
+ .discord_client
+ .exchange_code(AuthorizationCode::new(query.code.clone()))
+ .request_async(&client)
+ .await
+ .context("failed in sending request request to authorization server")?;
+
+ let user_data: User = client
+ // https://discord.com/developers/docs/resources/user#get-current-user
+ .get("https://discordapp.com/api/users/@me")
+ .bearer_auth(token.access_token().secret())
+ .send()
+ .await
+ .context("failed in sending request to target Url")?
+ .json::<User>()
+ .await
+ .context("failed to deserialize response as JSON")?;
+
+ // Create a new session filled with user data
+ let session_id = Id(i128::from_le_bytes(uuid::Uuid::new_v4().to_bytes_le()));
+ store
+ .create(&mut Record {
+ id: session_id,
+ data: [(
+ SESSION_DATA_KEY.to_string(),
+ serde_json::to_value(user_data).unwrap(),
+ )]
+ .into(),
+ expiry_date: time::OffsetDateTime::now_utc()
+ + Duration::from_secs(state.local_config.oauth.session_lifespan),
+ })
+ .await
+ .context("failed in inserting serialized value into session")?;
+
+ // Store session and get corresponding cookie.
+ let cookie = format!("{SESSION_COOKIE}={session_id}; SameSite=Lax; HttpOnly; Secure; Path=/");
+
+ // Set cookie
+ let mut headers = HeaderMap::new();
+ headers.insert(
+ SET_COOKIE,
+ cookie.parse().context("failed to parse cookie")?,
+ );
+
+ Ok((headers, Redirect::to("/")))
}
diff --git a/crates/auth/src/server/routes/discord/discord_auth.rs b/crates/auth/src/server/routes/discord/discord_auth.rs
index b07fa7a..5257a33 100644
--- a/crates/auth/src/server/routes/discord/discord_auth.rs
+++ b/crates/auth/src/server/routes/discord/discord_auth.rs
@@ -1,11 +1,24 @@
+use std::time::Duration;
+
+use anyhow::Context;
use axum::{
extract::State,
http::HeaderMap,
response::{IntoResponse, Redirect},
};
use oauth2::{CsrfToken, Scope};
+use reqwest::header::SET_COOKIE;
+use sqlx::types::uuid;
+use tower_sessions::{
+ SessionStore,
+ session::{Id, Record},
+};
-use crate::{error::AppError, state::AppHandle};
+use crate::{
+ error::AppError,
+ server::{CSRF_TOKEN, OAUTH_CSRF_COOKIE},
+ state::AppHandle,
+};
pub async fn discord_auth(State(state): State<AppHandle>) -> Result<impl IntoResponse, AppError> {
let (auth_url, csrf_token) = state
@@ -14,7 +27,33 @@ pub async fn discord_auth(State(state): State<AppHandle>) -> Result<impl IntoRes
.add_scope(Scope::new("identify".to_string()))
.url();
+ // Store the token in the session and retrieve the session cookie.
+ let session_id = Id(i128::from_le_bytes(uuid::Uuid::new_v4().to_bytes_le()));
+ let store = state.session_store.clone();
+
+ store
+ .create(&mut Record {
+ id: session_id,
+ data: [(
+ CSRF_TOKEN.to_string(),
+ serde_json::to_value(csrf_token).unwrap(),
+ )]
+ .into(),
+ expiry_date: time::OffsetDateTime::now_utc()
+ + Duration::from_secs(state.local_config.oauth.session_lifespan),
+ })
+ .await
+ .unwrap();
+ // .context("failed in inserting CSRF token into session")?;
+
+ // Attach the session cookie to the response header
+ let cookie =
+ format!("{OAUTH_CSRF_COOKIE}={session_id}; SameSite=Lax; HttpOnly; Secure; Path=/");
let mut headers = HeaderMap::new();
+ headers.insert(
+ SET_COOKIE,
+ cookie.parse().context("failed to parse cookie")?,
+ );
Ok((headers, Redirect::to(auth_url.as_ref())))
}
diff --git a/crates/auth/src/state.rs b/crates/auth/src/state.rs
index 5a483c9..927823c 100644
--- a/crates/auth/src/state.rs
+++ b/crates/auth/src/state.rs
@@ -1,6 +1,10 @@
use std::{ops::Deref, sync::Arc};
use stack_up::{Configuration, Services};
+use tokio::task::JoinHandle;
+use tower_sessions::{CachingSessionStore, ExpiredDeletion, session_store};
+use tower_sessions_moka_store::MokaStore;
+use tower_sessions_sqlx_store::PostgresStore;
use crate::{
client::{OauthClient, discord::discord_client},
@@ -24,22 +28,39 @@ pub struct AppState {
pub local_config: LocalConfig,
pub discord_client: OauthClient,
pub http_client: reqwest::Client,
+ pub session_store: CachingSessionStore<MokaStore, PostgresStore>,
}
impl AppState {
pub async fn create(
services: Services,
configuration: &Configuration,
- ) -> Result<AppHandle, AppError> {
+ ) -> Result<(AppHandle, JoinHandle<Result<(), session_store::Error>>), AppError> {
let local_config: LocalConfig = serde_json::from_value(configuration.misc.clone())?;
+ let session_store_db =
+ tower_sessions_sqlx_store::PostgresStore::new(services.postgres.clone());
+ session_store_db.migrate().await?;
+ let deletion_task = tokio::task::spawn(
+ session_store_db
+ .clone()
+ .continuously_delete_expired(tokio::time::Duration::from_secs(60)),
+ );
+ let session_store_mem = MokaStore::new(Some(100));
+
+ let store = CachingSessionStore::new(session_store_mem, session_store_db);
+
let discord_client = discord_client(&local_config.oauth.discord)?;
- Ok(AppHandle(Arc::new(Self {
- services,
- local_config,
- discord_client,
- http_client: reqwest::Client::new(),
- })))
+ Ok((
+ AppHandle(Arc::new(Self {
+ services,
+ local_config,
+ discord_client,
+ http_client: reqwest::Client::new(),
+ session_store: store,
+ })),
+ deletion_task,
+ ))
}
}