aboutsummaryrefslogtreecommitdiffstats
path: root/crates/api-auth/src/lib.rs
blob: 284b7723392fc618c0691007ec262ff1d12ca508 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#[cfg(feature = "discord")]
pub mod discord;

mod error;
use api_core::auth::AuthClientConfig;
use api_core::auth::provider::OauthProvider;
use api_core::models::user::User;
pub use error::AuthClientError;

use oauth2::{EndpointNotSet, EndpointSet};

type C = oauth2::basic::BasicClient<
    EndpointSet,
    EndpointNotSet,
    EndpointNotSet,
    EndpointNotSet,
    EndpointSet,
>;

#[derive(Clone, Debug)]
pub struct BasicClient(C);

#[async_trait::async_trait]
pub trait OauthDriver: Send + Sync + std::fmt::Debug {
    async fn get_auth_token(&self) -> Result<String, AuthError>;
    async fn get_user(&self) -> Result<User, AuthError>;
    async fn create_session(&self, user: &User);
}

use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use std::{convert::TryFrom, ops::Deref};

use crate::error::AuthError;

pub struct OauthService {
    clients: HashMap<OauthProvider, Arc<dyn OauthDriver>>,
}

impl Deref for BasicClient {
    type Target = C;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl TryFrom<AuthClientConfig> for BasicClient {
    type Error = AuthClientError;

    fn try_from(value: AuthClientConfig) -> Result<Self, Self::Error> {
        let auth_url = AuthUrl::new(value.auth_url).map_err(AuthClientError::InvalidAuthUrl)?;

        let token_url = TokenUrl::new(value.token_uri).map_err(AuthClientError::InvalidTokenUrl)?;

        let redirect_url =
            RedirectUrl::new(value.redirect_uri).map_err(AuthClientError::InvalidRedirectUrl)?;

        Ok(Self(
            oauth2::basic::BasicClient::new(ClientId::new(value.client_id))
                .set_client_secret(ClientSecret::new(value.client_secret))
                .set_auth_uri(auth_url)
                .set_token_uri(token_url)
                .set_redirect_uri(redirect_url),
        ))
    }
}