use api_core::models::user::User; use async_session::{Session, serde_json}; use async_trait::async_trait; use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; use sh_util::cache::{CacheKey, RedisManager}; use sqlx::PgPool; use crate::{ BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, client::AuthHttpClient, error::AuthError, }; // The user data we'll get back from Discord. // https://discord.com/developers/docs/resources/user#user-object-user-structure #[derive(Debug, Serialize, Deserialize)] struct DiscordUser { id: String, avatar: Option, username: String, discriminator: String, } impl From for User { fn from(value: DiscordUser) -> Self { todo!() } } #[derive(Clone)] pub struct AuthServiceDiscord { database: PgPool, cache: RedisManager, client: BasicClient, } impl AuthServiceDiscord { pub fn new(database: PgPool, client: BasicClient, cache: RedisManager) -> Self { Self { database, client, cache, } } } #[async_trait] impl OauthDriver for AuthServiceDiscord { async fn get_user(&self, client: &AuthHttpClient, code: &str) -> Result { // Get an auth token let token = self .client .exchange_code(AuthorizationCode::new(code.to_owned())) .request_async(client) .await .unwrap(); // Fetch user data from discord let user_data: DiscordUser = client // https://discord.com/developers/docs/resources/user#get-current-user .get("https://discordapp.com/api/users/@me") .bearer_auth(token.access_token().secret()) .send() .await .unwrap() .json::() .await .unwrap(); Ok(user_data.into()) } async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError> { let id = Session::id_from_cookie_value(cookie)?; let cache_key = CacheKey::Session(&id); let mut cache = self.cache.get().await.unwrap(); let session = cache.get::<_, String>(&cache_key).await?; let session: Session = serde_json::from_str(&session).map_err(|_e| AuthError::InvalidSession)?; match session.validate() { Some(session) => { // Extract the CSRF token from the session let stored_csrf_token = session.get::(CSRF_TOKEN); if let Some(stored) = stored_csrf_token { // Cleanup the CSRF token session cache.del::<_, ()>(cache_key).await?; // Validate CSRF token is the same as the one in the auth request if *stored.secret() != state { return Err(AuthError::TokenMismatch); } else { return Ok(()); } } else { return Err(AuthError::NoCSRFToken); } } None => return Err(AuthError::MissingSession), } } async fn create_oauth_session(&self) -> Result { let (auth_url, csrf_token) = self .client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) .url(); let mut session = Session::new(); session.insert(CSRF_TOKEN, &csrf_token).unwrap(); let cache_key = CacheKey::Session(session.id()); let mut cache = self.cache.get().await.unwrap(); cache .set::<_, _, ()>( cache_key, serde_json::to_string(&session).or(Err(AuthError::InvalidSession))?, ) .await?; let cookie = session .into_cookie_value() .ok_or(AuthError::MissingSession)?; Ok(SessionResponse { cookie_value: cookie, auth_url, }) } async fn save_session(&self, user: &User) -> Result<(), AuthError> { todo!() } }