Skip to content

Commit

Permalink
Implement extractors for AuthUser and SuperUser (#458)
Browse files Browse the repository at this point in the history
* start implementing authorisation

* feat(backend): implement extractor for AuthUser (user id)

* feat(backend): implement extractor for SuperUser (user id + authZ)

---------

Co-authored-by: kappamalone <[email protected]>
  • Loading branch information
KavikaPalletenne and Kappamalone authored Feb 11, 2024
1 parent 2c2ce07 commit 888a9f0
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 30 deletions.
6 changes: 3 additions & 3 deletions backend/prisma-cli/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ model User {
degree_starting_year Int?
created_at DateTime @default(now())
updated_at DateTime
role UserRole @default(USER)
role UserRole @default(User)
applications Application[]
OrganisationAdmins OrganisationAdmins[]
Ratings Ratings[]
}

enum UserRole {
USER
SUPERUSER
User
SuperUser
}

model Organisation {
Expand Down
7 changes: 5 additions & 2 deletions backend/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ edition = "2021"
[dependencies]
# Primary crates
tokio = { version = "1.32.0", features = ["macros", "rt-multi-thread"] }
axum = { version = "0.6.20", features = ["macros"] }
axum = { version = "0.6.20", features = ["macros", "headers"] }
axum-extra = "0.8.0"
sqlx = { version = "0.7.1", features = ["runtime-tokio-rustls", "postgres", "time", "uuid"] }

# Important secondary crates
anyhow = "1.0.75"
serde = { version = "1.0.188", features = ["derive"] }
reqwest = "0.11.20"
reqwest = { version = "0.11.20", features = ["json"] }
serde_json = "1.0.105"
chrono = { version = "0.4.26", features = ["serde"] }
oauth2 = "4.4.1"
log = "0.4.20"
uuid = { version = "1.5.0", features = ["serde", "v4"] }
jsonwebtoken = "9.1.0"
35 changes: 21 additions & 14 deletions backend/server/src/handler/auth.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,51 @@
use axum::Extension;
use crate::models::app::AppState;
use crate::models::auth::{AuthRequest, UserProfile};
use crate::service::auth::create_or_get_user_id;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Extension;
use log::error;
use oauth2::{AuthorizationCode, TokenResponse};
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use crate::models::app::AppState;
use crate::models::auth::{AuthRequest, UserProfile};
use crate::service::auth::create_or_get_user_id;
use oauth2::{AuthorizationCode, TokenResponse};

/// This function handles the passing in of the Google OAuth code. After allowing our app the
/// requested permissions, the user is redirected to this url on our server, where we use the
/// code to get the user's email address from Google's OpenID Connect API.
pub async fn google_callback(
State(state): State<AppState>,
Query(query): Query<AuthRequest>,
Extension(oauth_client): Extension<BasicClient>
Extension(oauth_client): Extension<BasicClient>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let token = match oauth_client
.exchange_code(AuthorizationCode::new(query.code))
.request_async(async_http_client)
.await {
.await
{
Ok(res) => res,
Err(e) => {
error!("An error occured while exchanging Google OAuth code");
return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))
return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()));
}
};

let profile = match state.ctx.get("https://openidconnect.googleapis.com/v1/userinfo")
let profile = match state
.ctx
.get("https://openidconnect.googleapis.com/v1/userinfo")
.bearer_auth(token.access_token().secret().to_owned())
.send().await {
.send()
.await
{
Ok(res) => res,
Err(e) => return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))
Err(e) => return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string())),
};

let profile = profile.json::<UserProfile>().await.unwrap();
// let profile = profile.json::<UserProfile>().await?;

let user_id = create_or_get_user_id(profile.email, state.db).await?;
// let user_id = create_or_get_user_id(profile.email, state.db).await?;

// TODO: Create a JWT from this user_id and return to the user.
}
Ok("woohoo")
}

40 changes: 36 additions & 4 deletions backend/server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,44 @@
use std::env;
use axum::{routing::get, Router};

mod service;
mod models;
use jsonwebtoken::{DecodingKey, EncodingKey};
use sqlx::postgres::PgPoolOptions;
use models::app::AppState;
mod handler;
mod models;
mod service;

#[tokio::main]
async fn main() {
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
// Initialise DB connection
let db_url = env::var("DATABASE_URL")
.expect("Error getting DATABASE_URL")
.to_string();
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(db_url.as_str()).await.expect("Cannot connect to database");

// Initialise JWT settings
let jwt_secret = env::var("JWT_SECRET")
.expect("Error getting JWT_SECRET")
.to_string();
// let jwt_secret = "I want to cry";
let encoding_key = EncodingKey::from_secret(jwt_secret.as_bytes());
let decoding_key = DecodingKey::from_secret(jwt_secret.as_bytes());

// Initialise reqwest client
let ctx = reqwest::Client::new();

// Add all data to AppState
let state = AppState {
db: pool,
ctx,
encoding_key,
decoding_key,
};

let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.with_state(state);

axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
.serve(app.into_make_service())
Expand Down
10 changes: 7 additions & 3 deletions backend/server/src/models/app.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use sqlx::{Pool, Postgres};
use jsonwebtoken::{DecodingKey, EncodingKey};
use reqwest::Client as ReqwestClient;
use sqlx::{Pool, Postgres};

#[derive(Clone)]
pub struct AppState {
pub db: Pool<Postgres>,
pub ctx: ReqwestClient
}
pub ctx: ReqwestClient,
pub decoding_key: DecodingKey,
pub encoding_key: EncodingKey,
}
89 changes: 87 additions & 2 deletions backend/server/src/models/auth.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,96 @@
use axum::{async_trait, headers, http::{self, Request}, RequestPartsExt};
use axum::extract::{FromRef, FromRequestParts, TypedHeader};
use axum::http::request::Parts;
use axum::response::{IntoResponse, Redirect, Response};
use serde::{Deserialize, Serialize};
use crate::models::app::AppState;
use crate::service::auth::is_super_user;
use crate::service::jwt::decode_auth_token;

#[derive(Deserialize, Serialize)]
pub struct AuthRequest {
pub code: String
pub code: String,
}

#[derive(Deserialize, Serialize)]
pub struct UserProfile {
pub email: String
pub email: String,
}

pub struct AuthRedirect;

impl IntoResponse for AuthRedirect {
fn into_response(self) -> Response {
// TODO: Fix this redirect to point to front end login page
Redirect::temporary("/auth/google").into_response()
}
}

#[derive(Deserialize, Serialize)]
pub struct AuthUser {
pub user_id: i64,
}

#[async_trait]
impl<S> FromRequestParts<S> for AuthUser
where
AppState: FromRef<S>,
S: Send + Sync,
{
type Rejection = AuthRedirect;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let app_state = AppState::from_ref(state);
let decoding_key = &app_state.decoding_key;
let extracted_cookies = parts
.extract::<TypedHeader<headers::Cookie>>()
.await;

if let Ok(cookies) = extracted_cookies {
let token = cookies.get("auth_token").ok_or(AuthRedirect)?;
let claims = decode_auth_token(token.to_string(), decoding_key).ok_or(AuthRedirect)?;

Ok(AuthUser { user_id: claims.sub })
} else {
Err(AuthRedirect)
}
}
}

#[derive(Deserialize, Serialize)]
pub struct SuperUser {
pub user_id: i64,
}

#[async_trait]
impl<S> FromRequestParts<S> for SuperUser
where
AppState: FromRef<S>,
S: Send + Sync,
{
type Rejection = AuthRedirect;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let app_state = AppState::from_ref(state);
let decoding_key = &app_state.decoding_key;
let extracted_cookies = parts
.extract::<TypedHeader<headers::Cookie>>()
.await;

if let Ok(cookies) = extracted_cookies {
let token = cookies.get("auth_token").ok_or(AuthRedirect)?;
let claims = decode_auth_token(token.to_string(), decoding_key).ok_or(AuthRedirect)?;

let pool = &app_state.db;
let possible_user = is_super_user(claims.sub, pool).await;

if let Ok(is_auth_user) = possible_user {
if is_auth_user {
return Ok(SuperUser { user_id: claims.sub });
}
}
}

Err(AuthRedirect)
}
}
3 changes: 2 additions & 1 deletion backend/server/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod auth;
pub mod app;
pub mod app;
pub mod user;
8 changes: 8 additions & 0 deletions backend/server/src/models/user.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, sqlx::Type, Clone)]
#[sqlx(type_name = "user_role", rename_all = "PascalCase")]
pub enum UserRole {
User,
SuperUser,
}
11 changes: 10 additions & 1 deletion backend/server/src/service/auth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Result;
use jsonwebtoken::{DecodingKey, EncodingKey};
use sqlx::{Pool, Postgres};

/// Checks if a user exists in DB based on given email address. If so, their user_id is returned.
Expand All @@ -11,4 +12,12 @@ pub async fn create_or_get_user_id(email: String, pool: Pool<Postgres>) -> Resul

let user_id = 1;
return Ok(user_id);
}
}

pub async fn is_super_user(user_id: i64, pool: &Pool<Postgres>) -> Result<bool> {
let is_super_user = sqlx::query!("SELECT EXISTS(SELECT 1 FROM users WHERE id = $1 AND role = $2)", user_id, UserRole::SuperUser)
.fetch_one(pool)
.await?;

Ok(is_super_user.exists.unwrap())
}
38 changes: 38 additions & 0 deletions backend/server/src/service/jwt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use axum::extract::State;
use jsonwebtoken::{Algorithm, DecodingKey};
use jsonwebtoken::{decode, Validation};
use serde::{Deserialize, Serialize};
use uuid::Uuid;

use crate::AppState;

#[derive(Debug, Deserialize, Serialize)]
pub struct AuthorizationJwtPayload {
pub iss: String, // issuer
pub sub: i64, // subject (user's id)
pub jti: Uuid, // id
pub aud: Vec<String>, // audience (uri the JWT is meant for)

// Time-based validity
pub exp: i64, // expiry (UNIX timestamp)
pub nbf: i64, // not-valid-before (UNIX timestamp)
pub iat: i64, // issued-at (UNIX timestamp)

pub username: String, // username
}

pub fn decode_auth_token(
token: String,
decoding_key: &DecodingKey,
) -> Option<AuthorizationJwtPayload> {
let decode_token = decode::<AuthorizationJwtPayload>(
token.as_str(),
decoding_key,
&Validation::new(Algorithm::HS256),
);

return match decode_token {
Ok(token) => Option::from(token.claims),
Err(_err) => None::<AuthorizationJwtPayload>,
};
}
1 change: 1 addition & 0 deletions backend/server/src/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod auth;
pub mod jwt;
pub mod oauth2;

0 comments on commit 888a9f0

Please sign in to comment.