diff options
Diffstat (limited to 'crates/auth-service')
22 files changed, 964 insertions, 0 deletions
diff --git a/crates/auth-service/Cargo.toml b/crates/auth-service/Cargo.toml new file mode 100644 index 0000000..bbbb10d --- /dev/null +++ b/crates/auth-service/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "sellershut-auth" +version = "0.1.0" +edition = "2024" +license.workspace = true +homepage.workspace = true +documentation.workspace = true +description.workspace = true + +[dependencies] +anyhow.workspace = true +axum = { workspace = true, features = ["macros"] } +axum-extra = { version = "0.10.1", features = ["typed-header"] } +base64.workspace = true +clap = { workspace = true, features = ["derive"] } +config = { workspace = true, features = ["convert-case", "toml"] } +futures-util.workspace = true +jsonwebtoken = "9.3.1" +nanoid.workspace = true +oauth2 = "5.0.0" +reqwest = { workspace = true, features = ["json", "rustls-tls"] } +sellershut-core = { workspace = true, features = ["auth", "serde"] } +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true +sqlx = { workspace = true, features = ["macros", "migrate", "runtime-tokio", "time", "tls-rustls", "uuid"] } +time = { workspace = true, features = ["parsing", "serde"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] } +tonic.workspace = true +tonic-reflection = "0.13.0" +tower = { workspace = true, features = ["steer", "util"] } +tower-http = { workspace = true, features = ["map-request-body", "trace", "util"] } +tower-sessions = "0.14.0" +tower-sessions-core = { version = "0.14.0", features = ["deletion-task"] } +tower-sessions-moka-store = "0.15.0" +tower-sessions-sqlx-store = { version = "0.15.0", features = ["postgres"] } +tracing.workspace = true +url.workspace = true +uuid = { workspace = true, features = ["serde", "v7"] } + +[dependencies.stack-up] +workspace = true +features = ["api", "postgres", "tracing"] diff --git a/crates/auth-service/auth.toml b/crates/auth-service/auth.toml new file mode 100644 index 0000000..8febe90 --- /dev/null +++ b/crates/auth-service/auth.toml @@ -0,0 +1,40 @@ +[application] +env = "development" +port = 1304 + +[nats] +hosts = ["nats://localhost:4222"] + +[misc] +profile-endpoint = "http://localhost:1610" + +[misc.oauth] +session-lifespan = 3600 # seconds +jwt-encoding-key = "secret" + +[misc.oauth.discord] +# query param for provider +redirect-url = "http://127.0.0.1:1304/auth/authorised?provider=discord" +#client-id = "" +#client-secret = "" +#auth-url = "" + + +[monitoring] +log-level = "auth_service=trace,info" + +[database] +pool_size = 100 +port = 5432 +name = "auth" +host = "localhost" +password = "password" +user = "postgres" + +[cache] +dsn = "redis://localhost:6379" +pooled = true +type = "non-clustered" # clustered, non-clustered or sentinel +max-connections = 100 + +# vim:ft=toml diff --git a/crates/auth-service/migrations/20250723100947_user.sql b/crates/auth-service/migrations/20250723100947_user.sql new file mode 100644 index 0000000..b5566fe --- /dev/null +++ b/crates/auth-service/migrations/20250723100947_user.sql @@ -0,0 +1,20 @@ +-- Add migration script here +create table auth_user ( + id uuid primary key, + email text unique not null, + updated_at timestamptz not null default now(), + created_at timestamptz not null default now() +); + +create or replace function set_updated_at() +returns trigger as $$ +begin + new.updated_at := now(); + return new; +end; +$$ language plpgsql; + +create trigger trigger_set_updated_at +before update on auth_user +for each row +execute function set_updated_at(); diff --git a/crates/auth-service/migrations/20250723121223_oauth_account.sql b/crates/auth-service/migrations/20250723121223_oauth_account.sql new file mode 100644 index 0000000..826fbcc --- /dev/null +++ b/crates/auth-service/migrations/20250723121223_oauth_account.sql @@ -0,0 +1,7 @@ +CREATE TABLE oauth_account ( + provider_id text not null, + provider_user_id text not null, + user_id uuid not null, + primary key (provider_id, provider_user_id), + foreign key (user_id) references auth_user(id) +); diff --git a/crates/auth-service/migrations/20250725160900_session.sql b/crates/auth-service/migrations/20250725160900_session.sql new file mode 100644 index 0000000..c5e76dc --- /dev/null +++ b/crates/auth-service/migrations/20250725160900_session.sql @@ -0,0 +1,8 @@ +-- Add migration script here +create schema if not exists "tower_sessions"; + +create table "tower_sessions"."session" ( + id text primary key not null, + data bytea not null, + expiry_date timestamptz not null +); diff --git a/crates/auth-service/migrations/20250725161014_token.sql b/crates/auth-service/migrations/20250725161014_token.sql new file mode 100644 index 0000000..68f476c --- /dev/null +++ b/crates/auth-service/migrations/20250725161014_token.sql @@ -0,0 +1,9 @@ +-- Add migration script here +create table token ( + user_id uuid not null, + token text not null, + session_id text not null, + primary key (user_id, session_id), + foreign key (session_id) references "tower_sessions"."session"(id) on delete cascade, + foreign key (user_id) references auth_user(id) on delete cascade +); diff --git a/crates/auth-service/src/auth.rs b/crates/auth-service/src/auth.rs new file mode 100644 index 0000000..04cb60a --- /dev/null +++ b/crates/auth-service/src/auth.rs @@ -0,0 +1,12 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Claims { + pub iss: String, + pub sub: Uuid, + pub exp: i64, + pub iat: i64, + pub sid: String, + pub aud: String, +} diff --git a/crates/auth-service/src/client.rs b/crates/auth-service/src/client.rs new file mode 100644 index 0000000..5aa4de0 --- /dev/null +++ b/crates/auth-service/src/client.rs @@ -0,0 +1,6 @@ +use oauth2::{EndpointNotSet, EndpointSet, basic::BasicClient}; + +pub mod discord; + +pub type OauthClient = + BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>; diff --git a/crates/auth-service/src/client/discord.rs b/crates/auth-service/src/client/discord.rs new file mode 100644 index 0000000..9217684 --- /dev/null +++ b/crates/auth-service/src/client/discord.rs @@ -0,0 +1,30 @@ +use crate::{client::OauthClient, cnfg::OauthCredentials, error::AppError}; +use anyhow::Context; +use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, basic::BasicClient}; + +pub fn discord_client(config: &OauthCredentials) -> Result<OauthClient, AppError> { + let auth_url = config.auth_url.clone().unwrap_or_else(|| { + "https://discord.com/api/oauth2/authorize?response_type=code".to_string() + }); + + let token_url = config + .token_url + .clone() + .unwrap_or_else(|| "https://discord.com/api/oauth2/token".to_string()); + + let c = BasicClient::new(ClientId::new(config.client_id.to_owned())) + .set_client_secret(ClientSecret::new(config.client_secret.to_owned())) + .set_auth_uri( + AuthUrl::new(auth_url).context("failed to create new auth server url [discord]")?, + ) + .set_redirect_uri( + RedirectUrl::new(config.redirect_url.to_owned()) + .context("failed to create new redirect URL [discord]")?, + ) + .set_token_uri( + TokenUrl::new(token_url) + .context("failed to create new token endpoint URL [discord]")?, + ); + + Ok(c) +} diff --git a/crates/auth-service/src/cnfg.rs b/crates/auth-service/src/cnfg.rs new file mode 100644 index 0000000..9b765a5 --- /dev/null +++ b/crates/auth-service/src/cnfg.rs @@ -0,0 +1,26 @@ +use serde::Deserialize; + +#[derive(Deserialize, Clone)] +#[serde(rename_all = "kebab-case")] +pub struct LocalConfig { + pub oauth: OauthConfig, + pub profile_endpoint: String, +} + +#[derive(Deserialize, Clone)] +#[serde(rename_all = "kebab-case")] +pub struct OauthConfig { + pub discord: OauthCredentials, + pub session_lifespan: u64, + pub jwt_encoding_key: String, +} + +#[derive(Deserialize, Clone)] +#[serde(rename_all = "kebab-case")] +pub struct OauthCredentials { + pub client_id: String, + pub client_secret: String, + pub redirect_url: String, + pub auth_url: Option<String>, + pub token_url: Option<String>, +} diff --git a/crates/auth-service/src/error.rs b/crates/auth-service/src/error.rs new file mode 100644 index 0000000..730f99a --- /dev/null +++ b/crates/auth-service/src/error.rs @@ -0,0 +1,26 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, +}; + +#[derive(Debug)] +pub struct AppError(anyhow::Error); + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Something went wrong: {}", self.0), + ) + .into_response() + } +} + +impl<E> From<E> for AppError +where + E: Into<anyhow::Error>, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} diff --git a/crates/auth-service/src/main.rs b/crates/auth-service/src/main.rs new file mode 100644 index 0000000..50544fc --- /dev/null +++ b/crates/auth-service/src/main.rs @@ -0,0 +1,141 @@ +mod auth; +mod client; +mod cnfg; +mod error; +mod server; +mod state; + +use std::net::{Ipv6Addr, SocketAddr}; + +use clap::Parser; +use reqwest::header::CONTENT_TYPE; +use sellershut_core::auth::{AUTH_FILE_DESCRIPTOR_SET, auth_server::AuthServer}; +use stack_up::{Configuration, tracing::Tracing}; +use tokio::{signal, task::AbortHandle}; +use tonic::service::Routes; +use tower::{make::Shared, steer::Steer}; +use tracing::{info, trace}; + +use crate::{ + error::AppError, + server::grpc::interceptor::MyInterceptor, + state::{AppState, Services}, +}; + +/// sellershut-auth +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + /// Path to config file + #[arg(short, long)] + config_file: Option<std::path::PathBuf>, +} + +#[tokio::main] +async fn main() -> Result<(), AppError> { + let args = Args::parse(); + let config = include_str!("../auth.toml"); + + let mut config = config::Config::builder() + .add_source(config::File::from_str(config, config::FileFormat::Toml)) + .add_source( + config::Environment::with_prefix("APP") + .separator("__") + .convert_case(config::Case::Kebab), + ); + + if let Some(cf) = args.config_file.as_ref().and_then(|v| v.to_str()) { + config = config.add_source(config::File::new(cf, config::FileFormat::Toml)); + }; + + let mut config: Configuration = config.build()?.try_deserialize()?; + config.application.name = env!("CARGO_CRATE_NAME").into(); + config.application.version = env!("CARGO_PKG_VERSION").into(); + + let _tracing = Tracing::builder().build(&config.monitoring); + + let mut services = stack_up::Services::builder() + .postgres(&config.database) + .await + .inspect_err(|e| tracing::error!("database: {e}"))? + .build(); + + let postgres = services + .postgres + .take() + .ok_or_else(|| anyhow::anyhow!("database is not ready"))?; + + let services = Services { postgres }; + + trace!("running migrations"); + sqlx::migrate!("./migrations") + .run(&services.postgres) + .await?; + + let (state, deletion_task) = AppState::create(services, &config).await?; + + let addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, config.application.port)); + + let listener = tokio::net::TcpListener::bind(addr).await?; + info!(port = addr.port(), "serving api"); + + let service = AuthServer::with_interceptor(state.clone(), MyInterceptor); + let auth_reflector = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(AUTH_FILE_DESCRIPTOR_SET) + .build_v1()?; + + let grpc_server = Routes::new(service) + .add_service(auth_reflector) + .into_axum_router(); + + let service = Steer::new( + vec![server::router(state), grpc_server], + |req: &axum::extract::Request, _services: &[_]| { + if req + .headers() + .get(CONTENT_TYPE) + .map(|content_type| content_type.as_bytes()) + .filter(|content_type| content_type.starts_with(b"application/grpc")) + .is_some() + { + // grpc service + 1 + } else { + // http service + 0 + } + }, + ); + + axum::serve(listener, Shared::new(service)) + .with_graceful_shutdown(shutdown_signal(deletion_task.abort_handle())) + .await?; + + deletion_task.await??; + + Ok(()) +} + +async fn shutdown_signal(deletion_task_abort_handle: AbortHandle) { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => { deletion_task_abort_handle.abort() }, + _ = terminate => { deletion_task_abort_handle.abort() }, + } +} diff --git a/crates/auth-service/src/server.rs b/crates/auth-service/src/server.rs new file mode 100644 index 0000000..7b66c42 --- /dev/null +++ b/crates/auth-service/src/server.rs @@ -0,0 +1,37 @@ +use axum::{Router, routing::get}; +use tower_http::trace::TraceLayer; + +use crate::{ + server::routes::{authorised::login_authorised, health_check}, + state::AppHandle, +}; + +pub mod csrf_token_validation; +pub mod grpc; +pub mod routes; + +const CSRF_TOKEN: &str = "csrf_token"; +const OAUTH_CSRF_COOKIE: &str = "SESSION"; + +pub fn router(state: AppHandle) -> Router { + Router::new() + .route("/auth/authorised", get(login_authorised)) + .route("/", get(health_check)) + .with_state(state.clone()) + .merge(routes::discord::discord_router(state)) + .layer(TraceLayer::new_for_http()) +} + +#[cfg(test)] +pub(crate) fn test_config() -> stack_up::Configuration { + use stack_up::Configuration; + + let config_path = "auth.toml"; + + let config = config::Config::builder() + .add_source(config::File::new(config_path, config::FileFormat::Toml)) + .build() + .unwrap(); + + config.try_deserialize::<Configuration>().unwrap() +} diff --git a/crates/auth-service/src/server/csrf_token_validation.rs b/crates/auth-service/src/server/csrf_token_validation.rs new file mode 100644 index 0000000..94424c8 --- /dev/null +++ b/crates/auth-service/src/server/csrf_token_validation.rs @@ -0,0 +1,40 @@ +use anyhow::{Context, anyhow}; +use oauth2::CsrfToken; +use tower_sessions::{CachingSessionStore, SessionStore, session::Id}; +use tower_sessions_moka_store::MokaStore; +use tower_sessions_sqlx_store::PostgresStore; + +use crate::{ + error::AppError, + server::{CSRF_TOKEN, routes::authorised::AuthRequest}, +}; + +pub async fn csrf_token_validation_workflow( + auth_request: &AuthRequest, + store: &CachingSessionStore<MokaStore, PostgresStore>, + oauth_session_id: Id, +) -> Result<(), AppError> { + let oauth_session = store.load(&oauth_session_id).await.unwrap().unwrap(); + + // Extract the CSRF token from the session + let csrf_token_serialized = oauth_session + .data + .get(CSRF_TOKEN) + .context("failed to get value from session")?; + let csrf_token = serde_json::from_value::<CsrfToken>(csrf_token_serialized.clone()) + .context("CSRF token not found in session")? + .to_owned(); + + // Cleanup the CSRF token session + store + .delete(&oauth_session_id) + .await + .context("Failed to destroy old session")?; + + // Validate CSRF token is the same as the one in the auth request + if *csrf_token.secret() != auth_request.state { + return Err(anyhow!("CSRF token mismatch").into()); + } + + Ok(()) +} diff --git a/crates/auth-service/src/server/grpc.rs b/crates/auth-service/src/server/grpc.rs new file mode 100644 index 0000000..0fd775b --- /dev/null +++ b/crates/auth-service/src/server/grpc.rs @@ -0,0 +1,2 @@ +pub mod auth; +pub mod interceptor; diff --git a/crates/auth-service/src/server/grpc/auth.rs b/crates/auth-service/src/server/grpc/auth.rs new file mode 100644 index 0000000..fb00291 --- /dev/null +++ b/crates/auth-service/src/server/grpc/auth.rs @@ -0,0 +1,50 @@ +use std::str::FromStr; + +use jsonwebtoken::DecodingKey; +use sellershut_core::auth::{ValidationRequest, ValidationResponse, auth_server::Auth}; +use tonic::{Request, Response, Status, async_trait}; +use tower_sessions::{SessionStore, session::Id}; +use tracing::warn; + +use crate::{auth::Claims, state::AppHandle}; + +#[async_trait] +impl Auth for AppHandle { + async fn validate_auth_token( + &self, + request: Request<ValidationRequest>, + ) -> Result<Response<ValidationResponse>, Status> { + let token = request.into_inner().token; + + let token = jsonwebtoken::decode::<Claims>( + &token, + &DecodingKey::from_secret(self.local_config.oauth.jwt_encoding_key.as_bytes()), + &jsonwebtoken::Validation::default(), + ); + + match token { + Ok(value) => { + let session_id = value.claims.sid; + let store = &self.session_store; + match Id::from_str(&session_id) { + Ok(ref id) => { + if let Ok(Some(_)) = store.load(id).await { + return Ok(Response::new(ValidationResponse { valid: true })); + } else { + return Ok(Response::new(Default::default())); + } + } + Err(e) => { + warn!("{e}"); + + return Ok(Response::new(Default::default())); + } + } + } + Err(e) => { + warn!("{e}"); + Ok(Response::new(ValidationResponse::default())) + } + } + } +} diff --git a/crates/auth-service/src/server/grpc/interceptor.rs b/crates/auth-service/src/server/grpc/interceptor.rs new file mode 100644 index 0000000..155a306 --- /dev/null +++ b/crates/auth-service/src/server/grpc/interceptor.rs @@ -0,0 +1,17 @@ +use tonic::{ + Status, + service::{Interceptor, interceptor::InterceptedService}, + transport::Channel, +}; +use tracing::Span; + +pub type Intercepted = InterceptedService<Channel, MyInterceptor>; + +#[derive(Clone, Copy)] +pub struct MyInterceptor; + +impl Interceptor for MyInterceptor { + fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, Status> { + Ok(request) + } +} diff --git a/crates/auth-service/src/server/routes.rs b/crates/auth-service/src/server/routes.rs new file mode 100644 index 0000000..6773962 --- /dev/null +++ b/crates/auth-service/src/server/routes.rs @@ -0,0 +1,62 @@ +pub mod authorised; +pub mod discord; + +use std::fmt::Display; + +use axum::response::IntoResponse; +use serde::Deserialize; + +#[derive(Debug, Clone, Copy, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Provider { + Discord, +} + +impl Display for Provider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Provider::Discord => "discord", + } + ) + } +} + +pub async fn health_check() -> impl IntoResponse { + let name = env!("CARGO_PKG_NAME"); + let ver = env!("CARGO_PKG_VERSION"); + + format!("{name} v{ver} is live") +} + +#[cfg(test)] +mod tests { + use axum::{ + body::Body, + http::{Request, StatusCode}, + }; + use sqlx::PgPool; + use stack_up::Services; + use tower::ServiceExt; + + use crate::{ + server::{self, test_config}, + state::AppState, + }; + + #[sqlx::test] + async fn health_check(pool: PgPool) { + let services = Services { postgres: pool }; + let (state, _) = AppState::create(services, &test_config()).await.unwrap(); + let app = server::router(state); + + let response = app + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + } +} diff --git a/crates/auth-service/src/server/routes/authorised.rs b/crates/auth-service/src/server/routes/authorised.rs new file mode 100644 index 0000000..4d48299 --- /dev/null +++ b/crates/auth-service/src/server/routes/authorised.rs @@ -0,0 +1,236 @@ +use std::{str::FromStr, time::Duration}; + +use anyhow::Context; +use axum::{ + extract::{Query, State}, + http::HeaderMap, + response::{IntoResponse, Redirect}, +}; +use axum_extra::{TypedHeader, headers}; +use oauth2::{AuthorizationCode, TokenResponse}; +use reqwest::{StatusCode, header::SET_COOKIE}; +use sellershut_core::profile::CreateUserRequest; +use serde::{Deserialize, Serialize}; +use sqlx::types::uuid; +use time::OffsetDateTime; +use tower_sessions::{ + SessionStore, + session::{Id, Record}, +}; +use uuid::Uuid; + +use crate::{ + auth::Claims, + error::AppError, + server::{ + OAUTH_CSRF_COOKIE, csrf_token_validation::csrf_token_validation_workflow, routes::Provider, + }, + state::AppHandle, +}; + +#[derive(Debug, Deserialize)] +pub struct AuthRequest { + provider: Provider, + code: String, + pub state: String, +} + +#[derive(Debug, Deserialize, Serialize)] +struct User { + id: String, + avatar: Option<String>, + username: String, + discriminator: String, + verified: bool, + email: String, +} + +#[derive(Debug, Deserialize, Serialize)] +struct DbUser { + id: Uuid, + email: String, + created_at: OffsetDateTime, + updated_at: OffsetDateTime, +} + +/// The cookie to store the session id for user information. +const SESSION_COOKIE: &str = "info"; +const SESSION_DATA_KEY: &str = "data"; + +pub async fn login_authorised( + Query(query): Query<AuthRequest>, + State(state): State<AppHandle>, + TypedHeader(cookies): TypedHeader<headers::Cookie>, +) -> Result<impl IntoResponse, AppError> { + let provider = query.provider.to_string(); + let oauth_session_id = Id::from_str( + cookies + .get(OAUTH_CSRF_COOKIE) + .context("missing session cookie")?, + )?; + csrf_token_validation_workflow(&query, &state.session_store, oauth_session_id).await?; + + let client = state.http_client.clone(); + let store = state.session_store.clone(); + + // Get an auth token + let token = state + .discord_client + .exchange_code(AuthorizationCode::new(query.code.clone())) + .request_async(&client) + .await + .context("failed in sending request request to authorisation server")?; + + let user_data = client + // https://discord.com/developers/docs/resources/user#get-current-user + .get("https://discordapp.com/api/users/@me") + .bearer_auth(token.access_token().secret()) + .send() + .await + .context("failed in sending request to target Url")? + .json::<serde_json::Value>() + .await + .context("failed to deserialise response as JSON")?; + + dbg!(&user_data); + + let user_data: User = serde_json::from_value(user_data)?; + + if !user_data.verified { + return Ok((StatusCode::UNAUTHORIZED, "email is not verified").into_response()); + } + + // Create a new session filled with user data + let session_id = Id(i128::from_le_bytes(uuid::Uuid::new_v4().to_bytes_le())); + + let mut transaction = state.services.postgres.begin().await?; + + let user = sqlx::query_as!( + DbUser, + " + select + p.* + from + auth_user p + inner join + oauth_account a + on + p.id=a.user_id + where a.provider_id = $1 and a.provider_user_id = $2 + ", + provider, + user_data.id + ) + .fetch_optional(&mut *transaction) + .await?; + + let user = if let Some(user) = user { + user + } else { + let uuid = uuid::Uuid::now_v7(); + let user = sqlx::query_as!( + DbUser, + "insert into auth_user (id, email) values ($1, $2) + on conflict (email) do update + set email = excluded.email + returning *; + ", + uuid, + user_data.email, + ) + .fetch_one(&mut *transaction) + .await?; + + sqlx::query_as!( + DbUser, + "with upsert as ( + insert into oauth_account (provider_id, provider_user_id, user_id) values ($1, $2, $3) + on conflict (provider_id, provider_user_id) do update + set provider_id = excluded.provider_id -- no-op + returning user_id + ) + select u.* + from upsert + join auth_user u on u.id = upsert.user_id; + ", + provider, + user_data.id, + user.id + ) + .fetch_one(&mut *transaction) + .await? + }; + + let exp = OffsetDateTime::now_utc() + Duration::from_secs(15 * 60); + + let claims = Claims { + sub: user.id, + exp: exp.unix_timestamp(), + iss: "sellershut".to_owned(), + sid: session_id.to_string(), + aud: "sellershut".to_owned(), + iat: OffsetDateTime::now_utc().unix_timestamp(), + }; + + let token = jsonwebtoken::encode( + &jsonwebtoken::Header::default(), + &claims, + &jsonwebtoken::EncodingKey::from_secret( + state.local_config.oauth.jwt_encoding_key.as_bytes(), + ), + )?; + + let user_request = CreateUserRequest { + email: user_data.email.to_owned(), + avatar: user_data.avatar.as_ref().map(|value| { + format!( + "https://cdn.discordapp.com/avatars/{}/{value}", + user_data.id + ) + }), + }; + + store + .create(&mut Record { + id: session_id, + data: [( + SESSION_DATA_KEY.to_string(), + serde_json::to_value(user_data).unwrap(), + )] + .into(), + expiry_date: time::OffsetDateTime::now_utc() + + Duration::from_secs(state.local_config.oauth.session_lifespan), + }) + .await + .context("failed in inserting serialised value into session")?; + + sqlx::query!( + "insert into token (user_id, token, session_id) values ($1, $2, $3)", + user.id, + token, + session_id.to_string() + ) + .execute(&mut *transaction) + .await?; + + let cookie = format!("{SESSION_COOKIE}={session_id}; SameSite=Lax; HttpOnly; Secure; Path=/"); + + let mut profile_client = state.profile_client.clone(); + let resp = profile_client.create_user(user_request).await?.into_inner(); + + let user_id = resp.temp_id; + + let mut headers = HeaderMap::new(); + headers.insert( + SET_COOKIE, + cookie.parse().context("failed to parse cookie")?, + ); + + transaction.commit().await?; + + Ok(( + headers, + Redirect::to(&format!("/?user={user_id}&token={token}")), + ) + .into_response()) +} diff --git a/crates/auth-service/src/server/routes/discord.rs b/crates/auth-service/src/server/routes/discord.rs new file mode 100644 index 0000000..e1a834f --- /dev/null +++ b/crates/auth-service/src/server/routes/discord.rs @@ -0,0 +1,10 @@ +mod discord_auth; +use axum::{Router, routing::get}; + +use crate::state::AppHandle; + +pub fn discord_router(state: AppHandle) -> Router { + Router::new() + .route("/auth/discord", get(discord_auth::discord_auth)) + .with_state(state) +} diff --git a/crates/auth-service/src/server/routes/discord/discord_auth.rs b/crates/auth-service/src/server/routes/discord/discord_auth.rs new file mode 100644 index 0000000..a45de86 --- /dev/null +++ b/crates/auth-service/src/server/routes/discord/discord_auth.rs @@ -0,0 +1,58 @@ +use std::time::Duration; + +use anyhow::Context; +use axum::{ + extract::State, + http::HeaderMap, + response::{IntoResponse, Redirect}, +}; +use oauth2::{CsrfToken, Scope}; +use reqwest::header::SET_COOKIE; +use sqlx::types::uuid; +use tower_sessions::{ + SessionStore, + session::{Id, Record}, +}; + +use crate::{ + error::AppError, + server::{CSRF_TOKEN, OAUTH_CSRF_COOKIE}, + state::AppHandle, +}; + +pub async fn discord_auth(State(state): State<AppHandle>) -> Result<impl IntoResponse, AppError> { + let (auth_url, csrf_token) = state + .discord_client + .authorize_url(CsrfToken::new_random) + .add_scope(Scope::new("identify".to_string())) + .url(); + + // Store the token in the session and retrieve the session cookie. + let session_id = Id(i128::from_le_bytes(uuid::Uuid::new_v4().to_bytes_le())); + let store = state.session_store.clone(); + + store + .create(&mut Record { + id: session_id, + data: [( + CSRF_TOKEN.to_string(), + serde_json::to_value(csrf_token).unwrap(), + )] + .into(), + expiry_date: time::OffsetDateTime::now_utc() + + Duration::from_secs(state.local_config.oauth.session_lifespan), + }) + .await + .context("failed in inserting CSRF token into session")?; + + // Attach the session cookie to the response header + let cookie = + format!("{OAUTH_CSRF_COOKIE}={session_id}; SameSite=Lax; HttpOnly; Secure; Path=/"); + let mut headers = HeaderMap::new(); + headers.insert( + SET_COOKIE, + cookie.parse().context("failed to parse cookie")?, + ); + + Ok((headers, Redirect::to(auth_url.as_ref()))) +} diff --git a/crates/auth-service/src/state.rs b/crates/auth-service/src/state.rs new file mode 100644 index 0000000..5905948 --- /dev/null +++ b/crates/auth-service/src/state.rs @@ -0,0 +1,85 @@ +use std::{ops::Deref, sync::Arc}; + +use sellershut_core::profile::profile_client::ProfileClient; +use sqlx::PgPool; +use stack_up::Configuration; +use tokio::task::JoinHandle; +use tonic::transport::Endpoint; +use tower_sessions::{CachingSessionStore, ExpiredDeletion, session_store}; +use tower_sessions_moka_store::MokaStore; +use tower_sessions_sqlx_store::PostgresStore; +use tracing::error; + +use crate::{ + client::{OauthClient, discord::discord_client}, + cnfg::LocalConfig, + error::AppError, + server::grpc::interceptor::{Intercepted, MyInterceptor}, +}; + +#[derive(Clone)] +pub struct AppHandle(Arc<AppState>); + +impl Deref for AppHandle { + type Target = Arc<AppState>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Clone)] +pub struct Services { + pub postgres: PgPool, +} + +pub struct AppState { + pub services: Services, + pub local_config: LocalConfig, + pub discord_client: OauthClient, + pub http_client: reqwest::Client, + pub session_store: CachingSessionStore<MokaStore, PostgresStore>, + pub profile_client: ProfileClient<Intercepted>, +} + +impl AppState { + pub async fn create( + services: Services, + configuration: &Configuration, + ) -> Result<(AppHandle, JoinHandle<Result<(), session_store::Error>>), AppError> { + let local_config: LocalConfig = serde_json::from_value(configuration.misc.clone())?; + + let session_store_db = + tower_sessions_sqlx_store::PostgresStore::new(services.postgres.clone()); + session_store_db.migrate().await?; + let deletion_task = tokio::task::spawn( + session_store_db + .clone() + .continuously_delete_expired(tokio::time::Duration::from_secs(60)), + ); + let session_store_mem = MokaStore::new(Some(100)); + + let store = CachingSessionStore::new(session_store_mem, session_store_db); + + let discord_client = discord_client(&local_config.oauth.discord)?; + + let channel = Endpoint::new(local_config.profile_endpoint.to_string())? + .connect() + .await + .inspect_err(|e| error!("could not connect to profile service: {e}"))?; + + let profile_client = ProfileClient::with_interceptor(channel, MyInterceptor); + + Ok(( + AppHandle(Arc::new(Self { + services, + local_config, + discord_client, + http_client: reqwest::Client::new(), + session_store: store, + profile_client, + })), + deletion_task, + )) + } +} |