1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
use api_core::models::user::User;
use async_session::{Session, serde_json};
use async_trait::async_trait;
use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse};
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use sh_util::cache::{CacheKey, RedisManager};
use sqlx::PgPool;
use crate::{
BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, client::AuthHttpClient, error::AuthError,
};
// The user data we'll get back from Discord.
// https://discord.com/developers/docs/resources/user#user-object-user-structure
#[derive(Debug, Serialize, Deserialize)]
struct DiscordUser {
id: String,
avatar: Option<String>,
username: String,
discriminator: String,
}
impl From<DiscordUser> for User {
fn from(value: DiscordUser) -> Self {
todo!()
}
}
#[derive(Clone)]
pub struct AuthServiceDiscord {
database: PgPool,
cache: RedisManager,
client: BasicClient,
}
impl AuthServiceDiscord {
pub fn new(database: PgPool, client: BasicClient, cache: RedisManager) -> Self {
Self {
database,
client,
cache,
}
}
}
#[async_trait]
impl OauthDriver for AuthServiceDiscord {
async fn get_user(&self, client: &AuthHttpClient, code: &str) -> Result<User, AuthError> {
// Get an auth token
let token = self
.client
.exchange_code(AuthorizationCode::new(code.to_owned()))
.request_async(client)
.await
.unwrap();
// Fetch user data from discord
let user_data: DiscordUser = client
// https://discord.com/developers/docs/resources/user#get-current-user
.get("https://discordapp.com/api/users/@me")
.bearer_auth(token.access_token().secret())
.send()
.await
.unwrap()
.json::<DiscordUser>()
.await
.unwrap();
Ok(user_data.into())
}
async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError> {
let id = Session::id_from_cookie_value(cookie)?;
let cache_key = CacheKey::Session(&id);
let mut cache = self.cache.get().await.unwrap();
let session = cache.get::<_, String>(&cache_key).await?;
let session: Session =
serde_json::from_str(&session).map_err(|_e| AuthError::InvalidSession)?;
match session.validate() {
Some(session) => {
// Extract the CSRF token from the session
let stored_csrf_token = session.get::<CsrfToken>(CSRF_TOKEN);
if let Some(stored) = stored_csrf_token {
// Cleanup the CSRF token session
cache.del::<_, ()>(cache_key).await?;
// Validate CSRF token is the same as the one in the auth request
if *stored.secret() != state {
return Err(AuthError::TokenMismatch);
} else {
return Ok(());
}
} else {
return Err(AuthError::NoCSRFToken);
}
}
None => return Err(AuthError::MissingSession),
}
}
async fn create_oauth_session(&self) -> Result<SessionResponse, AuthError> {
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("identify".to_string()))
.url();
let mut session = Session::new();
session.insert(CSRF_TOKEN, &csrf_token).unwrap();
let cache_key = CacheKey::Session(session.id());
let mut cache = self.cache.get().await.unwrap();
cache
.set::<_, _, ()>(
cache_key,
serde_json::to_string(&session).or(Err(AuthError::InvalidSession))?,
)
.await?;
let cookie = session
.into_cookie_value()
.ok_or(AuthError::MissingSession)?;
Ok(SessionResponse {
cookie_value: cookie,
auth_url,
})
}
async fn save_session(&self, user: &User) -> Result<(), AuthError> {
todo!()
}
}
|