From f06288f156ccb8f9ebf35782a179bf57e6bc8fc2 Mon Sep 17 00:00:00 2001 From: rtkay123 Date: Fri, 10 Apr 2026 23:48:24 +0200 Subject: feat(auth): get user --- crates/api-auth/Cargo.toml | 1 + crates/api-auth/src/client.rs | 58 ++++++++++++++++++++++++++++ crates/api-auth/src/discord/mod.rs | 77 +++++++++++++++++++++++++++++++++++--- crates/api-auth/src/error.rs | 7 ++++ crates/api-auth/src/lib.rs | 10 ++++- 5 files changed, 145 insertions(+), 8 deletions(-) create mode 100644 crates/api-auth/src/client.rs (limited to 'crates/api-auth') diff --git a/crates/api-auth/Cargo.toml b/crates/api-auth/Cargo.toml index 5ce0647..518762b 100644 --- a/crates/api-auth/Cargo.toml +++ b/crates/api-auth/Cargo.toml @@ -12,6 +12,7 @@ api-core = { workspace = true, features = ["auth", "users"] } async-trait.workspace = true oauth2 = "5.0.0" redis.workspace = true +reqwest = { workspace = true, features = ["json"] } secrecy.workspace = true serde.workspace = true sh-util = { workspace = true, optional = true } diff --git a/crates/api-auth/src/client.rs b/crates/api-auth/src/client.rs new file mode 100644 index 0000000..d696162 --- /dev/null +++ b/crates/api-auth/src/client.rs @@ -0,0 +1,58 @@ +use std::pin::Pin; +use std::{future::Future, ops::Deref}; + +#[cfg(not(target_arch = "wasm32"))] +use oauth2::HttpResponse; +use oauth2::{AsyncHttpClient, HttpClientError, HttpRequest, http}; + +#[derive(Clone)] +pub struct AuthHttpClient(reqwest::Client); + +impl Deref for AuthHttpClient { + type Target = reqwest::Client; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for AuthHttpClient { + fn from(value: reqwest::Client) -> Self { + Self(value) + } +} + +impl<'c> AsyncHttpClient<'c> for AuthHttpClient { + type Error = HttpClientError; + + #[cfg(target_arch = "wasm32")] + type Future = Pin> + 'c>>; + #[cfg(not(target_arch = "wasm32"))] + type Future = + Pin> + Send + Sync + 'c>>; + + fn call(&'c self, request: HttpRequest) -> Self::Future { + Box::pin(async move { + let response = self + .0 + .execute(request.try_into().map_err(Box::new)?) + .await + .map_err(Box::new)?; + + let mut builder = http::Response::builder().status(response.status()); + + #[cfg(not(target_arch = "wasm32"))] + { + builder = builder.version(response.version()); + } + + for (name, value) in response.headers().iter() { + builder = builder.header(name, value); + } + + builder + .body(response.bytes().await.map_err(Box::new)?.to_vec()) + .map_err(HttpClientError::Http) + }) + } +} diff --git a/crates/api-auth/src/discord/mod.rs b/crates/api-auth/src/discord/mod.rs index 1a7d47d..0844f58 100644 --- a/crates/api-auth/src/discord/mod.rs +++ b/crates/api-auth/src/discord/mod.rs @@ -1,12 +1,31 @@ use api_core::models::user::User; use async_session::{Session, serde_json}; use async_trait::async_trait; -use oauth2::{CsrfToken, Scope}; +use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; use redis::AsyncCommands; +use serde::{Deserialize, Serialize}; use sh_util::cache::{CacheKey, RedisManager}; use sqlx::PgPool; -use crate::{BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, error::AuthError}; +use crate::{ + BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, client::AuthHttpClient, error::AuthError, +}; + +// The user data we'll get back from Discord. +// https://discord.com/developers/docs/resources/user#user-object-user-structure +#[derive(Debug, Serialize, Deserialize)] +struct DiscordUser { + id: String, + avatar: Option, + username: String, + discriminator: String, +} + +impl From for User { + fn from(value: DiscordUser) -> Self { + todo!() + } +} #[derive(Clone)] pub struct AuthServiceDiscord { @@ -27,11 +46,57 @@ impl AuthServiceDiscord { #[async_trait] impl OauthDriver for AuthServiceDiscord { - async fn get_auth_token(&self) -> Result { - todo!() + async fn get_user(&self, client: &AuthHttpClient, code: &str) -> Result { + // Get an auth token + let token = self + .client + .exchange_code(AuthorizationCode::new(code.to_owned())) + .request_async(client) + .await + .unwrap(); + // Fetch user data from discord + let user_data: DiscordUser = client + // https://discord.com/developers/docs/resources/user#get-current-user + .get("https://discordapp.com/api/users/@me") + .bearer_auth(token.access_token().secret()) + .send() + .await + .unwrap() + .json::() + .await + .unwrap(); + + Ok(user_data.into()) } - async fn get_user(&self) -> Result { - todo!() + async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError> { + let id = Session::id_from_cookie_value(cookie)?; + let cache_key = CacheKey::Session(&id); + let mut cache = self.cache.get().await.unwrap(); + let session = cache.get::<_, String>(&cache_key).await?; + let session: Session = + serde_json::from_str(&session).map_err(|_e| AuthError::InvalidSession)?; + + match session.validate() { + Some(session) => { + // Extract the CSRF token from the session + let stored_csrf_token = session.get::(CSRF_TOKEN); + + if let Some(stored) = stored_csrf_token { + // Cleanup the CSRF token session + cache.del::<_, ()>(cache_key).await?; + + // Validate CSRF token is the same as the one in the auth request + if *stored.secret() != state { + return Err(AuthError::TokenMismatch); + } else { + return Ok(()); + } + } else { + return Err(AuthError::NoCSRFToken); + } + } + None => return Err(AuthError::MissingSession), + } } async fn create_oauth_session(&self) -> Result { let (auth_url, csrf_token) = self diff --git a/crates/api-auth/src/error.rs b/crates/api-auth/src/error.rs index 72a7fba..2db3281 100644 --- a/crates/api-auth/src/error.rs +++ b/crates/api-auth/src/error.rs @@ -1,3 +1,4 @@ +use async_session::base64; use thiserror::Error; #[derive(Debug, Error)] @@ -28,4 +29,10 @@ pub enum AuthError { MissingSession, #[error("invalid session")] InvalidSession, + #[error("invalid session")] + CorruptedCookie(#[from] base64::DecodeError), + #[error("CSRF token mismatch")] + TokenMismatch, + #[error("CSRF token missing")] + NoCSRFToken, } diff --git a/crates/api-auth/src/lib.rs b/crates/api-auth/src/lib.rs index 85fdb01..815b170 100644 --- a/crates/api-auth/src/lib.rs +++ b/crates/api-auth/src/lib.rs @@ -1,6 +1,8 @@ #[cfg(feature = "discord")] pub mod discord; +pub mod client; + mod error; use api_core::auth::AuthClientConfig; use api_core::models::user::User; @@ -21,8 +23,12 @@ pub struct BasicClient(C); #[async_trait::async_trait] pub trait OauthDriver: Send + Sync { - async fn get_auth_token(&self) -> Result; - async fn get_user(&self) -> Result; + async fn get_user( + &self, + client: &client::AuthHttpClient, + code: &str, + ) -> Result; + async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError>; async fn create_oauth_session(&self) -> Result; async fn save_session(&self, user: &User) -> Result<(), AuthError>; } -- cgit v1.2.3