aboutsummaryrefslogtreecommitdiffstats
path: root/crates/api-auth/src
diff options
context:
space:
mode:
authorrtkay123 <dev@kanjala.com>2026-04-10 23:48:24 +0200
committerrtkay123 <dev@kanjala.com>2026-04-10 23:48:24 +0200
commitf06288f156ccb8f9ebf35782a179bf57e6bc8fc2 (patch)
tree2e9eb80237094d930b4f3a54261fac0cb3350129 /crates/api-auth/src
parentbe2af8a5fe2e58953b4970e3fed970165fc4b4ca (diff)
downloadsellershut-f06288f156ccb8f9ebf35782a179bf57e6bc8fc2.tar.bz2
sellershut-f06288f156ccb8f9ebf35782a179bf57e6bc8fc2.zip
feat(auth): get user
Diffstat (limited to 'crates/api-auth/src')
-rw-r--r--crates/api-auth/src/client.rs58
-rw-r--r--crates/api-auth/src/discord/mod.rs77
-rw-r--r--crates/api-auth/src/error.rs7
-rw-r--r--crates/api-auth/src/lib.rs10
4 files changed, 144 insertions, 8 deletions
diff --git a/crates/api-auth/src/client.rs b/crates/api-auth/src/client.rs
new file mode 100644
index 0000000..d696162
--- /dev/null
+++ b/crates/api-auth/src/client.rs
@@ -0,0 +1,58 @@
+use std::pin::Pin;
+use std::{future::Future, ops::Deref};
+
+#[cfg(not(target_arch = "wasm32"))]
+use oauth2::HttpResponse;
+use oauth2::{AsyncHttpClient, HttpClientError, HttpRequest, http};
+
+#[derive(Clone)]
+pub struct AuthHttpClient(reqwest::Client);
+
+impl Deref for AuthHttpClient {
+ type Target = reqwest::Client;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<reqwest::Client> for AuthHttpClient {
+ fn from(value: reqwest::Client) -> Self {
+ Self(value)
+ }
+}
+
+impl<'c> AsyncHttpClient<'c> for AuthHttpClient {
+ type Error = HttpClientError<reqwest::Error>;
+
+ #[cfg(target_arch = "wasm32")]
+ type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + 'c>>;
+ #[cfg(not(target_arch = "wasm32"))]
+ type Future =
+ Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
+
+ fn call(&'c self, request: HttpRequest) -> Self::Future {
+ Box::pin(async move {
+ let response = self
+ .0
+ .execute(request.try_into().map_err(Box::new)?)
+ .await
+ .map_err(Box::new)?;
+
+ let mut builder = http::Response::builder().status(response.status());
+
+ #[cfg(not(target_arch = "wasm32"))]
+ {
+ builder = builder.version(response.version());
+ }
+
+ for (name, value) in response.headers().iter() {
+ builder = builder.header(name, value);
+ }
+
+ builder
+ .body(response.bytes().await.map_err(Box::new)?.to_vec())
+ .map_err(HttpClientError::Http)
+ })
+ }
+}
diff --git a/crates/api-auth/src/discord/mod.rs b/crates/api-auth/src/discord/mod.rs
index 1a7d47d..0844f58 100644
--- a/crates/api-auth/src/discord/mod.rs
+++ b/crates/api-auth/src/discord/mod.rs
@@ -1,12 +1,31 @@
use api_core::models::user::User;
use async_session::{Session, serde_json};
use async_trait::async_trait;
-use oauth2::{CsrfToken, Scope};
+use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse};
use redis::AsyncCommands;
+use serde::{Deserialize, Serialize};
use sh_util::cache::{CacheKey, RedisManager};
use sqlx::PgPool;
-use crate::{BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, error::AuthError};
+use crate::{
+ BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, client::AuthHttpClient, error::AuthError,
+};
+
+// The user data we'll get back from Discord.
+// https://discord.com/developers/docs/resources/user#user-object-user-structure
+#[derive(Debug, Serialize, Deserialize)]
+struct DiscordUser {
+ id: String,
+ avatar: Option<String>,
+ username: String,
+ discriminator: String,
+}
+
+impl From<DiscordUser> for User {
+ fn from(value: DiscordUser) -> Self {
+ todo!()
+ }
+}
#[derive(Clone)]
pub struct AuthServiceDiscord {
@@ -27,11 +46,57 @@ impl AuthServiceDiscord {
#[async_trait]
impl OauthDriver for AuthServiceDiscord {
- async fn get_auth_token(&self) -> Result<String, AuthError> {
- todo!()
+ async fn get_user(&self, client: &AuthHttpClient, code: &str) -> Result<User, AuthError> {
+ // Get an auth token
+ let token = self
+ .client
+ .exchange_code(AuthorizationCode::new(code.to_owned()))
+ .request_async(client)
+ .await
+ .unwrap();
+ // Fetch user data from discord
+ let user_data: DiscordUser = client
+ // https://discord.com/developers/docs/resources/user#get-current-user
+ .get("https://discordapp.com/api/users/@me")
+ .bearer_auth(token.access_token().secret())
+ .send()
+ .await
+ .unwrap()
+ .json::<DiscordUser>()
+ .await
+ .unwrap();
+
+ Ok(user_data.into())
}
- async fn get_user(&self) -> Result<User, AuthError> {
- todo!()
+ async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError> {
+ let id = Session::id_from_cookie_value(cookie)?;
+ let cache_key = CacheKey::Session(&id);
+ let mut cache = self.cache.get().await.unwrap();
+ let session = cache.get::<_, String>(&cache_key).await?;
+ let session: Session =
+ serde_json::from_str(&session).map_err(|_e| AuthError::InvalidSession)?;
+
+ match session.validate() {
+ Some(session) => {
+ // Extract the CSRF token from the session
+ let stored_csrf_token = session.get::<CsrfToken>(CSRF_TOKEN);
+
+ if let Some(stored) = stored_csrf_token {
+ // Cleanup the CSRF token session
+ cache.del::<_, ()>(cache_key).await?;
+
+ // Validate CSRF token is the same as the one in the auth request
+ if *stored.secret() != state {
+ return Err(AuthError::TokenMismatch);
+ } else {
+ return Ok(());
+ }
+ } else {
+ return Err(AuthError::NoCSRFToken);
+ }
+ }
+ None => return Err(AuthError::MissingSession),
+ }
}
async fn create_oauth_session(&self) -> Result<SessionResponse, AuthError> {
let (auth_url, csrf_token) = self
diff --git a/crates/api-auth/src/error.rs b/crates/api-auth/src/error.rs
index 72a7fba..2db3281 100644
--- a/crates/api-auth/src/error.rs
+++ b/crates/api-auth/src/error.rs
@@ -1,3 +1,4 @@
+use async_session::base64;
use thiserror::Error;
#[derive(Debug, Error)]
@@ -28,4 +29,10 @@ pub enum AuthError {
MissingSession,
#[error("invalid session")]
InvalidSession,
+ #[error("invalid session")]
+ CorruptedCookie(#[from] base64::DecodeError),
+ #[error("CSRF token mismatch")]
+ TokenMismatch,
+ #[error("CSRF token missing")]
+ NoCSRFToken,
}
diff --git a/crates/api-auth/src/lib.rs b/crates/api-auth/src/lib.rs
index 85fdb01..815b170 100644
--- a/crates/api-auth/src/lib.rs
+++ b/crates/api-auth/src/lib.rs
@@ -1,6 +1,8 @@
#[cfg(feature = "discord")]
pub mod discord;
+pub mod client;
+
mod error;
use api_core::auth::AuthClientConfig;
use api_core::models::user::User;
@@ -21,8 +23,12 @@ pub struct BasicClient(C);
#[async_trait::async_trait]
pub trait OauthDriver: Send + Sync {
- async fn get_auth_token(&self) -> Result<String, AuthError>;
- async fn get_user(&self) -> Result<User, AuthError>;
+ async fn get_user(
+ &self,
+ client: &client::AuthHttpClient,
+ code: &str,
+ ) -> Result<User, AuthError>;
+ async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError>;
async fn create_oauth_session(&self) -> Result<SessionResponse, AuthError>;
async fn save_session(&self, user: &User) -> Result<(), AuthError>;
}