From 92c1529ade5b87dd11d76eac09246c14533f4beb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Tue, 8 Oct 2024 14:01:41 +0200 Subject: [PATCH 01/21] poc for security restricted guc --- src/gucs.rs | 13 +++++++++++++ src/lib.rs | 13 +++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 src/gucs.rs diff --git a/src/gucs.rs b/src/gucs.rs new file mode 100644 index 0000000..56bac69 --- /dev/null +++ b/src/gucs.rs @@ -0,0 +1,13 @@ +use pgrx::*; +use std::ffi::CStr; + +pub static AUTH_FOO: GucSetting> = GucSetting::>::new(None); + +pub fn init() { + GucRegistry::define_string_guc( + "auth.foo", + "foo", + "bar", + &AUTH_FOO, + GucContext::Suset, GucFlags::NOT_WHILE_SEC_REST); +} diff --git a/src/lib.rs b/src/lib.rs index d32bd8a..8fac896 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +mod gucs; + use pgrx::prelude::*; pgrx::pg_module_magic!(); @@ -13,6 +15,12 @@ macro_rules! error_code { }}; } +#[allow(non_snake_case)] +#[pg_guard] +pub unsafe extern "C" fn _PG_init() { + gucs::init(); +} + #[pg_schema] pub mod auth { use std::cell::{OnceCell, RefCell}; @@ -179,6 +187,11 @@ pub mod auth { } } + #[pg_extern] + pub fn foo() -> String { + crate::gucs::AUTH_FOO.get().unwrap().to_owned().into_string().unwrap() + } + /// Decrypt the JWT and store it. /// /// # Panics From 59659a656233726c6b4f115e40625f9c5df34f20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Tue, 8 Oct 2024 20:42:25 +0200 Subject: [PATCH 02/21] use neon.auth namespace --- src/gucs.rs | 18 +++++++++++++----- src/lib.rs | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/gucs.rs b/src/gucs.rs index 56bac69..1aa700c 100644 --- a/src/gucs.rs +++ b/src/gucs.rs @@ -1,13 +1,21 @@ use pgrx::*; use std::ffi::CStr; -pub static AUTH_FOO: GucSetting> = GucSetting::>::new(None); +pub static NEON_AUTH_JWK: GucSetting> = GucSetting::>::new(None); +pub static NEON_AUTH_JWT: GucSetting> = GucSetting::>::new(None); pub fn init() { GucRegistry::define_string_guc( - "auth.foo", - "foo", - "bar", - &AUTH_FOO, + "neon.auth.jwk", + "JSON Web Key (JWK) userd for JWT validation", + "Generated per connection by Neon local proxy", + &NEON_AUTH_JWK, + GucContext::SuBackend, GucFlags::NOT_WHILE_SEC_REST); + + GucRegistry::define_string_guc( + "neon.auth.jwt", + "JSON Web Token (JWT) used for query authorization", + "Represents user session related claims like user id", + &NEON_AUTH_JWT, GucContext::Suset, GucFlags::NOT_WHILE_SEC_REST); } diff --git a/src/lib.rs b/src/lib.rs index 8fac896..e526146 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -189,7 +189,7 @@ pub mod auth { #[pg_extern] pub fn foo() -> String { - crate::gucs::AUTH_FOO.get().unwrap().to_owned().into_string().unwrap() + crate::gucs::NEON_AUTH_JWT.get().unwrap().to_owned().into_string().unwrap() } /// Decrypt the JWT and store it. From dbd61356efa04360e2906b97b452bcb0acfb3ad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 08:47:42 +0200 Subject: [PATCH 03/21] init() expects JWK in runtime params --- .gitignore | 1 + src/gucs.rs | 10 +++++---- src/lib.rs | 58 ++++++++++++++++++++++++++++++++++++++--------------- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index 3ea57d4..c619c96 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /target *.iml **/*.rs.bk +*.swp diff --git a/src/gucs.rs b/src/gucs.rs index 1aa700c..29fd430 100644 --- a/src/gucs.rs +++ b/src/gucs.rs @@ -1,21 +1,23 @@ use pgrx::*; use std::ffi::CStr; +pub static NEON_AUTH_JWK_RUNTIME_PARAM: &'static str = "neon.auth.jwk"; pub static NEON_AUTH_JWK: GucSetting> = GucSetting::>::new(None); +pub static NEON_AUTH_JWT_RUNTIME_PARAM: &'static str = "neon.auth.jwt"; pub static NEON_AUTH_JWT: GucSetting> = GucSetting::>::new(None); pub fn init() { GucRegistry::define_string_guc( - "neon.auth.jwk", + &NEON_AUTH_JWK_RUNTIME_PARAM, "JSON Web Key (JWK) userd for JWT validation", "Generated per connection by Neon local proxy", &NEON_AUTH_JWK, - GucContext::SuBackend, GucFlags::NOT_WHILE_SEC_REST); + GucContext::Suset, GucFlags::NOT_WHILE_SEC_REST); GucRegistry::define_string_guc( - "neon.auth.jwt", + &NEON_AUTH_JWT_RUNTIME_PARAM, "JSON Web Token (JWT) used for query authorization", - "Represents user session related claims like user id", + "Represents authenticated user session related claims like user ID", &NEON_AUTH_JWT, GucContext::Suset, GucFlags::NOT_WHILE_SEC_REST); } diff --git a/src/lib.rs b/src/lib.rs index e526146..0b7449a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,8 @@ pub mod auth { use pgrx::JsonB; use serde::de::DeserializeOwned; + use crate::gucs::{NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM}; + type Object = serde_json::Map; thread_local! { @@ -56,8 +58,20 @@ pub mod auth { /// This function will panic if called multiple times per session. /// This is to prevent replacing the key mid-session. #[pg_extern] - pub fn init(kid: i64, s: JsonB) { - let key: JwkEcKey = serde_json::from_value(s.0).unwrap_or_else(|e| { + pub fn init(kid: i64) { + let jwk = NEON_AUTH_JWK.get().unwrap_or_else(|| { + error_code!( + PgSqlErrorCode::ERRCODE_NO_DATA, + format!("Missing runtime parameter: {}", NEON_AUTH_JWK_RUNTIME_PARAM) + ) + }).to_str().unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, + format!("Couldn't parse {}", NEON_AUTH_JWK_RUNTIME_PARAM), + e.to_string() + ) + }); + let key: JwkEcKey = serde_json::from_str(jwk).unwrap_or_else(|e| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "session init requires an ES256 JWK", @@ -294,6 +308,11 @@ mod tests { use serde_json::json; use crate::auth; + use crate::gucs::NEON_AUTH_JWK_RUNTIME_PARAM; + + fn set_jwk_in_guc(key: String) { + Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWK_RUNTIME_PARAM, key).as_str()).unwrap(); + } fn sign_jwt(sk: &SigningKey, header: &str, payload: impl Display) -> String { let header = Base64UrlUnpadded::encode_string(header.as_bytes()); @@ -312,9 +331,10 @@ mod tests { let point = sk.verifying_key().to_encoded_point(false); let jwk = JwkEcKey::from_encoded_point::(&point).unwrap(); let jwk = serde_json::to_value(&jwk).unwrap(); + set_jwk_in_guc(serde_json::to_string(&jwk).unwrap()); - auth::init(1, JsonB(jwk.clone())); - auth::init(2, JsonB(jwk)); + auth::init(1); + auth::init(2); } #[pg_test] @@ -322,9 +342,10 @@ mod tests { fn wrong_pid() { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); + let jwk = serde_json::to_string(&jwk).unwrap(); + set_jwk_in_guc(jwk); - auth::init(1, jwk); + auth::init(1); auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":2}"#, r#"{"jti":1}"#)); } @@ -333,9 +354,10 @@ mod tests { fn wrong_txid() { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); + let jwk = serde_json::to_string(&jwk).unwrap(); + set_jwk_in_guc(jwk); - auth::init(1, jwk); + auth::init(1); auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#)); auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#)); } @@ -345,9 +367,10 @@ mod tests { fn invalid_nbf() { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); + let jwk = serde_json::to_string(&jwk).unwrap(); + set_jwk_in_guc(jwk); - auth::init(1, jwk); + auth::init(1); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -365,9 +388,10 @@ mod tests { fn invalid_exp() { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); + let jwk = serde_json::to_string(&jwk).unwrap(); + set_jwk_in_guc(jwk); - auth::init(1, jwk); + auth::init(1); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -384,9 +408,10 @@ mod tests { fn valid_time() { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); + let jwk = serde_json::to_string(&jwk).unwrap(); + set_jwk_in_guc(jwk); - auth::init(1, jwk); + auth::init(1); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -408,9 +433,10 @@ mod tests { fn test_pg_session_jwt() { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = JsonB(serde_json::to_value(&jwk).unwrap()); + let jwk = serde_json::to_string(&jwk).unwrap(); + set_jwk_in_guc(jwk); - auth::init(1, jwk); + auth::init(1); let header = r#"{"kid":1}"#; auth::jwt_session_init(&sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#)); From c564509db48783ea230a494d6391bd80d5db4626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 08:53:11 +0200 Subject: [PATCH 04/21] GUC_PGC_SUBACKEND placeholder --- src/gucs.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gucs.rs b/src/gucs.rs index 29fd430..49b9374 100644 --- a/src/gucs.rs +++ b/src/gucs.rs @@ -12,7 +12,9 @@ pub fn init() { "JSON Web Key (JWK) userd for JWT validation", "Generated per connection by Neon local proxy", &NEON_AUTH_JWK, - GucContext::Suset, GucFlags::NOT_WHILE_SEC_REST); + GucContext::Suset, // we should use GucContext::SuBackend but this breaks unit tests + GucFlags::NOT_WHILE_SEC_REST + ); GucRegistry::define_string_guc( &NEON_AUTH_JWT_RUNTIME_PARAM, From 60c9c215feb0cf878aea1ab0e9779876d5c52159 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 08:58:44 +0200 Subject: [PATCH 05/21] document intention where this draft is going --- src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 0b7449a..4d12da9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -252,6 +252,9 @@ pub mod auth { /// Extract a value from the shared state. #[pg_extern] pub fn session(s: &str) -> JsonB { + // todo: check if JWK is set in thread_local, if not assume its bgworker and call dedicated + // function that uses JWT_CACHE and relies only on runtime parameters (neon.auth.jwk and + // neon.auth.jwt) JWT.with_borrow(|j| { JsonB( j.as_ref() From 24361bea1bcb883ed5dafa711bdcffa6bb768456 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 09:53:50 +0200 Subject: [PATCH 06/21] jwt_session_init() seeds JWT runtime param --- src/lib.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 4d12da9..99862ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,7 @@ pub mod auth { use pgrx::JsonB; use serde::de::DeserializeOwned; - use crate::gucs::{NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM}; + use crate::gucs::{NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT_RUNTIME_PARAM}; type Object = serde_json::Map; @@ -213,6 +213,14 @@ pub mod auth { /// This function will panic if the JWT could not be verified. #[pg_extern] pub fn jwt_session_init(s: &str) { + Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, s).as_str()).unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED, + format!("Couldn't set {}", NEON_AUTH_JWT_RUNTIME_PARAM), + e.to_string(), + ) + }); + let key = JWK.with(|b| { b.get() .unwrap_or_else(|| { From d073a88c714cd3ca068b6101bba2c4f356778594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 09:54:16 +0200 Subject: [PATCH 07/21] refactor --- src/gucs.rs | 8 ++++---- src/lib.rs | 9 ++------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/gucs.rs b/src/gucs.rs index 49b9374..6273d19 100644 --- a/src/gucs.rs +++ b/src/gucs.rs @@ -1,14 +1,14 @@ use pgrx::*; use std::ffi::CStr; -pub static NEON_AUTH_JWK_RUNTIME_PARAM: &'static str = "neon.auth.jwk"; +pub static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; pub static NEON_AUTH_JWK: GucSetting> = GucSetting::>::new(None); -pub static NEON_AUTH_JWT_RUNTIME_PARAM: &'static str = "neon.auth.jwt"; +pub static NEON_AUTH_JWT_RUNTIME_PARAM: &str = "neon.auth.jwt"; pub static NEON_AUTH_JWT: GucSetting> = GucSetting::>::new(None); pub fn init() { GucRegistry::define_string_guc( - &NEON_AUTH_JWK_RUNTIME_PARAM, + NEON_AUTH_JWK_RUNTIME_PARAM, "JSON Web Key (JWK) userd for JWT validation", "Generated per connection by Neon local proxy", &NEON_AUTH_JWK, @@ -17,7 +17,7 @@ pub fn init() { ); GucRegistry::define_string_guc( - &NEON_AUTH_JWT_RUNTIME_PARAM, + NEON_AUTH_JWT_RUNTIME_PARAM, "JSON Web Token (JWT) used for query authorization", "Represents authenticated user session related claims like user ID", &NEON_AUTH_JWT, diff --git a/src/lib.rs b/src/lib.rs index 99862ed..fe82203 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,7 +68,7 @@ pub mod auth { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, format!("Couldn't parse {}", NEON_AUTH_JWK_RUNTIME_PARAM), - e.to_string() + e.to_string(), ) }); let key: JwkEcKey = serde_json::from_str(jwk).unwrap_or_else(|e| { @@ -201,11 +201,6 @@ pub mod auth { } } - #[pg_extern] - pub fn foo() -> String { - crate::gucs::NEON_AUTH_JWT.get().unwrap().to_owned().into_string().unwrap() - } - /// Decrypt the JWT and store it. /// /// # Panics @@ -314,7 +309,7 @@ mod tests { elliptic_curve::JwkEcKey, }; use p256::{NistP256, PublicKey}; - use pgrx::{prelude::*, JsonB}; + use pgrx::prelude::*; use rand::rngs::OsRng; use serde_json::json; From d17ebef7be8f526fa99c5cf5a273809cbe6a588d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 10:01:32 +0200 Subject: [PATCH 08/21] unit tests --- src/lib.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index fe82203..bd9411a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -314,7 +314,7 @@ mod tests { use serde_json::json; use crate::auth; - use crate::gucs::NEON_AUTH_JWK_RUNTIME_PARAM; + use crate::gucs::{NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT}; fn set_jwk_in_guc(key: String) { Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWK_RUNTIME_PARAM, key).as_str()).unwrap(); @@ -445,10 +445,14 @@ mod tests { auth::init(1); let header = r#"{"kid":1}"#; - auth::jwt_session_init(&sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#)); + let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); + auth::jwt_session_init(&jwt); + assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); assert_eq!(auth::user_id(), "foo"); - auth::jwt_session_init(&sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#)); + let jwt = sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#); + auth::jwt_session_init(&jwt); + assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); assert_eq!(auth::user_id(), "bar"); } } From b74378db0ba70bc0e548814728e43782209ed99c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 11:54:51 +0200 Subject: [PATCH 09/21] refactor --- src/lib.rs | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index bd9411a..49b305a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,7 @@ pub mod auth { use pgrx::JsonB; use serde::de::DeserializeOwned; - use crate::gucs::{NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT_RUNTIME_PARAM}; + use crate::gucs::{NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM}; type Object = serde_json::Map; @@ -207,15 +207,30 @@ pub mod auth { /// /// This function will panic if the JWT could not be verified. #[pg_extern] - pub fn jwt_session_init(s: &str) { - Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, s).as_str()).unwrap_or_else(|e| { + pub fn jwt_session_init(jwt: &str) { + Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, jwt).as_str()).unwrap_or_else(|e| { error_code!( PgSqlErrorCode::ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED, format!("Couldn't set {}", NEON_AUTH_JWT_RUNTIME_PARAM), e.to_string(), ) }); + set_jwt_cache() + } + fn set_jwt_cache() { + let jwt = NEON_AUTH_JWT.get().unwrap_or_else(|| { + error_code!( + PgSqlErrorCode::ERRCODE_NO_DATA, + format!("Missing runtime parameter: {}", NEON_AUTH_JWT_RUNTIME_PARAM) + ) + }).to_str().unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, + format!("Couldn't parse {}", NEON_AUTH_JWT_RUNTIME_PARAM), + e.to_string(), + ) + }); let key = JWK.with(|b| { b.get() .unwrap_or_else(|| { @@ -226,7 +241,7 @@ pub mod auth { }) .clone() }); - let (body, sig) = s.rsplit_once('.').unwrap_or_else(|| { + let (body, sig) = jwt.rsplit_once('.').unwrap_or_else(|| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "invalid JWT encoding", From 2b01ed1744b6d8778ab34e3bf7394bfe4d5637c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 11:55:01 +0200 Subject: [PATCH 10/21] linter --- src/gucs.rs | 14 ++++++---- src/lib.rs | 75 ++++++++++++++++++++++++++++++----------------------- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/gucs.rs b/src/gucs.rs index 6273d19..6daa779 100644 --- a/src/gucs.rs +++ b/src/gucs.rs @@ -2,9 +2,11 @@ use pgrx::*; use std::ffi::CStr; pub static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; -pub static NEON_AUTH_JWK: GucSetting> = GucSetting::>::new(None); +pub static NEON_AUTH_JWK: GucSetting> = + GucSetting::>::new(None); pub static NEON_AUTH_JWT_RUNTIME_PARAM: &str = "neon.auth.jwt"; -pub static NEON_AUTH_JWT: GucSetting> = GucSetting::>::new(None); +pub static NEON_AUTH_JWT: GucSetting> = + GucSetting::>::new(None); pub fn init() { GucRegistry::define_string_guc( @@ -13,13 +15,15 @@ pub fn init() { "Generated per connection by Neon local proxy", &NEON_AUTH_JWK, GucContext::Suset, // we should use GucContext::SuBackend but this breaks unit tests - GucFlags::NOT_WHILE_SEC_REST - ); + GucFlags::NOT_WHILE_SEC_REST, + ); GucRegistry::define_string_guc( NEON_AUTH_JWT_RUNTIME_PARAM, "JSON Web Token (JWT) used for query authorization", "Represents authenticated user session related claims like user ID", &NEON_AUTH_JWT, - GucContext::Suset, GucFlags::NOT_WHILE_SEC_REST); + GucContext::Suset, + GucFlags::NOT_WHILE_SEC_REST, + ); } diff --git a/src/lib.rs b/src/lib.rs index 49b305a..19bd9c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,9 @@ pub mod auth { use pgrx::JsonB; use serde::de::DeserializeOwned; - use crate::gucs::{NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM}; + use crate::gucs::{ + NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM, + }; type Object = serde_json::Map; @@ -59,18 +61,22 @@ pub mod auth { /// This is to prevent replacing the key mid-session. #[pg_extern] pub fn init(kid: i64) { - let jwk = NEON_AUTH_JWK.get().unwrap_or_else(|| { - error_code!( - PgSqlErrorCode::ERRCODE_NO_DATA, - format!("Missing runtime parameter: {}", NEON_AUTH_JWK_RUNTIME_PARAM) - ) - }).to_str().unwrap_or_else(|e| { - error_code!( - PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, - format!("Couldn't parse {}", NEON_AUTH_JWK_RUNTIME_PARAM), - e.to_string(), - ) - }); + let jwk = NEON_AUTH_JWK + .get() + .unwrap_or_else(|| { + error_code!( + PgSqlErrorCode::ERRCODE_NO_DATA, + format!("Missing runtime parameter: {}", NEON_AUTH_JWK_RUNTIME_PARAM) + ) + }) + .to_str() + .unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, + format!("Couldn't parse {}", NEON_AUTH_JWK_RUNTIME_PARAM), + e.to_string(), + ) + }); let key: JwkEcKey = serde_json::from_str(jwk).unwrap_or_else(|e| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, @@ -208,29 +214,34 @@ pub mod auth { /// This function will panic if the JWT could not be verified. #[pg_extern] pub fn jwt_session_init(jwt: &str) { - Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, jwt).as_str()).unwrap_or_else(|e| { - error_code!( - PgSqlErrorCode::ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED, - format!("Couldn't set {}", NEON_AUTH_JWT_RUNTIME_PARAM), - e.to_string(), - ) - }); + Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, jwt).as_str()) + .unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED, + format!("Couldn't set {}", NEON_AUTH_JWT_RUNTIME_PARAM), + e.to_string(), + ) + }); set_jwt_cache() } fn set_jwt_cache() { - let jwt = NEON_AUTH_JWT.get().unwrap_or_else(|| { - error_code!( - PgSqlErrorCode::ERRCODE_NO_DATA, - format!("Missing runtime parameter: {}", NEON_AUTH_JWT_RUNTIME_PARAM) - ) - }).to_str().unwrap_or_else(|e| { - error_code!( - PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, - format!("Couldn't parse {}", NEON_AUTH_JWT_RUNTIME_PARAM), - e.to_string(), - ) - }); + let jwt = NEON_AUTH_JWT + .get() + .unwrap_or_else(|| { + error_code!( + PgSqlErrorCode::ERRCODE_NO_DATA, + format!("Missing runtime parameter: {}", NEON_AUTH_JWT_RUNTIME_PARAM) + ) + }) + .to_str() + .unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, + format!("Couldn't parse {}", NEON_AUTH_JWT_RUNTIME_PARAM), + e.to_string(), + ) + }); let key = JWK.with(|b| { b.get() .unwrap_or_else(|| { From 52c3ada3ea5be7cecca29ddf5af6938456e26d97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 12:07:59 +0200 Subject: [PATCH 11/21] wip with verbose intent --- src/lib.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 19bd9c1..f8fe27c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -281,9 +281,14 @@ pub mod auth { /// Extract a value from the shared state. #[pg_extern] pub fn session(s: &str) -> JsonB { - // todo: check if JWK is set in thread_local, if not assume its bgworker and call dedicated - // function that uses JWT_CACHE and relies only on runtime parameters (neon.auth.jwk and - // neon.auth.jwt) + JWK.with(|j| { + if j.get().is_none() { + // assuming that running as bgworker + init(1); + set_jwt_cache(); + } + }); + JWT.with_borrow(|j| { JsonB( j.as_ref() From 05bd41679e80404929d58890b57e27b6ed92b669 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 13:44:12 +0200 Subject: [PATCH 12/21] init() expects key ID in runtime params --- src/gucs.rs | 13 +++++++++++++ src/lib.rs | 43 ++++++++++++++++++++++++------------------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/src/gucs.rs b/src/gucs.rs index 6daa779..c5a2ccf 100644 --- a/src/gucs.rs +++ b/src/gucs.rs @@ -1,6 +1,8 @@ use pgrx::*; use std::ffi::CStr; +pub static NEON_AUTH_KID_RUNTIME_PARAM: &str = "neon.auth.kid"; +pub static NEON_AUTH_KID: GucSetting = GucSetting::::new(0); pub static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; pub static NEON_AUTH_JWK: GucSetting> = GucSetting::>::new(None); @@ -9,6 +11,17 @@ pub static NEON_AUTH_JWT: GucSetting> = GucSetting::>::new(None); pub fn init() { + GucRegistry::define_int_guc( + NEON_AUTH_KID_RUNTIME_PARAM, + "ID of JSON Web Key (JWK)", + "Generated per connection by Neon local proxy", + &NEON_AUTH_KID, + 0, + i32::MAX, + GucContext::Suset, // we should use GucContext::SuBackend but this breaks unit tests + GucFlags::NOT_WHILE_SEC_REST, + ); + GucRegistry::define_string_guc( NEON_AUTH_JWK_RUNTIME_PARAM, "JSON Web Key (JWK) userd for JWT validation", diff --git a/src/lib.rs b/src/lib.rs index f8fe27c..3ecda8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,7 @@ pub mod auth { use crate::gucs::{ NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM, + NEON_AUTH_KID, }; type Object = serde_json::Map; @@ -60,7 +61,8 @@ pub mod auth { /// This function will panic if called multiple times per session. /// This is to prevent replacing the key mid-session. #[pg_extern] - pub fn init(kid: i64) { + pub fn init() { + let kid: i64 = NEON_AUTH_KID.get().into(); let jwk = NEON_AUTH_JWK .get() .unwrap_or_else(|| { @@ -284,7 +286,7 @@ pub mod auth { JWK.with(|j| { if j.get().is_none() { // assuming that running as bgworker - init(1); + init(); set_jwt_cache(); } }); @@ -345,9 +347,10 @@ mod tests { use serde_json::json; use crate::auth; - use crate::gucs::{NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT}; + use crate::gucs::{NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_KID_RUNTIME_PARAM}; - fn set_jwk_in_guc(key: String) { + fn set_jwk_in_guc(kid: i32, key: String) { + Spi::run(format!("SET {} = {}", NEON_AUTH_KID_RUNTIME_PARAM, kid).as_str()).unwrap(); Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWK_RUNTIME_PARAM, key).as_str()).unwrap(); } @@ -368,10 +371,12 @@ mod tests { let point = sk.verifying_key().to_encoded_point(false); let jwk = JwkEcKey::from_encoded_point::(&point).unwrap(); let jwk = serde_json::to_value(&jwk).unwrap(); - set_jwk_in_guc(serde_json::to_string(&jwk).unwrap()); - auth::init(1); - auth::init(2); + set_jwk_in_guc(1, serde_json::to_string(&jwk).unwrap()); + auth::init(); + + set_jwk_in_guc(2, serde_json::to_string(&jwk).unwrap()); + auth::init(); } #[pg_test] @@ -380,9 +385,9 @@ mod tests { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(jwk); + set_jwk_in_guc(1, jwk); - auth::init(1); + auth::init(); auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":2}"#, r#"{"jti":1}"#)); } @@ -392,9 +397,9 @@ mod tests { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(jwk); + set_jwk_in_guc(1, jwk); - auth::init(1); + auth::init(); auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#)); auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#)); } @@ -405,9 +410,9 @@ mod tests { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(jwk); + set_jwk_in_guc(1, jwk); - auth::init(1); + auth::init(); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -426,9 +431,9 @@ mod tests { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(jwk); + set_jwk_in_guc(1, jwk); - auth::init(1); + auth::init(); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -446,9 +451,9 @@ mod tests { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(jwk); + set_jwk_in_guc(1, jwk); - auth::init(1); + auth::init(); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -471,9 +476,9 @@ mod tests { let sk = SigningKey::random(&mut OsRng); let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(jwk); + set_jwk_in_guc(1, jwk); - auth::init(1); + auth::init(); let header = r#"{"kid":1}"#; let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); From 782847d54a497cedcf114e73f57158c0b8c1e9c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 13:54:17 +0200 Subject: [PATCH 13/21] unit test for bgworker --- src/lib.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 3ecda8c..c14bae6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -347,13 +347,20 @@ mod tests { use serde_json::json; use crate::auth; - use crate::gucs::{NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_KID_RUNTIME_PARAM}; + use crate::gucs::{ + NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM, + NEON_AUTH_KID_RUNTIME_PARAM, + }; fn set_jwk_in_guc(kid: i32, key: String) { Spi::run(format!("SET {} = {}", NEON_AUTH_KID_RUNTIME_PARAM, kid).as_str()).unwrap(); Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWK_RUNTIME_PARAM, key).as_str()).unwrap(); } + fn set_jwt_in_guc(jwt: String) { + Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, jwt).as_str()).unwrap(); + } + fn sign_jwt(sk: &SigningKey, header: &str, payload: impl Display) -> String { let header = Base64UrlUnpadded::encode_string(header.as_bytes()); let payload = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes()); @@ -491,6 +498,21 @@ mod tests { assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); assert_eq!(auth::user_id(), "bar"); } + + // bgworker process exits after execution, because of that we don't need to test case for more + // than one JWT + #[pg_test] + fn test_bgworker() { + let sk = SigningKey::random(&mut OsRng); + let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); + let jwk = serde_json::to_string(&jwk).unwrap(); + let header = r#"{"kid":1}"#; + let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); + set_jwk_in_guc(1, jwk); + set_jwt_in_guc(jwt); + + assert_eq!(auth::user_id(), "foo"); + } } /// This module is required by `cargo pgrx test` invocations. From ce7f120637d4c2cfd9ad453e9aea9e4c2caece36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 14:07:33 +0200 Subject: [PATCH 14/21] just in case - test multiple execution --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index c14bae6..a70da67 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -512,6 +512,7 @@ mod tests { set_jwt_in_guc(jwt); assert_eq!(auth::user_id(), "foo"); + assert_eq!(auth::user_id(), "foo"); } } From 10393d32125557d2c718722e06738b74d12eb5ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Wed, 9 Oct 2024 14:31:05 +0200 Subject: [PATCH 15/21] chore: bump version --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4dbc791..aec170f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1028,7 +1028,7 @@ dependencies = [ [[package]] name = "pg_session_jwt" -version = "0.0.1" +version = "0.1.0" dependencies = [ "base64ct", "heapless", diff --git a/Cargo.toml b/Cargo.toml index 102a80e..9c3c56d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pg_session_jwt" -version = "0.0.1" +version = "0.1.0" edition = "2021" [lib] From bcd0d531da1877fad55375ffa07f531f6a5f34dd Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 10 Oct 2024 15:27:25 +0200 Subject: [PATCH 16/21] refactor testing suite --- Cargo.lock | 117 +++- Cargo.toml | 14 +- pg_session_jwt.control | 2 +- pgrx-tests/.gitignore | 7 + pgrx-tests/Cargo.toml | 72 +++ pgrx-tests/LICENSE | 26 + pgrx-tests/README.md | 5 + pgrx-tests/src/framework.rs | 898 +++++++++++++++++++++++++++ pgrx-tests/src/framework/shutdown.rs | 125 ++++ pgrx-tests/src/lib.rs | 30 + src/gucs.rs | 19 +- src/lib.rs | 423 +++++++------ tests/pg_session_jwt.rs | 69 ++ 13 files changed, 1581 insertions(+), 226 deletions(-) create mode 100644 pgrx-tests/.gitignore create mode 100644 pgrx-tests/Cargo.toml create mode 100644 pgrx-tests/LICENSE create mode 100644 pgrx-tests/README.md create mode 100644 pgrx-tests/src/framework.rs create mode 100644 pgrx-tests/src/framework/shutdown.rs create mode 100644 pgrx-tests/src/lib.rs create mode 100644 tests/pg_session_jwt.rs diff --git a/Cargo.lock b/Cargo.lock index aec170f..c3eddf3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,12 +26,55 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "0.6.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +[[package]] +name = "anstyle-parse" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + [[package]] name = "anyhow" version = "1.0.89" @@ -249,8 +292,10 @@ version = "4.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" dependencies = [ + "anstream", "anstyle", "clap_lex", + "strsim", ] [[package]] @@ -271,6 +316,12 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +[[package]] +name = "colorchoice" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" + [[package]] name = "const-oid" version = "0.9.6" @@ -481,6 +532,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "escape8259" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" + [[package]] name = "eyre" version = "0.6.12" @@ -721,6 +778,12 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.12.1" @@ -823,6 +886,18 @@ dependencies = [ "libc", ] +[[package]] +name = "libtest-mimic" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" +dependencies = [ + "anstream", + "anstyle", + "clap", + "escape8259", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1031,11 +1106,15 @@ name = "pg_session_jwt" version = "0.1.0" dependencies = [ "base64ct", + "eyre", "heapless", "jose-jwk", + "libtest-mimic", "p256", "pgrx", + "pgrx-pg-config", "pgrx-tests", + "postgres", "rand", "serde", "serde_json", @@ -1138,8 +1217,6 @@ dependencies = [ [[package]] name = "pgrx-tests" version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3abc01e2bb930b072bd660d04c8eaa69a29d4727d5b2a641f946c603c1605e" dependencies = [ "clap-cargo", "eyre", @@ -1157,6 +1234,7 @@ dependencies = [ "serde_json", "sysinfo", "thiserror", + "trybuild", ] [[package]] @@ -1717,6 +1795,12 @@ dependencies = [ "unicode-properties", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -1779,6 +1863,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -1902,6 +1995,20 @@ dependencies = [ "winnow", ] +[[package]] +name = "trybuild" +version = "1.0.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "207aa50d36c4be8d8c6ea829478be44a372c6a77669937bb39c698e52f1491e8" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "termcolor", + "toml", +] + [[package]] name = "typenum" version = "1.17.0" @@ -1970,6 +2077,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 9c3c56d..3231149 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,6 @@ +[workspace] +members = ["pgrx-tests"] + [package] name = "pg_session_jwt" version = "0.1.0" @@ -25,7 +28,11 @@ serde_json = { version = "1.0.117", default-features = false } rand = { version = "0.8", optional = true } [dev-dependencies] -pgrx-tests = "=0.11.3" +eyre = "0.6.12" +libtest-mimic = "0.8.1" +pgrx-pg-config = "0.11.3" +pgrx-tests = { path = "./pgrx-tests" } +postgres = "0.19.9" rand = "0.8" [profile.dev] @@ -36,3 +43,8 @@ panic = "unwind" opt-level = 3 lto = "fat" codegen-units = 1 + +[[test]] +name = "tests" +harness = false +path = "tests/pg_session_jwt.rs" diff --git a/pg_session_jwt.control b/pg_session_jwt.control index 3745faa..825e99c 100644 --- a/pg_session_jwt.control +++ b/pg_session_jwt.control @@ -2,4 +2,4 @@ comment = 'pg_session_jwt: manage authentication sessions using JWTs' default_version = '@CARGO_VERSION@' module_pathname = '$libdir/pg_session_jwt' relocatable = false -superuser = true +superuser = false diff --git a/pgrx-tests/.gitignore b/pgrx-tests/.gitignore new file mode 100644 index 0000000..ab3ae30 --- /dev/null +++ b/pgrx-tests/.gitignore @@ -0,0 +1,7 @@ +.DS_Store +.idea/ +target/ +*.iml +**/*.rs.bk +Cargo.lock +sql/pgrx_tests-1.0.sql diff --git a/pgrx-tests/Cargo.toml b/pgrx-tests/Cargo.toml new file mode 100644 index 0000000..78c7cd3 --- /dev/null +++ b/pgrx-tests/Cargo.toml @@ -0,0 +1,72 @@ +#LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +#LICENSE +#LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +#LICENSE +#LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +#LICENSE +#LICENSE Portions Copyright 2024-2024 Neon, Inc. +#LICENSE +#LICENSE All rights reserved. +#LICENSE +#LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +[package] +name = "pgrx-tests" +version = "0.11.3" +authors = ["PgCentral Foundation, Inc. "] +license = "MIT" +description = "Test framework for 'pgrx'-based Postgres extensions" +homepage = "https://github.com/pgcentralfoundation/pgrx/" +repository = "https://github.com/pgcentralfoundation/pgrx/" +documentation = "https://docs.rs/pgrx-tests" +readme = "README.md" +edition = "2021" + +[lib] +crate-type = ["cdylib", "lib"] + +[features] +default = ["proptest"] +pg11 = ["pgrx/pg11"] +pg12 = ["pgrx/pg12"] +pg13 = ["pgrx/pg13"] +pg14 = ["pgrx/pg14"] +pg15 = ["pgrx/pg15"] +pg16 = ["pgrx/pg16"] +pg_test = [] +proptest = ["dep:proptest"] +cshim = ["pgrx/cshim"] +no-schema-generation = [ + "pgrx/no-schema-generation", + "pgrx-macros/no-schema-generation", +] + +[package.metadata.docs.rs] +features = ["pg14", "proptest"] +no-default-features = true +targets = ["x86_64-unknown-linux-gnu"] +# Enable `#[cfg(docsrs)]` (https://docs.rs/about/builds#cross-compiling) +rustc-args = ["--cfg", "docsrs"] +rustdoc-args = ["--cfg", "docsrs"] + +[dependencies] +clap-cargo = "0.11.0" +owo-colors = "3.5" +once_cell = "1.18.0" +libc = "0.2.149" +pgrx = "=0.11.3" +pgrx-macros = "=0.11.3" +pgrx-pg-config = "=0.11.3" +postgres = "0.19.7" +proptest = { version = "1", optional = true } +regex = "1.10.0" +serde = "1.0" +serde_json = "1.0" +sysinfo = "0.29.10" +eyre = "0.6.8" +thiserror = "1.0" +rand = "0.8.5" + +[dev-dependencies] +eyre = "0.6.8" # testing functions that return `eyre::Result` +trybuild = "1" diff --git a/pgrx-tests/LICENSE b/pgrx-tests/LICENSE new file mode 100644 index 0000000..03632a3 --- /dev/null +++ b/pgrx-tests/LICENSE @@ -0,0 +1,26 @@ +MIT License + +Portions Copyright 2019-2021 ZomboDB, LLC. +Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +Portions Copyright 2023 PgCentral Foundation, Inc. +Portions Copyright 2024 Neon, Inc. + +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pgrx-tests/README.md b/pgrx-tests/README.md new file mode 100644 index 0000000..8aea84d --- /dev/null +++ b/pgrx-tests/README.md @@ -0,0 +1,5 @@ +# pgrx-tests + +Test framework for [`pgrx`](https://crates.io/crates/pgrx/). + +Meant to be used as one of your `[dev-dependencies]` when using `pgrx`. \ No newline at end of file diff --git a/pgrx-tests/src/framework.rs b/pgrx-tests/src/framework.rs new file mode 100644 index 0000000..ec99e1d --- /dev/null +++ b/pgrx-tests/src/framework.rs @@ -0,0 +1,898 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE Portions Copyright 2024-2024 Neon, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +use std::collections::HashSet; +use std::process::{Command, Stdio}; + +use eyre::{eyre, WrapErr}; +use once_cell::sync::Lazy; +use owo_colors::OwoColorize; +use pgrx::prelude::*; +use pgrx_pg_config::{ + cargo::PgrxManifestExt, createdb, get_c_locale_flags, get_target_dir, PgConfig, Pgrx, +}; +use postgres::error::DbError; +use std::collections::HashMap; +use std::io::{BufRead, BufReader, Write}; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use sysinfo::{Pid, ProcessExt, System, SystemExt}; + +mod shutdown; +pub use shutdown::add_shutdown_hook; + +type LogLines = Arc>>>; + +struct SetupState { + installed: bool, + loglines: LogLines, + system_session_id: String, +} + +static TEST_MUTEX: Lazy> = Lazy::new(|| { + Mutex::new(SetupState { + installed: false, + loglines: Arc::new(Mutex::new(HashMap::new())), + system_session_id: "NONE".to_string(), + }) +}); + +// The goal of this closure is to allow "wrapping" of anything that might issue +// an SQL simple_query or query using either a postgres::Client or +// postgres::Transaction and capture the output. The use of this wrapper is +// completely optional, but it might help narrow down some errors later on. +fn query_wrapper( + query: Option, + query_params: Option<&[&(dyn postgres::types::ToSql + Sync)]>, + mut f: F, +) -> eyre::Result +where + T: IntoIterator, + F: FnMut( + Option, + Option<&[&(dyn postgres::types::ToSql + Sync)]>, + ) -> Result, +{ + let result = f(query.clone(), query_params.clone()); + + match result { + Ok(result) => Ok(result), + Err(e) => { + if let Some(dberror) = e.as_db_error() { + let query = query.unwrap(); + let query_message = dberror.message(); + + let code = dberror.code().code(); + let severity = dberror.severity(); + + let mut message = format!("{} SQLSTATE[{}]", severity, code) + .bold() + .red() + .to_string(); + + message.push_str(format!(": {}", query_message.bold().white()).as_str()); + message.push_str(format!("\nquery: {}", query.bold().white()).as_str()); + message.push_str( + format!( + "\nparams: {}", + match query_params { + Some(params) => format!("{:?}", params), + None => "None".to_string(), + } + ) + .as_str(), + ); + + if let Ok(var) = std::env::var("RUST_BACKTRACE") { + if var.eq("1") { + let detail = dberror.detail().unwrap_or("None"); + let hint = dberror.hint().unwrap_or("None"); + let schema = dberror.hint().unwrap_or("None"); + let table = dberror.table().unwrap_or("None"); + let more_info = format!( + "\ndetail: {detail}\nhint: {hint}\nschema: {schema}\ntable: {table}" + ); + message.push_str(more_info.as_str()); + } + } + + Err(eyre!(message)) + } else { + return Err(e).wrap_err("non-DbError"); + } + } + } +} + +pub fn run_test( + options: Option<&str>, + expected_error: Option<&str>, + postgresql_conf: Vec<&'static str>, + queries: impl for<'a> FnOnce(&'a mut postgres::Transaction) -> Result<(), postgres::Error>, +) -> eyre::Result<()> { + if std::env::var_os("PGRX_TEST_SKIP").unwrap_or_default() != "" { + eprintln!("Skipping test because `PGRX_TEST_SKIP` is set in the environment",); + return Ok(()); + } + let (loglines, system_session_id) = initialize_test_framework(postgresql_conf)?; + + { + let (mut client, _) = client(None, &get_pg_user())?; + client + .execute("ALTER ROLE pgrx WITH NOSUPERUSER LOGIN", &[]) + .unwrap(); + client + .execute("GRANT USAGE ON SCHEMA auth TO pgrx", &[]) + .unwrap(); + } + + let (mut client, session_id) = client(options, "pgrx")?; + + let result = client.transaction().map(|mut tx| { + let result = queries(&mut tx); + + // let schema = "tests"; // get_extension_schema(); + // let result = tx.simple_query(&format!("SELECT \"{schema}\".\"{sql_funcname}\"();")); + + if result.is_ok() { + // and abort the transaction when complete + tx.rollback()?; + } + + result + }); + + // flatten the above result + let result = match result { + Err(e) => Err(e), + Ok(Err(e)) => Err(e), + Ok(_) => Ok(()), + }; + + if let Err(e) = result { + let error_as_string = format!("{e}"); + let cause = e.into_source(); + + let (pg_location, rust_location, message) = + if let Some(Some(dberror)) = cause.map(|e| e.downcast_ref::().cloned()) { + let received_error_message = dberror.message(); + + if Some(received_error_message) == expected_error { + // the error received is the one we expected, so just return if they match + return Ok(()); + } + + let pg_location = dberror.file().unwrap_or("").to_string(); + let rust_location = dberror.where_().unwrap_or("").to_string(); + + ( + pg_location, + rust_location, + received_error_message.to_string(), + ) + } else { + ( + "".to_string(), + "".to_string(), + format!("{error_as_string}"), + ) + }; + + // wait a second for Postgres to get log messages written to stderr + std::thread::sleep(std::time::Duration::from_millis(1000)); + + let system_loglines = format_loglines(&system_session_id, &loglines); + let session_loglines = format_loglines(&session_id, &loglines); + panic!( + "\n\nPostgres Messages:\n{system_loglines}\n\nTest Function Messages:\n{session_loglines}\n\nClient Error:\n{message}\npostgres location: {pg_location}\nrust location: {rust_location}\n\n", + system_loglines = system_loglines.dimmed().white(), + session_loglines = session_loglines.cyan(), + message = message.bold().red(), + pg_location = pg_location.dimmed().white(), + rust_location = rust_location.yellow() + ); + } else if let Some(message) = expected_error { + // we expected an ERROR, but didn't get one + return Err(eyre!("Expected error: {message}")); + } else { + Ok(()) + } +} + +fn format_loglines(session_id: &str, loglines: &LogLines) -> String { + let mut result = String::new(); + + for line in loglines + .lock() + .unwrap() + .entry(session_id.to_string()) + .or_default() + .iter() + { + result.push_str(line); + result.push('\n'); + } + + result +} + +fn initialize_test_framework( + postgresql_conf: Vec<&'static str>, +) -> eyre::Result<(LogLines, String)> { + let mut state = TEST_MUTEX.lock().unwrap_or_else(|_| { + // This used to immediately throw an std::process::exit(1), but it + // would consume both stdout and stderr, resulting in error messages + // not being displayed unless you were running tests with --nocapture. + panic!( + "Could not obtain test mutex. A previous test may have hard-aborted while holding it." + ); + }); + + if !state.installed { + shutdown::register_shutdown_hook(); + install_extension()?; + initdb(postgresql_conf)?; + + let system_session_id = start_pg(state.loglines.clone())?; + let pg_config = get_pg_config()?; + dropdb()?; + createdb(&pg_config, get_pg_dbname(), true, false)?; + create_extension()?; + state.installed = true; + state.system_session_id = system_session_id; + } + + Ok((state.loglines.clone(), state.system_session_id.clone())) +} + +fn get_pg_config() -> eyre::Result { + let pgrx = Pgrx::from_config().wrap_err("Unable to get PGRX from config")?; + + let pg_version = pg_sys::get_pg_major_version_num(); + + let pg_config = pgrx + .get(&format!("pg{}", pg_version)) + .wrap_err_with(|| { + format!( + "Error getting pg_config: {} is not a valid postgres version", + pg_version + ) + }) + .unwrap() + .clone(); + + Ok(pg_config) +} + +pub fn client(options: Option<&str>, user: &str) -> eyre::Result<(postgres::Client, String)> { + let pg_config = get_pg_config()?; + + let mut config = postgres::Config::new(); + + config + .host(pg_config.host()) + .port( + pg_config + .test_port() + .expect("unable to determine test port"), + ) + .user(user) + .dbname(&get_pg_dbname()); + + if let Some(options) = options { + config.options(options); + } + + let mut client = config + .connect(postgres::NoTls) + .wrap_err("Error connecting to Postgres")?; + + let sid_query_result = query_wrapper( + Some("SELECT to_hex(trunc(EXTRACT(EPOCH FROM backend_start))::integer) || '.' || to_hex(pid) AS sid FROM pg_stat_activity WHERE pid = pg_backend_pid();".to_string()), + Some(&[]), + |query, query_params| client.query(&query.unwrap(), query_params.unwrap()), + ) + .wrap_err("There was an issue attempting to get the session ID from Postgres")?; + + let session_id = match sid_query_result.get(0) { + Some(row) => row.get::<&str, &str>("sid").to_string(), + None => Err(eyre!("Failed to obtain a client Session ID from Postgres"))?, + }; + + if user != "pgrx" { + query_wrapper( + Some("SET log_min_messages TO 'INFO';".to_string()), + None, + |query, _| client.simple_query(query.unwrap().as_str()), + ) + .wrap_err("Postgres Client setup failed to SET log_min_messages TO 'INFO'")?; + + query_wrapper( + Some("SET log_min_duration_statement TO 1000;".to_string()), + None, + |query, _| client.simple_query(query.unwrap().as_str()), + ) + .wrap_err("Postgres Client setup failed to SET log_min_duration_statement TO 1000;")?; + + query_wrapper( + Some("SET log_statement TO 'all';".to_string()), + None, + |query, _| client.simple_query(query.unwrap().as_str()), + ) + .wrap_err("Postgres Client setup failed to SET log_statement TO 'all';")?; + } + + Ok((client, session_id)) +} + +fn install_extension() -> eyre::Result<()> { + eprintln!("installing extension"); + let profile = std::env::var("PGRX_BUILD_PROFILE").unwrap_or("debug".into()); + let no_schema = std::env::var("PGRX_NO_SCHEMA").unwrap_or("false".into()) == "true"; + let mut features = std::env::var("PGRX_FEATURES") + .unwrap_or("".to_string()) + .split_ascii_whitespace() + .map(|s| s.to_string()) + .collect::>(); + features.insert("pg_test".into()); + + let no_default_features = + std::env::var("PGRX_NO_DEFAULT_FEATURES").unwrap_or("false".to_string()) == "true"; + let all_features = std::env::var("PGRX_ALL_FEATURES").unwrap_or("false".to_string()) == "true"; + + let pg_version = format!("pg{}", pg_sys::get_pg_major_version_string()); + let pgrx = Pgrx::from_config()?; + let pg_config = pgrx.get(&pg_version)?; + let cargo_test_args = get_cargo_test_features()?; + println!("detected cargo args: {:?}", cargo_test_args); + + features.extend(cargo_test_args.features.iter().cloned()); + + let mut command = cargo_pgrx(); + command + .arg("install") + .arg("--test") + .arg("--pg-config") + .arg(pg_config.path().ok_or(eyre!("No pg_config found"))?) + .stdout(Stdio::inherit()) + .stderr(Stdio::piped()) + .env("CARGO_TARGET_DIR", get_target_dir()?); + + if let Ok(manifest_path) = std::env::var("PGRX_MANIFEST_PATH") { + command.arg("--manifest-path"); + command.arg(manifest_path); + } + + if let Ok(rust_log) = std::env::var("RUST_LOG") { + command.env("RUST_LOG", rust_log); + } + + if !features.is_empty() { + command.arg("--features"); + command.arg(features.into_iter().collect::>().join(" ")); + } + + if no_default_features || cargo_test_args.no_default_features { + command.arg("--no-default-features"); + } + + if all_features || cargo_test_args.all_features { + command.arg("--all-features"); + } + + match profile.trim() { + // For legacy reasons, cargo has two names for the debug profile... (We + // also ignore the empty string here, just in case). + "debug" | "dev" | "" => {} + "release" => { + command.arg("--release"); + } + profile => { + command.args(["--profile", profile]); + } + } + + if no_schema { + command.arg("--no-schema"); + } + + let command_str = format!("{:?}", command); + + let child = command.spawn().wrap_err_with(|| { + format!( + "Failed to spawn process for installing extension using command: '{}': ", + command_str + ) + })?; + + let output = child.wait_with_output().wrap_err_with(|| { + format!( + "Failed waiting for spawned process attempting to install extension using command: '{}': ", + command_str + ) + })?; + + if !output.status.success() { + return Err(eyre!( + "Failure installing extension using command: {}\n\n{}{}", + command_str, + String::from_utf8(output.stdout).unwrap(), + String::from_utf8(output.stderr).unwrap() + )); + } + + Ok(()) +} + +fn initdb(postgresql_conf: Vec<&'static str>) -> eyre::Result<()> { + let pgdata = get_pgdata_path()?; + + if !pgdata.is_dir() { + let pg_config = get_pg_config()?; + let mut command = Command::new( + pg_config + .initdb_path() + .wrap_err("unable to determine initdb path")?, + ); + + command + .args(get_c_locale_flags()) + .arg("-D") + .arg(pgdata.to_str().unwrap()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + + let command_str = format!("{:?}", command); + + let child = command.spawn().wrap_err_with(|| { + format!( + "Failed to spawn process for initializing database using command: '{}': ", + command_str + ) + })?; + + let output = child.wait_with_output().wrap_err_with(|| { + format!( + "Failed waiting for spawned process attempting to initialize database using command: '{}': ", + command_str + ) + })?; + + if !output.status.success() { + return Err(eyre!( + "Failed to initialize database using command: {}\n\n{}{}", + command_str, + String::from_utf8(output.stdout).unwrap(), + String::from_utf8(output.stderr).unwrap() + )); + } + } + + modify_postgresql_conf(pgdata, postgresql_conf) +} + +fn modify_postgresql_conf(pgdata: PathBuf, postgresql_conf: Vec<&'static str>) -> eyre::Result<()> { + let mut postgresql_conf_file = std::fs::OpenOptions::new() + .write(true) + .truncate(true) + .open(format!("{}/postgresql.auto.conf", pgdata.display())) + .wrap_err("couldn't open postgresql.auto.conf")?; + postgresql_conf_file + .write_all("log_line_prefix='[%m] [%p] [%c]: '\n".as_bytes()) + .wrap_err("couldn't append log_line_prefix")?; + + for setting in postgresql_conf { + postgresql_conf_file + .write_all(format!("{setting}\n").as_bytes()) + .wrap_err("couldn't append custom setting to postgresql.conf")?; + } + + postgresql_conf_file + .write_all( + format!( + "unix_socket_directories = '{}'", + Pgrx::home().unwrap().display() + ) + .as_bytes(), + ) + .wrap_err("couldn't append `unix_socket_directories` setting to postgresql.conf")?; + Ok(()) +} + +fn start_pg(loglines: LogLines) -> eyre::Result { + wait_for_pidfile()?; + + let pg_config = get_pg_config()?; + let postmaster_path = pg_config + .postmaster_path() + .wrap_err("unable to determine postmaster path")?; + + let mut command = if use_valgrind() { + let mut cmd = Command::new("valgrind"); + cmd.args([ + "--leak-check=no", + "--gen-suppressions=all", + "--time-stamp=yes", + "--error-markers=VALGRINDERROR-BEGIN,VALGRINDERROR-END", + "--trace-children=yes", + ]); + // Try to provide a suppressions file, we'll likely get false positives + // if we can't, but that might be better than nothing. + if let Ok(path) = valgrind_suppressions_path(&pg_config) { + if path.exists() { + cmd.arg(format!("--suppressions={}", path.display())); + } + } + + cmd.arg(postmaster_path); + cmd + } else { + Command::new(postmaster_path) + }; + command + .arg("-D") + .arg(get_pgdata_path()?.to_str().unwrap()) + .arg("-h") + .arg(pg_config.host()) + .arg("-p") + .arg( + pg_config + .test_port() + .expect("unable to determine test port") + .to_string(), + ) + // Redirecting logs to files can hang the test framework, override it + .args([ + "-c", + "log_destination=stderr", + "-c", + "logging_collector=off", + ]) + .stdout(Stdio::inherit()) + .stderr(Stdio::piped()); + + let command_str = format!("{command:?}"); + + // start Postgres and monitor its stderr in the background + // also notify the main thread when it's ready to accept connections + let session_id = monitor_pg(command, command_str, loglines); + + Ok(session_id) +} + +fn valgrind_suppressions_path(pg_config: &PgConfig) -> Result { + let mut home = Pgrx::home()?; + home.push(pg_config.version()?); + home.push("src/tools/valgrind.supp"); + Ok(home) +} + +fn wait_for_pidfile() -> Result<(), eyre::Report> { + const MAX_PIDFILE_RETRIES: usize = 10; + + let pidfile = get_pid_file()?; + + let mut retries = 0; + while pidfile.exists() { + if retries > MAX_PIDFILE_RETRIES { + // break out and try to start postgres anyways, maybe it'll report a decent error about what's going on + eprintln!("`{}` has existed for ~10s. There might be some problem with the pgrx testing Postgres instance", pidfile.display()); + break; + } + eprintln!("`{}` still exists. Waiting...", pidfile.display()); + std::thread::sleep(Duration::from_secs(1)); + retries += 1; + } + Ok(()) +} + +fn monitor_pg(mut command: Command, cmd_string: String, loglines: LogLines) -> String { + let (sender, receiver) = std::sync::mpsc::channel(); + + std::thread::spawn(move || { + let mut child = command.spawn().expect("postmaster didn't spawn"); + + let pid = child.id(); + // Add a shutdown hook so we can terminate it when the test framework + // exits. TODO: Consider finding a way to handle cases where we fail to + // clean up due to a SIGNAL? + add_shutdown_hook(move || unsafe { + libc::kill(pid as libc::pid_t, libc::SIGTERM); + let message_string = std::ffi::CString::new( + format!("stopping postgres (pid={pid})\n") + .bold() + .blue() + .to_string(), + ) + .unwrap(); + // IMPORTANT: Rust string literals are not naturally null-terminated + libc::printf("%s\0".as_ptr().cast(), message_string.as_ptr()); + }); + + eprintln!( + "{cmd}\npid={p}", + cmd = cmd_string.bold().blue(), + p = pid.to_string().yellow() + ); + eprintln!("{}", pg_sys::get_pg_version_string().bold().purple()); + + // wait for the database to say its ready to start up + let reader = BufReader::new( + child + .stderr + .take() + .expect("couldn't take postmaster stderr"), + ); + + let regex = regex::Regex::new(r#"\[.*?\] \[.*?\] \[(?P.*?)\]"#).unwrap(); + let mut is_started_yet = false; + let mut lines = reader.lines(); + while let Some(Ok(line)) = lines.next() { + let session_id = match get_named_capture(®ex, "session_id", &line) { + Some(sid) => sid, + None => "NONE".to_string(), + }; + + if line.contains("database system is ready to accept connections") { + // Postgres says it's ready to go + sender.send(session_id.clone()).unwrap(); + is_started_yet = true; + } + + if !is_started_yet || line.contains("TMSG: ") { + eprintln!("{}", line.cyan()); + } + + // if line.contains("INFO: ") { + // eprintln!("{}", line.cyan()); + // } else if line.contains("WARNING: ") { + // eprintln!("{}", line.bold().yellow()); + // } else if line.contains("ERROR: ") { + // eprintln!("{}", line.bold().red()); + // } else if line.contains("statement: ") || line.contains("duration: ") { + // eprintln!("{}", line.bold().blue()); + // } else if line.contains("LOG: ") { + // eprintln!("{}", line.dimmed().white()); + // } else { + // eprintln!("{}", line.bold().purple()); + // } + + let mut loglines = loglines.lock().unwrap(); + let session_lines = loglines.entry(session_id).or_insert_with(Vec::new); + session_lines.push(line); + } + + // wait for Postgres to really finish + match child.try_wait() { + Ok(status) => { + if let Some(_status) = status { + // we exited normally + } + } + Err(e) => panic!("was going to let Postgres finish, but errored this time:\n{e}"), + } + }); + + // wait for Postgres to indicate it's ready to accept connection + // and return its pid when it is + receiver.recv().expect("Postgres failed to start") +} + +fn dropdb() -> eyre::Result<()> { + let pg_config = get_pg_config()?; + let output = Command::new( + pg_config + .dropdb_path() + .expect("unable to determine dropdb path"), + ) + .env_remove("PGDATABASE") + .env_remove("PGHOST") + .env_remove("PGPORT") + .env_remove("PGUSER") + .arg("--if-exists") + .arg("-h") + .arg(pg_config.host()) + .arg("-p") + .arg( + pg_config + .test_port() + .expect("unable to determine test port") + .to_string(), + ) + .arg(get_pg_dbname()) + .output() + .unwrap(); + + if !output.status.success() { + // maybe the database didn't exist, and if so that's okay + let stderr = String::from_utf8_lossy(output.stderr.as_slice()); + if !stderr.contains(&format!( + "ERROR: database \"{}\" does not exist", + get_pg_dbname() + )) { + // got some error we didn't expect + let stdout = String::from_utf8_lossy(output.stdout.as_slice()); + eprintln!("unexpected error (stdout):\n{stdout}"); + eprintln!("unexpected error (stderr):\n{stderr}"); + panic!("failed to drop test database"); + } + } + + Ok(()) +} + +fn create_extension() -> eyre::Result<()> { + let (mut client, _) = client(None, &get_pg_user())?; + let extension_name = get_extension_name()?; + + query_wrapper( + Some(format!("CREATE EXTENSION {} CASCADE;", &extension_name)), + None, + |query, _| client.simple_query(query.unwrap().as_str()), + ) + .wrap_err(format!( + "There was an issue creating the extension '{}' in Postgres: ", + &extension_name + ))?; + + Ok(()) +} + +fn get_extension_name() -> eyre::Result { + // We could replace this with the following if cargo adds the lib name on env var on tests/runs. + // https://github.com/rust-lang/cargo/issues/11966 + // std::env::var("CARGO_LIB_NAME") + // .unwrap_or_else(|_| panic!("CARGO_LIB_NAME environment var is unset or invalid UTF-8")) + // .replace("-", "_") + + // CARGO_MANIFEST_DIRR — The directory containing the manifest of your package. + // https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates + let dir = std::env::var("CARGO_MANIFEST_DIR") + .map_err(|_| eyre!("CARGO_MANIFEST_DIR environment var is unset or invalid UTF-8"))?; + + // Cargo.toml is case sensitive atm so this is ok. + // https://github.com/rust-lang/cargo/issues/45 + let path = PathBuf::from(dir).join("Cargo.toml"); + let name = pgrx_pg_config::cargo::read_manifest(path)?.lib_name()?; + Ok(name.replace("-", "_")) +} + +fn get_pgdata_path() -> eyre::Result { + let mut target_dir = get_target_dir()?; + target_dir.push(&format!( + "pgrx-test-data-{}", + pg_sys::get_pg_major_version_num() + )); + Ok(target_dir) +} + +fn get_pid_file() -> eyre::Result { + let mut pgdata = get_pgdata_path()?; + pgdata.push("postmaster.pid"); + return Ok(pgdata); +} + +pub(crate) fn get_pg_dbname() -> &'static str { + "pgrx_tests" +} + +pub(crate) fn get_pg_user() -> String { + std::env::var("USER") + .unwrap_or_else(|_| panic!("USER environment var is unset or invalid UTF-8")) +} + +pub fn get_named_capture( + regex: ®ex::Regex, + name: &'static str, + against: &str, +) -> Option { + match regex.captures(against) { + Some(cap) => Some(cap[name].to_string()), + None => None, + } +} + +fn get_cargo_test_features() -> eyre::Result { + let mut features = clap_cargo::Features::default(); + let cargo_user_args = get_cargo_args(); + let mut iter = cargo_user_args.iter(); + while let Some(part) = iter.next() { + match part.as_str() { + "--no-default-features" => features.no_default_features = true, + "--features" => { + let configured_features = iter.next().ok_or(eyre!( + "no `--features` specified in the cargo argument list: {:?}", + cargo_user_args + ))?; + features.features = configured_features + .split(|c: char| c.is_ascii_whitespace() || c == ',') + .map(|s| s.to_string()) + .collect(); + } + "--all-features" => features.all_features = true, + _ => {} + } + } + + Ok(features) +} + +fn get_cargo_args() -> Vec { + // setup the sysinfo crate's "System" + let mut system = System::new_all(); + system.refresh_all(); + + // starting with our process, look for the full set of arguments for the top-most "cargo" command + // in our process tree. + // + // it's possible we've been called by: + // - the user from the command-line via `cargo test ...` + // - `cargo pgrx test ...` + // - `cargo test ...` + // - some other combination with a `cargo ...` in the middle, perhaps + // + // we're interested in the first arguments the **user** gave to cargo, so `framework.rs` + // can later figure out which set of features to pass to `cargo pgrx` + let mut pid = Pid::from(std::process::id() as usize); + while let Some(process) = system.process(pid) { + // only if it's "cargo"... (This works for now, but just because `cargo` + // is at the end of the path. How *should* this handle `CARGO`?) + if process.exe().ends_with("cargo") { + // ... and only if it's "cargo test"... + if process.cmd().iter().any(|arg| arg == "test") + && !process.cmd().iter().any(|arg| arg == "pgrx") + { + // ... do we want its args + return process.cmd().iter().cloned().collect(); + } + } + + // and we want to keep going to find the top-most "cargo" process in our tree + match process.parent() { + Some(parent_pid) => pid = parent_pid, + None => break, + } + } + + Vec::new() +} + +// TODO: this would be a good place to insert a check invoking to see if +// `cargo-pgrx` is a crate in the local workspace, and use it instead. +fn cargo_pgrx() -> std::process::Command { + fn var_path(s: &str) -> Option { + std::env::var_os(s).map(PathBuf::from) + } + // Use `CARGO_PGRX` (set by `cargo-pgrx` on first run), then fall back to + // `cargo-pgrx` if it is on the path, then `$CARGO pgrx` + let cargo_pgrx = var_path("CARGO_PGRX") + .or_else(|| find_on_path("cargo-pgrx")) + .or_else(|| var_path("CARGO")) + .unwrap_or_else(|| "cargo".into()); + let mut cmd = std::process::Command::new(cargo_pgrx); + cmd.arg("pgrx"); + cmd +} + +fn find_on_path(program: &str) -> Option { + assert!(!program.contains('/')); + // Technically we should check `libc::confstr(libc::_CS_PATH)` + // when `PATH` is unset... + let paths = std::env::var_os("PATH")?; + std::env::split_paths(&paths) + .map(|p| p.join(program)) + .find(|abs| abs.exists()) +} + +fn use_valgrind() -> bool { + std::env::var_os("USE_VALGRIND").is_some_and(|s| s.len() > 0) +} diff --git a/pgrx-tests/src/framework/shutdown.rs b/pgrx-tests/src/framework/shutdown.rs new file mode 100644 index 0000000..bf6c3da --- /dev/null +++ b/pgrx-tests/src/framework/shutdown.rs @@ -0,0 +1,125 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +use std::panic::{self, AssertUnwindSafe, Location}; +use std::sync::{Mutex, PoisonError}; +use std::{any, io, mem, process}; + +/// Register a shutdown hook to be called when the process exits. +/// +/// Note that shutdown hooks are only run on the client, so must be added from +/// your `setup` function, not the `#[pg_test]` itself. +#[track_caller] +pub fn add_shutdown_hook(func: F) +where + F: Send + 'static, +{ + SHUTDOWN_HOOKS + .lock() + .unwrap_or_else(PoisonError::into_inner) + .push(ShutdownHook { source: Location::caller(), callback: Box::new(func) }); +} + +pub(super) fn register_shutdown_hook() { + unsafe { + libc::atexit(run_shutdown_hooks); + } +} + +/// The `atexit` callback. +/// +/// If we panic from `atexit`, we end up causing `exit` to unwind. Unwinding +/// from a nounwind + noreturn function can cause some destructors to run twice, +/// causing (for example) libtest to SIGSEGV. +/// +/// This ends up looking like a memory bug in either `pgrx` or the user code, and +/// is very hard to track down, so we go to some lengths to prevent it. +/// Essentially: +/// +/// - Panics in each user hook are caught and reported. +/// - As a stop-gap an abort-on-drop panic guard is used to ensure there isn't a +/// place we missed. +/// +/// We also write to stderr directly instead, since otherwise our output will +/// sometimes be redirected. +extern "C" fn run_shutdown_hooks() { + let guard = PanicGuard; + let mut any_panicked = false; + let mut hooks = SHUTDOWN_HOOKS.lock().unwrap_or_else(PoisonError::into_inner); + // Note: run hooks in the opposite order they were registered. + for hook in mem::take(&mut *hooks).into_iter().rev() { + any_panicked |= hook.run().is_err(); + } + if any_panicked { + write_stderr("error: one or more shutdown hooks panicked (see `stderr` for details).\n"); + std::process::abort() + } + mem::forget(guard); +} + +/// Prevent panics in a block of code. +/// +/// Prints a message and aborts in its drop. Intended usage is like: +/// ```ignore +/// let guard = PanicGuard; +/// // ...code that absolutely must never unwind goes here... +/// core::mem::forget(guard); +/// ``` +struct PanicGuard; +impl Drop for PanicGuard { + fn drop(&mut self) { + write_stderr("Failed to catch panic in the `atexit` callback, aborting!\n"); + process::abort(); + } +} + +static SHUTDOWN_HOOKS: Mutex> = Mutex::new(Vec::new()); + +struct ShutdownHook { + source: &'static Location<'static>, + callback: Box, +} + +impl ShutdownHook { + fn run(self) -> Result<(), ()> { + let Self { source, callback } = self; + let result = panic::catch_unwind(AssertUnwindSafe(callback)); + if let Err(e) = result { + let msg = failure_message(&e); + write_stderr(&format!( + "error: shutdown hook (registered at {source}) panicked: {msg}\n" + )); + Err(()) + } else { + Ok(()) + } + } +} + +fn failure_message(e: &(dyn any::Any + Send)) -> &str { + if let Some(&msg) = e.downcast_ref::<&'static str>() { + msg + } else if let Some(msg) = e.downcast_ref::() { + msg.as_str() + } else { + "" + } +} + +/// Write to stderr, bypassing libtest's output redirection. Doesn't append `\n`. +fn write_stderr(s: &str) { + loop { + let res = unsafe { libc::write(libc::STDERR_FILENO, s.as_ptr().cast(), s.len()) }; + // Handle EINTR to ensure we don't drop messages. + // `Error::last_os_error()` just reads from errno, so it's fine to use here. + if res >= 0 || io::Error::last_os_error().kind() != io::ErrorKind::Interrupted { + break; + } + } +} diff --git a/pgrx-tests/src/lib.rs b/pgrx-tests/src/lib.rs new file mode 100644 index 0000000..f2fe75c --- /dev/null +++ b/pgrx-tests/src/lib.rs @@ -0,0 +1,30 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +mod framework; +// #[cfg(any(test, feature = "pg_test"))] +// mod tests; + +pub use framework::*; +// #[cfg(feature = "proptest")] +// pub mod proptest; + +// #[cfg(any(test, feature = "pg_test"))] +// pgrx::pg_sql_graph_magic!(); + +// #[cfg(test)] +// pub mod pg_test { +// pub fn setup(_options: Vec<&str>) { +// // noop +// } + +// pub fn postgresql_conf_options() -> Vec<&'static str> { +// vec!["shared_preload_libraries='pgrx_tests'"] +// } +// } diff --git a/src/gucs.rs b/src/gucs.rs index c5a2ccf..e52c925 100644 --- a/src/gucs.rs +++ b/src/gucs.rs @@ -1,8 +1,6 @@ use pgrx::*; use std::ffi::CStr; -pub static NEON_AUTH_KID_RUNTIME_PARAM: &str = "neon.auth.kid"; -pub static NEON_AUTH_KID: GucSetting = GucSetting::::new(0); pub static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; pub static NEON_AUTH_JWK: GucSetting> = GucSetting::>::new(None); @@ -11,23 +9,12 @@ pub static NEON_AUTH_JWT: GucSetting> = GucSetting::>::new(None); pub fn init() { - GucRegistry::define_int_guc( - NEON_AUTH_KID_RUNTIME_PARAM, - "ID of JSON Web Key (JWK)", - "Generated per connection by Neon local proxy", - &NEON_AUTH_KID, - 0, - i32::MAX, - GucContext::Suset, // we should use GucContext::SuBackend but this breaks unit tests - GucFlags::NOT_WHILE_SEC_REST, - ); - GucRegistry::define_string_guc( NEON_AUTH_JWK_RUNTIME_PARAM, - "JSON Web Key (JWK) userd for JWT validation", + "JSON Web Key (JWK) used for JWT validation", "Generated per connection by Neon local proxy", &NEON_AUTH_JWK, - GucContext::Suset, // we should use GucContext::SuBackend but this breaks unit tests + GucContext::Backend, // we should use GucContext::Backend but this breaks unit tests GucFlags::NOT_WHILE_SEC_REST, ); @@ -36,7 +23,7 @@ pub fn init() { "JSON Web Token (JWT) used for query authorization", "Represents authenticated user session related claims like user ID", &NEON_AUTH_JWT, - GucContext::Suset, + GucContext::Userset, GucFlags::NOT_WHILE_SEC_REST, ); } diff --git a/src/lib.rs b/src/lib.rs index a70da67..13af319 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod gucs; +use p256::elliptic_curve::JwkEcKey; use pgrx::prelude::*; pgrx::pg_module_magic!(); @@ -21,6 +22,21 @@ pub unsafe extern "C" fn _PG_init() { gucs::init(); } +/// An Elliptic Curve JSON Web Key. +/// +/// This type is defined in [RFC7517 Section 4]. +/// +/// [RFC7517 Section 4]: https://datatracker.ietf.org/doc/html/rfc7517#section-4 +#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize)] +struct JwkEc { + /// The key material. + #[serde(flatten)] + key: JwkEcKey, + + // The key parameters. + kid: i64, +} + #[pg_schema] pub mod auth { use std::cell::{OnceCell, RefCell}; @@ -29,7 +45,6 @@ pub mod auth { use p256::ecdsa::signature::Verifier; use p256::ecdsa::{Signature, VerifyingKey}; use p256::elliptic_curve::generic_array::GenericArray; - use p256::elliptic_curve::JwkEcKey; use p256::PublicKey; use pgrx::prelude::*; use pgrx::JsonB; @@ -37,8 +52,8 @@ pub mod auth { use crate::gucs::{ NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM, - NEON_AUTH_KID, }; + use crate::JwkEc; type Object = serde_json::Map; @@ -62,7 +77,6 @@ pub mod auth { /// This is to prevent replacing the key mid-session. #[pg_extern] pub fn init() { - let kid: i64 = NEON_AUTH_KID.get().into(); let jwk = NEON_AUTH_JWK .get() .unwrap_or_else(|| { @@ -79,14 +93,15 @@ pub mod auth { e.to_string(), ) }); - let key: JwkEcKey = serde_json::from_str(jwk).unwrap_or_else(|e| { + + let jwk: JwkEc = serde_json::from_str(jwk).unwrap_or_else(|e| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "session init requires an ES256 JWK", e.to_string(), ) }); - let key = PublicKey::from_jwk(&key).unwrap_or_else(|p256::elliptic_curve::Error| { + let key = PublicKey::from_jwk(&jwk.key).unwrap_or_else(|p256::elliptic_curve::Error| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "session init requires an ES256 JWK", @@ -94,7 +109,7 @@ pub mod auth { }); let key = VerifyingKey::from(key); JWK.with(|j| { - if j.set(Key { kid, key }).is_err() { + if j.set(Key { kid: jwk.kid, key }).is_err() { error_code!( PgSqlErrorCode::ERRCODE_UNIQUE_VIOLATION, "JWK state can only be set once per session.", @@ -329,203 +344,199 @@ pub mod auth { } } -#[cfg(any(test, feature = "pg_test"))] -#[pg_schema] -mod tests { - use std::fmt::Display; - use std::time::{SystemTime, UNIX_EPOCH}; - - use base64ct::{Base64UrlUnpadded, Encoding}; - use p256::ecdsa::signature::Signer; - use p256::{ - ecdsa::{Signature, SigningKey}, - elliptic_curve::JwkEcKey, - }; - use p256::{NistP256, PublicKey}; - use pgrx::prelude::*; - use rand::rngs::OsRng; - use serde_json::json; - - use crate::auth; - use crate::gucs::{ - NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM, - NEON_AUTH_KID_RUNTIME_PARAM, - }; - - fn set_jwk_in_guc(kid: i32, key: String) { - Spi::run(format!("SET {} = {}", NEON_AUTH_KID_RUNTIME_PARAM, kid).as_str()).unwrap(); - Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWK_RUNTIME_PARAM, key).as_str()).unwrap(); - } - - fn set_jwt_in_guc(jwt: String) { - Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, jwt).as_str()).unwrap(); - } - - fn sign_jwt(sk: &SigningKey, header: &str, payload: impl Display) -> String { - let header = Base64UrlUnpadded::encode_string(header.as_bytes()); - let payload = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes()); - - let message = format!("{header}.{payload}"); - let sig: Signature = sk.sign(message.as_bytes()); - let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); - format!("{message}.{base64_sig}") - } - - #[pg_test] - #[should_panic = "JWK state can only be set once per session."] - fn init_jwk_twice() { - let sk = SigningKey::random(&mut OsRng); - let point = sk.verifying_key().to_encoded_point(false); - let jwk = JwkEcKey::from_encoded_point::(&point).unwrap(); - let jwk = serde_json::to_value(&jwk).unwrap(); - - set_jwk_in_guc(1, serde_json::to_string(&jwk).unwrap()); - auth::init(); - - set_jwk_in_guc(2, serde_json::to_string(&jwk).unwrap()); - auth::init(); - } - - #[pg_test] - #[should_panic = "Key ID mismatch"] - fn wrong_pid() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(1, jwk); - - auth::init(); - auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":2}"#, r#"{"jti":1}"#)); - } - - #[pg_test] - #[should_panic = "Token ID must be strictly monotonically increasing"] - fn wrong_txid() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(1, jwk); - - auth::init(); - auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#)); - auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#)); - } - - #[pg_test] - #[should_panic = "Token used before it is ready"] - fn invalid_nbf() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(1, jwk); - - auth::init(); - - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - auth::jwt_session_init(&sign_jwt( - &sk, - r#"{"kid":1}"#, - json!({"jti": 1, "nbf": now + 10}), - )); - } - - #[pg_test] - #[should_panic = "Token used after it has expired"] - fn invalid_exp() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(1, jwk); - - auth::init(); - - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - auth::jwt_session_init(&sign_jwt( - &sk, - r#"{"kid":1}"#, - json!({"jti": 1, "nbf": now - 10, "exp": now - 5}), - )); - } - - #[pg_test] - fn valid_time() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(1, jwk); - - auth::init(); - - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - let header = r#"{"kid":1}"#; - - auth::jwt_session_init(&sign_jwt( - &sk, - header, - json!({"jti": 1, "nbf": now - 10, "exp": now + 10}), - )); - auth::jwt_session_init(&sign_jwt(&sk, header, json!({"jti": 2, "nbf": now - 10}))); - auth::jwt_session_init(&sign_jwt(&sk, header, json!({"jti": 3, "exp": now + 10}))); - } - - #[pg_test] - fn test_pg_session_jwt() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = serde_json::to_string(&jwk).unwrap(); - set_jwk_in_guc(1, jwk); - - auth::init(); - let header = r#"{"kid":1}"#; - - let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); - auth::jwt_session_init(&jwt); - assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); - assert_eq!(auth::user_id(), "foo"); - - let jwt = sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#); - auth::jwt_session_init(&jwt); - assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); - assert_eq!(auth::user_id(), "bar"); - } - - // bgworker process exits after execution, because of that we don't need to test case for more - // than one JWT - #[pg_test] - fn test_bgworker() { - let sk = SigningKey::random(&mut OsRng); - let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); - let jwk = serde_json::to_string(&jwk).unwrap(); - let header = r#"{"kid":1}"#; - let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); - set_jwk_in_guc(1, jwk); - set_jwt_in_guc(jwt); - - assert_eq!(auth::user_id(), "foo"); - assert_eq!(auth::user_id(), "foo"); - } -} - -/// This module is required by `cargo pgrx test` invocations. -/// It must be visible at the root of your extension crate. -#[cfg(test)] -pub mod pg_test { - pub fn setup(_options: Vec<&str>) { - // perform one-off initialization when the pg_test framework starts - } - - pub fn postgresql_conf_options() -> Vec<&'static str> { - // return any postgresql.conf settings that are required for your tests - vec![] - } -} +// #[cfg(any(test, feature = "pg_test"))] +// #[pg_schema] +// mod tests { +// use std::fmt::Display; +// use std::time::{SystemTime, UNIX_EPOCH}; + +// use base64ct::{Base64UrlUnpadded, Encoding}; +// use p256::ecdsa::signature::Signer; +// use p256::{ +// ecdsa::{Signature, SigningKey}, +// elliptic_curve::JwkEcKey, +// }; +// use p256::{NistP256, PublicKey}; +// use pgrx::prelude::*; +// use rand::rngs::OsRng; +// use serde_json::json; + +// use crate::auth; +// use crate::gucs::{NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM}; + +// fn set_jwk_in_guc(kid: i32, key: String) { +// Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWK_RUNTIME_PARAM, key).as_str()).unwrap(); +// } + +// fn set_jwt_in_guc(jwt: String) { +// Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, jwt).as_str()).unwrap(); +// } + +// fn sign_jwt(sk: &SigningKey, header: &str, payload: impl Display) -> String { +// let header = Base64UrlUnpadded::encode_string(header.as_bytes()); +// let payload = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes()); + +// let message = format!("{header}.{payload}"); +// let sig: Signature = sk.sign(message.as_bytes()); +// let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); +// format!("{message}.{base64_sig}") +// } + +// #[pg_test] +// #[should_panic = "JWK state can only be set once per session."] +// fn init_jwk_twice() { +// let sk = SigningKey::random(&mut OsRng); +// let point = sk.verifying_key().to_encoded_point(false); +// let jwk = JwkEcKey::from_encoded_point::(&point).unwrap(); +// let jwk = serde_json::to_value(&jwk).unwrap(); + +// set_jwk_in_guc(1, serde_json::to_string(&jwk).unwrap()); +// auth::init(); + +// set_jwk_in_guc(2, serde_json::to_string(&jwk).unwrap()); +// auth::init(); +// } + +// #[pg_test] +// #[should_panic = "Key ID mismatch"] +// fn wrong_pid() { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); +// let jwk = serde_json::to_string(&jwk).unwrap(); +// set_jwk_in_guc(1, jwk); + +// auth::init(); +// auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":2}"#, r#"{"jti":1}"#)); +// } + +// #[pg_test] +// #[should_panic = "Token ID must be strictly monotonically increasing"] +// fn wrong_txid() { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); +// let jwk = serde_json::to_string(&jwk).unwrap(); +// set_jwk_in_guc(1, jwk); + +// auth::init(); +// auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#)); +// auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#)); +// } + +// #[pg_test] +// #[should_panic = "Token used before it is ready"] +// fn invalid_nbf() { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); +// let jwk = serde_json::to_string(&jwk).unwrap(); +// set_jwk_in_guc(1, jwk); + +// auth::init(); + +// let now = SystemTime::now() +// .duration_since(UNIX_EPOCH) +// .unwrap() +// .as_secs(); +// auth::jwt_session_init(&sign_jwt( +// &sk, +// r#"{"kid":1}"#, +// json!({"jti": 1, "nbf": now + 10}), +// )); +// } + +// #[pg_test] +// #[should_panic = "Token used after it has expired"] +// fn invalid_exp() { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); +// let jwk = serde_json::to_string(&jwk).unwrap(); +// set_jwk_in_guc(1, jwk); + +// auth::init(); + +// let now = SystemTime::now() +// .duration_since(UNIX_EPOCH) +// .unwrap() +// .as_secs(); +// auth::jwt_session_init(&sign_jwt( +// &sk, +// r#"{"kid":1}"#, +// json!({"jti": 1, "nbf": now - 10, "exp": now - 5}), +// )); +// } + +// #[pg_test] +// fn valid_time() { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); +// let jwk = serde_json::to_string(&jwk).unwrap(); +// set_jwk_in_guc(1, jwk); + +// auth::init(); + +// let now = SystemTime::now() +// .duration_since(UNIX_EPOCH) +// .unwrap() +// .as_secs(); + +// let header = r#"{"kid":1}"#; + +// auth::jwt_session_init(&sign_jwt( +// &sk, +// header, +// json!({"jti": 1, "nbf": now - 10, "exp": now + 10}), +// )); +// auth::jwt_session_init(&sign_jwt(&sk, header, json!({"jti": 2, "nbf": now - 10}))); +// auth::jwt_session_init(&sign_jwt(&sk, header, json!({"jti": 3, "exp": now + 10}))); +// } + +// #[pg_test] +// fn test_pg_session_jwt() { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); +// let jwk = serde_json::to_string(&jwk).unwrap(); +// set_jwk_in_guc(1, jwk); + +// auth::init(); +// let header = r#"{"kid":1}"#; + +// let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); +// auth::jwt_session_init(&jwt); +// assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); +// assert_eq!(auth::user_id(), "foo"); + +// let jwt = sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#); +// auth::jwt_session_init(&jwt); +// assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); +// assert_eq!(auth::user_id(), "bar"); +// } + +// // bgworker process exits after execution, because of that we don't need to test case for more +// // than one JWT +// #[pg_test] +// fn test_bgworker() { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); +// let jwk = serde_json::to_string(&jwk).unwrap(); +// let header = r#"{"kid":1}"#; +// let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); +// set_jwk_in_guc(1, jwk); +// set_jwt_in_guc(jwt); + +// assert_eq!(auth::user_id(), "foo"); +// assert_eq!(auth::user_id(), "foo"); +// } +// } + +// /// This module is required by `cargo pgrx test` invocations. +// /// It must be visible at the root of your extension crate. +// // #[cfg(test)] +// pub mod pg_test { +// pub fn setup(_options: Vec<&str>) { +// // perform one-off initialization when the pg_test framework starts +// } + +// pub fn postgresql_conf_options() -> Vec<&'static str> { +// // return any postgresql.conf settings that are required for your tests +// vec![] +// } +// } diff --git a/tests/pg_session_jwt.rs b/tests/pg_session_jwt.rs new file mode 100644 index 0000000..57a5b03 --- /dev/null +++ b/tests/pg_session_jwt.rs @@ -0,0 +1,69 @@ +use std::process::ExitCode; + +use base64ct::{Base64UrlUnpadded, Encoding}; +use libtest_mimic::{run, Trial}; +use p256::ecdsa::signature::Signer; +use p256::{ + ecdsa::{Signature, SigningKey}, + elliptic_curve::JwkEcKey, + NistP256, +}; +use rand::rngs::OsRng; +use serde::Serialize; + +pub static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; +pub static NEON_AUTH_JWT_RUNTIME_PARAM: &str = "neon.auth.jwt"; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] +pub struct JwkEc { + /// The key material. + #[serde(flatten)] + pub key: JwkEcKey, + + // The key parameters. + pub kid: i64, +} + +fn main() -> ExitCode { + let args = libtest_mimic::Arguments::from_args(); + + let mut tests = vec![]; + + tests.push(Trial::test("wrong_txid", move || { + let sk = SigningKey::random(&mut OsRng); + let jwk = create_jwk(&sk, 1); + let options = format!("-c {NEON_AUTH_JWK_RUNTIME_PARAM}={jwk}"); + + let error = "Token ID must be strictly monotonically increasing."; + pgrx_tests::run_test(Some(&options), Some(error), vec![], |tx| { + let jwt1 = sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#); + let jwt2 = sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; + + Ok(()) + }) + .map_err(|e| libtest_mimic::Failed::from(e)) + })); + + run(&args, tests).exit_code() +} + +fn sign_jwt(sk: &SigningKey, header: &str, payload: impl ToString) -> String { + let header = Base64UrlUnpadded::encode_string(header.as_bytes()); + let payload = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes()); + + let message = format!("{header}.{payload}"); + let sig: Signature = sk.sign(message.as_bytes()); + let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); + format!("{message}.{base64_sig}") +} + +fn create_jwk(sk: &SigningKey, kid: i64) -> String { + let point = sk.verifying_key().to_encoded_point(false); + let key = JwkEcKey::from_encoded_point::(&point).unwrap(); + let jwk = JwkEc { key, kid }; + serde_json::to_string(&jwk).unwrap() +} From 8f522abf637d501a378180498fcde756c9694154 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 10 Oct 2024 21:30:24 +0200 Subject: [PATCH 17/21] re-impl all tests --- Cargo.lock | 1 - Cargo.toml | 5 +- pgrx-tests/README.md | 8 +- pgrx-tests/src/framework.rs | 50 +++---- pgrx-tests/src/framework/shutdown.rs | 9 +- pgrx-tests/src/lib.rs | 19 --- src/lib.rs | 197 -------------------------- tests/pg_session_jwt.rs | 201 +++++++++++++++++++++++---- 8 files changed, 206 insertions(+), 284 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3eddf3..6a1c5ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1112,7 +1112,6 @@ dependencies = [ "libtest-mimic", "p256", "pgrx", - "pgrx-pg-config", "pgrx-tests", "postgres", "rand", diff --git a/Cargo.toml b/Cargo.toml index 3231149..72040e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ default = ["pg16"] pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ] pg15 = ["pgrx/pg15", "pgrx-tests/pg15" ] pg16 = ["pgrx/pg16", "pgrx-tests/pg16" ] -pg_test = ["dep:rand", "base64ct/alloc"] +pg_test = [] [dependencies] base64ct = { version = "1.6.0", features = ["std"] } @@ -25,12 +25,9 @@ pgrx = "=0.11.3" serde = { version = "1.0.203", features = ["derive"], default-features = false } serde_json = { version = "1.0.117", default-features = false } -rand = { version = "0.8", optional = true } - [dev-dependencies] eyre = "0.6.12" libtest-mimic = "0.8.1" -pgrx-pg-config = "0.11.3" pgrx-tests = { path = "./pgrx-tests" } postgres = "0.19.9" rand = "0.8" diff --git a/pgrx-tests/README.md b/pgrx-tests/README.md index 8aea84d..befa98b 100644 --- a/pgrx-tests/README.md +++ b/pgrx-tests/README.md @@ -1,5 +1,9 @@ # pgrx-tests -Test framework for [`pgrx`](https://crates.io/crates/pgrx/). +Test framework for [`pgrx`](https://crates.io/crates/pgrx/). -Meant to be used as one of your `[dev-dependencies]` when using `pgrx`. \ No newline at end of file +Meant to be used as one of your `[dev-dependencies]` when using `pgrx`. + +Forked off of pgrx 0.11.3 by Conrad Ludgate for the purposes of adding support for +1. Providing options used for initialising GucContext::Backend +2. Running tests as non-superuser diff --git a/pgrx-tests/src/framework.rs b/pgrx-tests/src/framework.rs index ec99e1d..cfc1ad6 100644 --- a/pgrx-tests/src/framework.rs +++ b/pgrx-tests/src/framework.rs @@ -28,7 +28,7 @@ use std::time::Duration; use sysinfo::{Pid, ProcessExt, System, SystemExt}; mod shutdown; -pub use shutdown::add_shutdown_hook; +use shutdown::add_shutdown_hook; type LogLines = Arc>>>; @@ -117,7 +117,7 @@ pub fn run_test( options: Option<&str>, expected_error: Option<&str>, postgresql_conf: Vec<&'static str>, - queries: impl for<'a> FnOnce(&'a mut postgres::Transaction) -> Result<(), postgres::Error>, + queries: impl for<'a> FnOnce(&'a mut postgres::Client) -> Result<(), postgres::Error>, ) -> eyre::Result<()> { if std::env::var_os("PGRX_TEST_SKIP").unwrap_or_default() != "" { eprintln!("Skipping test because `PGRX_TEST_SKIP` is set in the environment",); @@ -127,36 +127,28 @@ pub fn run_test( { let (mut client, _) = client(None, &get_pg_user())?; - client - .execute("ALTER ROLE pgrx WITH NOSUPERUSER LOGIN", &[]) + + let resp = client + .query_opt("SELECT rolname FROM pg_roles WHERE rolname = 'pgrx'", &[]) .unwrap(); + + if resp.is_none() { + client + .execute("CREATE ROLE pgrx WITH NOSUPERUSER LOGIN", &[]) + .unwrap(); + } else { + client + .execute("ALTER ROLE pgrx WITH NOSUPERUSER LOGIN", &[]) + .unwrap(); + } + client .execute("GRANT USAGE ON SCHEMA auth TO pgrx", &[]) .unwrap(); } let (mut client, session_id) = client(options, "pgrx")?; - - let result = client.transaction().map(|mut tx| { - let result = queries(&mut tx); - - // let schema = "tests"; // get_extension_schema(); - // let result = tx.simple_query(&format!("SELECT \"{schema}\".\"{sql_funcname}\"();")); - - if result.is_ok() { - // and abort the transaction when complete - tx.rollback()?; - } - - result - }); - - // flatten the above result - let result = match result { - Err(e) => Err(e), - Ok(Err(e)) => Err(e), - Ok(_) => Ok(()), - }; + let result = queries(&mut client); if let Err(e) = result { let error_as_string = format!("{e}"); @@ -273,7 +265,7 @@ fn get_pg_config() -> eyre::Result { Ok(pg_config) } -pub fn client(options: Option<&str>, user: &str) -> eyre::Result<(postgres::Client, String)> { +fn client(options: Option<&str>, user: &str) -> eyre::Result<(postgres::Client, String)> { let pg_config = get_pg_config()?; let mut config = postgres::Config::new(); @@ -790,11 +782,7 @@ pub(crate) fn get_pg_user() -> String { .unwrap_or_else(|_| panic!("USER environment var is unset or invalid UTF-8")) } -pub fn get_named_capture( - regex: ®ex::Regex, - name: &'static str, - against: &str, -) -> Option { +fn get_named_capture(regex: ®ex::Regex, name: &'static str, against: &str) -> Option { match regex.captures(against) { Some(cap) => Some(cap[name].to_string()), None => None, diff --git a/pgrx-tests/src/framework/shutdown.rs b/pgrx-tests/src/framework/shutdown.rs index bf6c3da..e2168ca 100644 --- a/pgrx-tests/src/framework/shutdown.rs +++ b/pgrx-tests/src/framework/shutdown.rs @@ -23,7 +23,10 @@ where SHUTDOWN_HOOKS .lock() .unwrap_or_else(PoisonError::into_inner) - .push(ShutdownHook { source: Location::caller(), callback: Box::new(func) }); + .push(ShutdownHook { + source: Location::caller(), + callback: Box::new(func), + }); } pub(super) fn register_shutdown_hook() { @@ -51,7 +54,9 @@ pub(super) fn register_shutdown_hook() { extern "C" fn run_shutdown_hooks() { let guard = PanicGuard; let mut any_panicked = false; - let mut hooks = SHUTDOWN_HOOKS.lock().unwrap_or_else(PoisonError::into_inner); + let mut hooks = SHUTDOWN_HOOKS + .lock() + .unwrap_or_else(PoisonError::into_inner); // Note: run hooks in the opposite order they were registered. for hook in mem::take(&mut *hooks).into_iter().rev() { any_panicked |= hook.run().is_err(); diff --git a/pgrx-tests/src/lib.rs b/pgrx-tests/src/lib.rs index f2fe75c..86905f8 100644 --- a/pgrx-tests/src/lib.rs +++ b/pgrx-tests/src/lib.rs @@ -8,23 +8,4 @@ //LICENSE //LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. mod framework; -// #[cfg(any(test, feature = "pg_test"))] -// mod tests; - pub use framework::*; -// #[cfg(feature = "proptest")] -// pub mod proptest; - -// #[cfg(any(test, feature = "pg_test"))] -// pgrx::pg_sql_graph_magic!(); - -// #[cfg(test)] -// pub mod pg_test { -// pub fn setup(_options: Vec<&str>) { -// // noop -// } - -// pub fn postgresql_conf_options() -> Vec<&'static str> { -// vec!["shared_preload_libraries='pgrx_tests'"] -// } -// } diff --git a/src/lib.rs b/src/lib.rs index 13af319..a0badae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -343,200 +343,3 @@ pub mod auth { }) } } - -// #[cfg(any(test, feature = "pg_test"))] -// #[pg_schema] -// mod tests { -// use std::fmt::Display; -// use std::time::{SystemTime, UNIX_EPOCH}; - -// use base64ct::{Base64UrlUnpadded, Encoding}; -// use p256::ecdsa::signature::Signer; -// use p256::{ -// ecdsa::{Signature, SigningKey}, -// elliptic_curve::JwkEcKey, -// }; -// use p256::{NistP256, PublicKey}; -// use pgrx::prelude::*; -// use rand::rngs::OsRng; -// use serde_json::json; - -// use crate::auth; -// use crate::gucs::{NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM}; - -// fn set_jwk_in_guc(kid: i32, key: String) { -// Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWK_RUNTIME_PARAM, key).as_str()).unwrap(); -// } - -// fn set_jwt_in_guc(jwt: String) { -// Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, jwt).as_str()).unwrap(); -// } - -// fn sign_jwt(sk: &SigningKey, header: &str, payload: impl Display) -> String { -// let header = Base64UrlUnpadded::encode_string(header.as_bytes()); -// let payload = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes()); - -// let message = format!("{header}.{payload}"); -// let sig: Signature = sk.sign(message.as_bytes()); -// let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); -// format!("{message}.{base64_sig}") -// } - -// #[pg_test] -// #[should_panic = "JWK state can only be set once per session."] -// fn init_jwk_twice() { -// let sk = SigningKey::random(&mut OsRng); -// let point = sk.verifying_key().to_encoded_point(false); -// let jwk = JwkEcKey::from_encoded_point::(&point).unwrap(); -// let jwk = serde_json::to_value(&jwk).unwrap(); - -// set_jwk_in_guc(1, serde_json::to_string(&jwk).unwrap()); -// auth::init(); - -// set_jwk_in_guc(2, serde_json::to_string(&jwk).unwrap()); -// auth::init(); -// } - -// #[pg_test] -// #[should_panic = "Key ID mismatch"] -// fn wrong_pid() { -// let sk = SigningKey::random(&mut OsRng); -// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); -// let jwk = serde_json::to_string(&jwk).unwrap(); -// set_jwk_in_guc(1, jwk); - -// auth::init(); -// auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":2}"#, r#"{"jti":1}"#)); -// } - -// #[pg_test] -// #[should_panic = "Token ID must be strictly monotonically increasing"] -// fn wrong_txid() { -// let sk = SigningKey::random(&mut OsRng); -// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); -// let jwk = serde_json::to_string(&jwk).unwrap(); -// set_jwk_in_guc(1, jwk); - -// auth::init(); -// auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#)); -// auth::jwt_session_init(&sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#)); -// } - -// #[pg_test] -// #[should_panic = "Token used before it is ready"] -// fn invalid_nbf() { -// let sk = SigningKey::random(&mut OsRng); -// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); -// let jwk = serde_json::to_string(&jwk).unwrap(); -// set_jwk_in_guc(1, jwk); - -// auth::init(); - -// let now = SystemTime::now() -// .duration_since(UNIX_EPOCH) -// .unwrap() -// .as_secs(); -// auth::jwt_session_init(&sign_jwt( -// &sk, -// r#"{"kid":1}"#, -// json!({"jti": 1, "nbf": now + 10}), -// )); -// } - -// #[pg_test] -// #[should_panic = "Token used after it has expired"] -// fn invalid_exp() { -// let sk = SigningKey::random(&mut OsRng); -// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); -// let jwk = serde_json::to_string(&jwk).unwrap(); -// set_jwk_in_guc(1, jwk); - -// auth::init(); - -// let now = SystemTime::now() -// .duration_since(UNIX_EPOCH) -// .unwrap() -// .as_secs(); -// auth::jwt_session_init(&sign_jwt( -// &sk, -// r#"{"kid":1}"#, -// json!({"jti": 1, "nbf": now - 10, "exp": now - 5}), -// )); -// } - -// #[pg_test] -// fn valid_time() { -// let sk = SigningKey::random(&mut OsRng); -// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); -// let jwk = serde_json::to_string(&jwk).unwrap(); -// set_jwk_in_guc(1, jwk); - -// auth::init(); - -// let now = SystemTime::now() -// .duration_since(UNIX_EPOCH) -// .unwrap() -// .as_secs(); - -// let header = r#"{"kid":1}"#; - -// auth::jwt_session_init(&sign_jwt( -// &sk, -// header, -// json!({"jti": 1, "nbf": now - 10, "exp": now + 10}), -// )); -// auth::jwt_session_init(&sign_jwt(&sk, header, json!({"jti": 2, "nbf": now - 10}))); -// auth::jwt_session_init(&sign_jwt(&sk, header, json!({"jti": 3, "exp": now + 10}))); -// } - -// #[pg_test] -// fn test_pg_session_jwt() { -// let sk = SigningKey::random(&mut OsRng); -// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); -// let jwk = serde_json::to_string(&jwk).unwrap(); -// set_jwk_in_guc(1, jwk); - -// auth::init(); -// let header = r#"{"kid":1}"#; - -// let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); -// auth::jwt_session_init(&jwt); -// assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); -// assert_eq!(auth::user_id(), "foo"); - -// let jwt = sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#); -// auth::jwt_session_init(&jwt); -// assert_eq!(NEON_AUTH_JWT.get().unwrap().to_str().unwrap(), &jwt); -// assert_eq!(auth::user_id(), "bar"); -// } - -// // bgworker process exits after execution, because of that we don't need to test case for more -// // than one JWT -// #[pg_test] -// fn test_bgworker() { -// let sk = SigningKey::random(&mut OsRng); -// let jwk = PublicKey::from(sk.verifying_key()).to_jwk(); -// let jwk = serde_json::to_string(&jwk).unwrap(); -// let header = r#"{"kid":1}"#; -// let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); -// set_jwk_in_guc(1, jwk); -// set_jwt_in_guc(jwt); - -// assert_eq!(auth::user_id(), "foo"); -// assert_eq!(auth::user_id(), "foo"); -// } -// } - -// /// This module is required by `cargo pgrx test` invocations. -// /// It must be visible at the root of your extension crate. -// // #[cfg(test)] -// pub mod pg_test { -// pub fn setup(_options: Vec<&str>) { -// // perform one-off initialization when the pg_test framework starts -// } - -// pub fn postgresql_conf_options() -> Vec<&'static str> { -// // return any postgresql.conf settings that are required for your tests -// vec![] -// } -// } diff --git a/tests/pg_session_jwt.rs b/tests/pg_session_jwt.rs index 57a5b03..7b062c6 100644 --- a/tests/pg_session_jwt.rs +++ b/tests/pg_session_jwt.rs @@ -1,4 +1,5 @@ use std::process::ExitCode; +use std::time::{SystemTime, UNIX_EPOCH}; use base64ct::{Base64UrlUnpadded, Encoding}; use libtest_mimic::{run, Trial}; @@ -10,45 +11,189 @@ use p256::{ }; use rand::rngs::OsRng; use serde::Serialize; +use serde_json::json; -pub static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; -pub static NEON_AUTH_JWT_RUNTIME_PARAM: &str = "neon.auth.jwt"; +fn main() -> ExitCode { + let mut args = libtest_mimic::Arguments::from_args(); + // fixes concurrent update failures + args.test_threads = Some(1); -#[derive(Clone, Debug, PartialEq, Eq, Serialize)] -pub struct JwkEc { - /// The key material. - #[serde(flatten)] - pub key: JwkEcKey, + let mut tests = vec![]; - // The key parameters. - pub kid: i64, + let err = "Token ID must be strictly monotonically increasing."; + tests.push(test_fn("wrong_txid", Some(err), wrong_txid)); + + let err = "Token used before it is ready"; + tests.push(test_fn("invalid_nbf", Some(err), invalid_nbf)); + + let err = "Token used after it has expired"; + tests.push(test_fn("invalid_exp", Some(err), invalid_exp)); + + tests.push(test_fn("valid_time", None, valid_time)); + tests.push(test_fn("test_pg_session_jwt", None, test_pg_session_jwt)); + tests.push(test_fn("test_bgworker", None, test_bgworker)); + + run(&args, tests).exit_code() } -fn main() -> ExitCode { - let args = libtest_mimic::Arguments::from_args(); +// bgworker process exits after execution, because of that we don't need to test case for more +// than one JWT +fn test_fn(name: &str, error: Option<&'static str>, f: F) -> Trial +where + F: for<'a, 'b> FnOnce(&'a SigningKey, &'b mut postgres::Client) -> Result<(), postgres::Error> + + Send + + 'static, +{ + let sk = SigningKey::random(&mut OsRng); + let jwk = create_jwk(&sk, 1); + let options = format!("-c {NEON_AUTH_JWK_RUNTIME_PARAM}={jwk}"); - let mut tests = vec![]; + Trial::test(name, move || { + pgrx_tests::run_test(Some(&options), error, vec![], move |tx| f(&sk, tx)) + .map_err(|e| libtest_mimic::Failed::from(e)) + }) +} + +fn wrong_txid(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let jwt1 = sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#); + let jwt2 = sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#); - tests.push(Trial::test("wrong_txid", move || { - let sk = SigningKey::random(&mut OsRng); - let jwk = create_jwk(&sk, 1); - let options = format!("-c {NEON_AUTH_JWK_RUNTIME_PARAM}={jwk}"); + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; - let error = "Token ID must be strictly monotonically increasing."; - pgrx_tests::run_test(Some(&options), Some(error), vec![], |tx| { - let jwt1 = sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#); - let jwt2 = sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#); + Ok(()) +} - tx.execute("select auth.init()", &[])?; - tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; - tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; +fn invalid_nbf(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); - Ok(()) - }) - .map_err(|e| libtest_mimic::Failed::from(e)) - })); + let jwt = sign_jwt(&sk, r#"{"kid":1}"#, json!({"jti": 1, "nbf": now + 10})); - run(&args, tests).exit_code() + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt])?; + + Ok(()) +} + +fn invalid_exp(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let jwt = sign_jwt( + &sk, + r#"{"kid":1}"#, + json!({"jti": 1, "nbf": now - 10, "exp": now - 5}), + ); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt])?; + + Ok(()) +} + +fn valid_time(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let header = r#"{"kid":1}"#; + let jwt1 = sign_jwt( + &sk, + header, + json!({"jti": 1, "nbf": now - 10, "exp": now + 10}), + ); + let jwt2 = sign_jwt(&sk, header, json!({"jti": 2, "nbf": now - 10})); + let jwt3 = sign_jwt(&sk, header, json!({"jti": 3, "exp": now + 10})); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt3])?; + + Ok(()) +} + +fn test_pg_session_jwt(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let header = r#"{"kid":1}"#; + let jwt1 = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); + let jwt2 = sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#); + + tx.execute("select auth.init()", &[])?; + tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; + let user_id = tx.query_one("select auth.user_id()", &[])?; + let user_id = user_id.get::<_, String>("user_id"); + assert_eq!(user_id, "foo"); + + tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; + let user_id = tx.query_one("select auth.user_id()", &[])?; + let user_id = user_id.get::<_, String>("user_id"); + assert_eq!(user_id, "bar"); + + Ok(()) +} + +// bgworker process exits after execution, because of that we don't need to test case for more +// than one JWT +fn test_bgworker(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { + let header = r#"{"kid":1}"#; + let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); + + tx.execute(&format!("set neon.auth.jwt = '{jwt}'"), &[])?; + let user_id = tx.query_one("select auth.user_id()", &[])?; + let user_id = user_id.get::<_, String>("user_id"); + assert_eq!(user_id, "foo"); + + Ok(()) +} + +// fn discard() -> eyre::Result<()> { +// let sk = SigningKey::random(&mut OsRng); +// let jwk = create_jwk(&sk, 1); +// let options = format!("-c {NEON_AUTH_JWK_RUNTIME_PARAM}={jwk}"); + +// let header = r#"{"kid":1}"#; +// let jwt1 = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); +// let jwt2 = sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#); + +// pgrx_tests::run_test(Some(&options), None, vec![], |tx| { +// tx.execute("select auth.init()", &[])?; +// tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; +// let user_id = tx.query_one("select auth.user_id()", &[])?; +// let user_id = user_id.get::<_, Option>("user_id"); +// assert_eq!(user_id.as_deref(), Some("foo")); + +// tx.simple_query("reset neon.auth.jwt")?; + +// let user_id = tx.query_one("select auth.user_id()", &[])?; +// let user_id = user_id.get::<_, Option>("user_id"); +// assert_eq!(user_id.as_deref(), None); + +// tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; +// let user_id = tx.query_one("select auth.user_id()", &[])?; +// let user_id = user_id.get::<_, Option>("user_id"); +// assert_eq!(user_id.as_deref(), Some("bar")); + +// Ok(()) +// }) +// } + +static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] +struct JwkEc { + /// The key material. + #[serde(flatten)] + key: JwkEcKey, + + // The key parameters. + kid: i64, } fn sign_jwt(sk: &SigningKey, header: &str, payload: impl ToString) -> String { From 9ee3a6b31ae3f883990f2d75c557d2603ee11020 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 10 Oct 2024 21:52:06 +0200 Subject: [PATCH 18/21] small changes --- src/gucs.rs | 4 ++-- src/lib.rs | 35 +++++++++++++++++++++-------------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/gucs.rs b/src/gucs.rs index e52c925..bfafc73 100644 --- a/src/gucs.rs +++ b/src/gucs.rs @@ -14,8 +14,8 @@ pub fn init() { "JSON Web Key (JWK) used for JWT validation", "Generated per connection by Neon local proxy", &NEON_AUTH_JWK, - GucContext::Backend, // we should use GucContext::Backend but this breaks unit tests - GucFlags::NOT_WHILE_SEC_REST, + GucContext::Backend, + GucFlags::NOT_WHILE_SEC_REST | GucFlags::NO_RESET_ALL, ); GucRegistry::define_string_guc( diff --git a/src/lib.rs b/src/lib.rs index a0badae..e01266f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -231,14 +231,21 @@ pub mod auth { /// This function will panic if the JWT could not be verified. #[pg_extern] pub fn jwt_session_init(jwt: &str) { - Spi::run(format!("SET {} = '{}'", NEON_AUTH_JWT_RUNTIME_PARAM, jwt).as_str()) - .unwrap_or_else(|e| { - error_code!( - PgSqlErrorCode::ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED, - format!("Couldn't set {}", NEON_AUTH_JWT_RUNTIME_PARAM), - e.to_string(), - ) - }); + Spi::run( + format!( + "SET {} = {}", + NEON_AUTH_JWT_RUNTIME_PARAM, + spi::quote_literal(jwt) + ) + .as_str(), + ) + .unwrap_or_else(|e| { + error_code!( + PgSqlErrorCode::ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED, + format!("Couldn't set {}", NEON_AUTH_JWT_RUNTIME_PARAM), + e.to_string(), + ) + }); set_jwt_cache() } @@ -297,7 +304,7 @@ pub mod auth { /// Extract a value from the shared state. #[pg_extern] - pub fn session(s: &str) -> JsonB { + pub fn session() -> JsonB { JWK.with(|j| { if j.get().is_none() { // assuming that running as bgworker @@ -309,16 +316,16 @@ pub mod auth { JWT.with_borrow(|j| { JsonB( j.as_ref() - .and_then(|j| j.get(s).cloned()) - .unwrap_or(serde_json::Value::Null), + .cloned() + .map_or(serde_json::Value::Null, serde_json::Value::Object), ) }) } #[pg_extern] - pub fn user_id() -> String { - match session("sub").0 { - serde_json::Value::String(s) => s, + pub fn user_id() -> Option { + match session().0.get("sub")? { + serde_json::Value::String(s) => Some(s.clone()), _ => error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "invalid subject claim in the JWT" From a70969e02c29253982e65386ad9c0f6573c12f72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Fri, 11 Oct 2024 11:05:17 +0200 Subject: [PATCH 19/21] update docs --- CONTRIBUTING.md | 22 +++++++++++++++++++++- README.md | 18 ++++++++++++++---- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 81e578f..acf0b7a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,8 +15,15 @@ Let's initialize pgrx. cargo pgrx init ``` -It's time to run `pg_session_jwt` locally with +## How to run the extension locally + +It's time to run `pg_session_jwt` locally. Please note that `neon.auth.jwk` +parameter MUST be set when new connection is created (for more details please +refer to the README file). ```console +MY_JWK=... +export PGOPTIONS="-c neon.auth.jwk=$MY_JWK" + cargo pgrx run pg16 ``` @@ -35,3 +42,16 @@ If you introduce new function make sure to reload the extension with DROP EXTENSION pg_session_jwt; CREATE EXTENSION pg_session_jwt; ``` + +## Before sending a PR + +You can lint your code with +```console +rustfmt src/*.rs tests/*.rs +cargo clippy --fix --allow-staged +``` + +You can run test-suite +```console +cargo test +``` diff --git a/README.md b/README.md index 6d74a9c..c00c66e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ pg\_session\_jwt ================ -`pg_session_jwt` is a PostgreSQL extension designed to handle authenticated sessions through a JWT. This JWT is then verified against a JWK (JSON Web Key) to ensure its authenticity. Both the JWK and the JWT must be provided to the extension by a Postgres superuser. The extension then stores the JWT in the database for later retrieval, and exposes functions to retrieve the user ID (the `sub` subject field) and other parts of the payload. +`pg_session_jwt` is a PostgreSQL extension designed to handle authenticated sessions through a JWT. This JWT is then verified against a JWK (JSON Web Key) to ensure its authenticity. + +**JWK can only be set at postmaster startup, from the configuration file, or by client request in the connection startup packet** (e.g., from libpq's PGOPTIONS variable), whereas the JWT can be set anytime at runtime. The extension then stores the JWT in the database for later retrieval, and exposes functions to retrieve the user ID (the `sub` subject field) and other parts of the payload. The goal of this extension is to provide a secure and efficient way to manage authenticated sessions in a PostgreSQL database. The JWTs can be generated by third-party auth providers, and then developers can leverage the JWT for [Row Level Security](https://www.postgresql.org/docs/current/ddl-rowsecurity.html) (RLS) policies, or to retrieve the user ID for other purposes (column defaults, filters, etc.). @@ -20,15 +22,23 @@ Features Usage ----- +Before calling functions make sure that `neon.auth.jwk` parameter is properly initialized. [libpq connect options](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-OPTIONS) can be used for that. + +For example: +```console +MY_JWK=... +export PGOPTIONS="-c neon.auth.jwk=$MY_JWK" +``` + `pg_session_jwt` exposes four main functions: -### 1\. auth.init(kid bigint, jwk jsonb) → void +### 1\. auth.init() → void -Initializes a session with a given key identifier (KID) and JWK data in JSONB format. +Initializes a session using JWK stored in `neon.auth.jwk` [run-time parameter](https://www.postgresql.org/docs/current/sql-show.html). Please remember that this parameter is fixed for a given connection once it's started (but it can vary across different connections) ### 2\. auth.jwt\_session\_init(jwt text) → void -Initializes the JWT session with the provided `jwt` as a string. +Initializes the JWT session with the provided `jwt` as a string. JWT must be signed by the JWK that was initialized with `auth.init()` ### 3\. auth.session(s text) → jsonb From 37840d6d9b436002737ae73cbd111d192be3191d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Fri, 11 Oct 2024 11:19:50 +0200 Subject: [PATCH 20/21] stop relying on kid --- src/lib.rs | 56 ++++++----------------------------------- tests/pg_session_jwt.rs | 18 +++---------- 2 files changed, 11 insertions(+), 63 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e01266f..23a8935 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,5 @@ mod gucs; -use p256::elliptic_curve::JwkEcKey; use pgrx::prelude::*; pgrx::pg_module_magic!(); @@ -22,21 +21,6 @@ pub unsafe extern "C" fn _PG_init() { gucs::init(); } -/// An Elliptic Curve JSON Web Key. -/// -/// This type is defined in [RFC7517 Section 4]. -/// -/// [RFC7517 Section 4]: https://datatracker.ietf.org/doc/html/rfc7517#section-4 -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize)] -struct JwkEc { - /// The key material. - #[serde(flatten)] - key: JwkEcKey, - - // The key parameters. - kid: i64, -} - #[pg_schema] pub mod auth { use std::cell::{OnceCell, RefCell}; @@ -45,6 +29,7 @@ pub mod auth { use p256::ecdsa::signature::Verifier; use p256::ecdsa::{Signature, VerifyingKey}; use p256::elliptic_curve::generic_array::GenericArray; + use p256::elliptic_curve::JwkEcKey; use p256::PublicKey; use pgrx::prelude::*; use pgrx::JsonB; @@ -53,22 +38,15 @@ pub mod auth { use crate::gucs::{ NEON_AUTH_JWK, NEON_AUTH_JWK_RUNTIME_PARAM, NEON_AUTH_JWT, NEON_AUTH_JWT_RUNTIME_PARAM, }; - use crate::JwkEc; type Object = serde_json::Map; thread_local! { - static JWK: OnceCell = const { OnceCell::new() }; + static JWK: OnceCell = const { OnceCell::new() }; static JWT: RefCell> = const { RefCell::new(None) }; static JTI: RefCell = const { RefCell::new(0) }; } - #[derive(Clone)] - struct Key { - kid: i64, - key: VerifyingKey, - } - /// Set the public key and key ID for this postgres session. /// /// # Panics @@ -94,14 +72,14 @@ pub mod auth { ) }); - let jwk: JwkEc = serde_json::from_str(jwk).unwrap_or_else(|e| { + let jwk: JwkEcKey = serde_json::from_str(jwk).unwrap_or_else(|e| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "session init requires an ES256 JWK", e.to_string(), ) }); - let key = PublicKey::from_jwk(&jwk.key).unwrap_or_else(|p256::elliptic_curve::Error| { + let key = PublicKey::from_jwk(&jwk).unwrap_or_else(|p256::elliptic_curve::Error| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "session init requires an ES256 JWK", @@ -109,7 +87,7 @@ pub mod auth { }); let key = VerifyingKey::from(key); JWK.with(|j| { - if j.set(Key { kid: jwk.kid, key }).is_err() { + if j.set(key).is_err() { error_code!( PgSqlErrorCode::ERRCODE_UNIQUE_VIOLATION, "JWK state can only be set once per session.", @@ -118,7 +96,7 @@ pub mod auth { }) } - fn verify_signature(key: &Key, body: &str, sig: &str) { + fn verify_signature(key: &VerifyingKey, body: &str, sig: &str) { let mut sig_bytes = GenericArray::default(); Base64UrlUnpadded::decode(sig, &mut sig_bytes).unwrap_or_else(|_| { error_code!( @@ -133,7 +111,7 @@ pub mod auth { ) }); - key.key.verify(body.as_bytes(), &sig).unwrap_or_else(|_| { + key.verify(body.as_bytes(), &sig).unwrap_or_else(|_| { error_code!( PgSqlErrorCode::ERRCODE_CHECK_VIOLATION, "invalid JWT signature", @@ -141,22 +119,6 @@ pub mod auth { }); } - fn verify_key_id(key: &Key, header: &Object) { - let kid = header - .get("kid") - .and_then(|x| x.as_i64()) - .unwrap_or_else(|| { - error_code!( - PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, - "JWT header must contain a valid 'kid' (key ID)", - ) - }); - - if key.kid != kid { - error_code!(PgSqlErrorCode::ERRCODE_CHECK_VIOLATION, "Key ID mismatch"); - } - } - fn verify_token_id(payload: &Object) -> i64 { let jti = payload .get("jti") @@ -282,15 +244,13 @@ pub mod auth { "invalid JWT encoding", ) }); - let (header, payload) = body.split_once('.').unwrap_or_else(|| { + let (_, payload) = body.split_once('.').unwrap_or_else(|| { error_code!( PgSqlErrorCode::ERRCODE_DATATYPE_MISMATCH, "invalid JWT encoding", ) }); - let header: Object = json_base64_decode(header); - verify_key_id(&key, &header); verify_signature(&key, body, sig); let payload: Object = json_base64_decode(payload); diff --git a/tests/pg_session_jwt.rs b/tests/pg_session_jwt.rs index 7b062c6..a9aafae 100644 --- a/tests/pg_session_jwt.rs +++ b/tests/pg_session_jwt.rs @@ -10,7 +10,6 @@ use p256::{ NistP256, }; use rand::rngs::OsRng; -use serde::Serialize; use serde_json::json; fn main() -> ExitCode { @@ -45,7 +44,7 @@ where + 'static, { let sk = SigningKey::random(&mut OsRng); - let jwk = create_jwk(&sk, 1); + let jwk = create_jwk(&sk); let options = format!("-c {NEON_AUTH_JWK_RUNTIME_PARAM}={jwk}"); Trial::test(name, move || { @@ -186,16 +185,6 @@ fn test_bgworker(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postg static NEON_AUTH_JWK_RUNTIME_PARAM: &str = "neon.auth.jwk"; -#[derive(Clone, Debug, PartialEq, Eq, Serialize)] -struct JwkEc { - /// The key material. - #[serde(flatten)] - key: JwkEcKey, - - // The key parameters. - kid: i64, -} - fn sign_jwt(sk: &SigningKey, header: &str, payload: impl ToString) -> String { let header = Base64UrlUnpadded::encode_string(header.as_bytes()); let payload = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes()); @@ -206,9 +195,8 @@ fn sign_jwt(sk: &SigningKey, header: &str, payload: impl ToString) -> String { format!("{message}.{base64_sig}") } -fn create_jwk(sk: &SigningKey, kid: i64) -> String { +fn create_jwk(sk: &SigningKey) -> String { let point = sk.verifying_key().to_encoded_point(false); let key = JwkEcKey::from_encoded_point::(&point).unwrap(); - let jwk = JwkEc { key, kid }; - serde_json::to_string(&jwk).unwrap() + serde_json::to_string(&key).unwrap() } From c71030c86deddaf887dbc8919e46a347f327f19f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= Date: Fri, 11 Oct 2024 11:23:58 +0200 Subject: [PATCH 21/21] linter --- src/lib.rs | 14 ++++++-------- tests/pg_session_jwt.rs | 22 +++++++++++----------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 23a8935..762713b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -229,14 +229,12 @@ pub mod auth { ) }); let key = JWK.with(|b| { - b.get() - .unwrap_or_else(|| { - error_code!( - PgSqlErrorCode::ERRCODE_NOT_NULL_VIOLATION, - "JWK state has not been initialised", - ) - }) - .clone() + *b.get().unwrap_or_else(|| { + error_code!( + PgSqlErrorCode::ERRCODE_NOT_NULL_VIOLATION, + "JWK state has not been initialised", + ) + }) }); let (body, sig) = jwt.rsplit_once('.').unwrap_or_else(|| { error_code!( diff --git a/tests/pg_session_jwt.rs b/tests/pg_session_jwt.rs index a9aafae..b40285c 100644 --- a/tests/pg_session_jwt.rs +++ b/tests/pg_session_jwt.rs @@ -49,13 +49,13 @@ where Trial::test(name, move || { pgrx_tests::run_test(Some(&options), error, vec![], move |tx| f(&sk, tx)) - .map_err(|e| libtest_mimic::Failed::from(e)) + .map_err(libtest_mimic::Failed::from) }) } fn wrong_txid(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { - let jwt1 = sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":1}"#); - let jwt2 = sign_jwt(&sk, r#"{"kid":1}"#, r#"{"jti":2}"#); + let jwt1 = sign_jwt(sk, r#"{"kid":1}"#, r#"{"jti":1}"#); + let jwt2 = sign_jwt(sk, r#"{"kid":1}"#, r#"{"jti":2}"#); tx.execute("select auth.init()", &[])?; tx.execute("select auth.jwt_session_init($1)", &[&jwt2])?; @@ -70,7 +70,7 @@ fn invalid_nbf(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgre .unwrap() .as_secs(); - let jwt = sign_jwt(&sk, r#"{"kid":1}"#, json!({"jti": 1, "nbf": now + 10})); + let jwt = sign_jwt(sk, r#"{"kid":1}"#, json!({"jti": 1, "nbf": now + 10})); tx.execute("select auth.init()", &[])?; tx.execute("select auth.jwt_session_init($1)", &[&jwt])?; @@ -85,7 +85,7 @@ fn invalid_exp(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgre .as_secs(); let jwt = sign_jwt( - &sk, + sk, r#"{"kid":1}"#, json!({"jti": 1, "nbf": now - 10, "exp": now - 5}), ); @@ -104,12 +104,12 @@ fn valid_time(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres let header = r#"{"kid":1}"#; let jwt1 = sign_jwt( - &sk, + sk, header, json!({"jti": 1, "nbf": now - 10, "exp": now + 10}), ); - let jwt2 = sign_jwt(&sk, header, json!({"jti": 2, "nbf": now - 10})); - let jwt3 = sign_jwt(&sk, header, json!({"jti": 3, "exp": now + 10})); + let jwt2 = sign_jwt(sk, header, json!({"jti": 2, "nbf": now - 10})); + let jwt3 = sign_jwt(sk, header, json!({"jti": 3, "exp": now + 10})); tx.execute("select auth.init()", &[])?; tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; @@ -121,8 +121,8 @@ fn valid_time(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres fn test_pg_session_jwt(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { let header = r#"{"kid":1}"#; - let jwt1 = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); - let jwt2 = sign_jwt(&sk, header, r#"{"sub":"bar","jti":2}"#); + let jwt1 = sign_jwt(sk, header, r#"{"sub":"foo","jti":1}"#); + let jwt2 = sign_jwt(sk, header, r#"{"sub":"bar","jti":2}"#); tx.execute("select auth.init()", &[])?; tx.execute("select auth.jwt_session_init($1)", &[&jwt1])?; @@ -142,7 +142,7 @@ fn test_pg_session_jwt(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), // than one JWT fn test_bgworker(sk: &SigningKey, tx: &mut postgres::Client) -> Result<(), postgres::Error> { let header = r#"{"kid":1}"#; - let jwt = sign_jwt(&sk, header, r#"{"sub":"foo","jti":1}"#); + let jwt = sign_jwt(sk, header, r#"{"sub":"foo","jti":1}"#); tx.execute(&format!("set neon.auth.jwt = '{jwt}'"), &[])?; let user_id = tx.query_one("select auth.user_id()", &[])?;