Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Address Whitelist Functionality #813

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ tracing-subscriber = { version = "0.3.17", features = [
"std",
] }
lru = "0.12.0"
ipnet = "2.9.0"

[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"
1 change: 1 addition & 0 deletions src/auth_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ impl AuthPassthrough {
username: self.user.clone(),
auth_type: AuthType::MD5,
password: Some(self.password.clone()),
address_whitelist: None,
server_username: None,
server_password: None,
pool_size: 1,
Expand Down
101 changes: 86 additions & 15 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
use crate::errors::{ClientIdentifier, Error};
use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use std::collections::{HashMap, VecDeque};
use std::sync::{atomic::AtomicUsize, Arc};
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;

use crate::admin::{generate_server_parameters_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{
get_config, get_idle_client_in_transaction_timeout, Address, AuthType, PoolMode,
};
use crate::constants::*;
use crate::dns_cache::CACHED_RESOLVER;
use crate::errors::{ClientIdentifier, Error};
use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::BanReason;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::{Server, ServerParameters};
use crate::stats::{ClientStats, ServerStats};
use crate::tls::Tls;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use ipnet::IpNet;
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use std::collections::{HashMap, VecDeque};
use std::sync::{atomic::AtomicUsize, Arc};
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;

use tokio_rustls::server::TlsStream;

Expand Down Expand Up @@ -415,6 +416,60 @@ pub async fn startup_tls(
}
}

pub async fn check_hostname_whitelist_entry(hostname: String, addr: std::net::SocketAddr) -> bool {
let cached_resolver = CACHED_RESOLVER.load();
let addr_set = match cached_resolver.lookup_ip(&hostname).await {
Ok(ok) => {
debug!("Obtained: {:?}", ok);
Some(ok)
}
Err(err) => {
warn!("Error trying to resolve {}, ({:?})", &hostname, err);
None
}
};

// Look through address set
if addr_set.is_some(){
for ip in &addr_set.unwrap().set {
if addr.ip() == *ip {
return true;
}
}
};
false
}

pub async fn check_whitelist_entries(
addr: std::net::SocketAddr,
whitelist_entries: &Option<Vec<String>>,
) -> bool {
match whitelist_entries {
Some(whitelist_entries_value) => {
for entry in whitelist_entries_value {
let parsed_ip_result: Result<IpNet, _> = entry.parse();
match parsed_ip_result {
// Compare ip to address ip
Ok(parsed_ip) => {
if parsed_ip.contains(&addr.ip()) {
return true;
}
}

// If ip is hostname then convert to ip/iplist then check
Err(_) => {
if check_hostname_whitelist_entry(entry.clone(), addr).await {
return true;
}
}
}
}
false
}
None => true,
}
}

impl<S, T> Client<S, T>
where
S: tokio::io::AsyncRead + std::marker::Unpin,
Expand Down Expand Up @@ -488,6 +543,15 @@ where
// Authenticate admin user.
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();

if !check_whitelist_entries(addr, &config.general.admin_address_whitelist).await {
let error =
Error::ClientGeneralError("IP Address not allowed".into(), client_identifier);
warn!("{}", error);
ip_address_whitelist_fail(&mut write, username, &addr.ip().to_string()).await?;
return Err(error);
}

// TODO: Add SASL support.
// Perform MD5 authentication.
match config.general.admin_auth_type {
Expand All @@ -512,7 +576,6 @@ where
code as char
)));
}

let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => {
Expand Down Expand Up @@ -576,6 +639,14 @@ where
}
};

if !check_whitelist_entries(addr, &pool.settings.user.address_whitelist).await {
let error =
Error::ClientGeneralError("IP Address not allowed".into(), client_identifier);
warn!("{}", error);
ip_address_whitelist_fail(&mut write, username, &addr.ip().to_string()).await?;
return Err(error);
};

// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
Expand Down
4 changes: 4 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ pub struct User {

#[serde(default = "User::default_auth_type")]
pub auth_type: AuthType,
pub address_whitelist: Option<Vec<String>>,
pub server_username: Option<String>,
pub server_password: Option<String>,
pub pool_size: u32,
Expand All @@ -229,6 +230,7 @@ impl Default for User {
username: String::from("postgres"),
password: None,
auth_type: AuthType::MD5,
address_whitelist: None,
server_username: None,
server_password: None,
pool_size: 15,
Expand Down Expand Up @@ -341,6 +343,7 @@ pub struct General {

pub admin_username: String,
pub admin_password: String,
pub admin_address_whitelist: Option<Vec<String>>,

#[serde(default = "General::default_admin_auth_type")]
pub admin_auth_type: AuthType,
Expand Down Expand Up @@ -472,6 +475,7 @@ impl Default for General {
admin_username: String::from("admin"),
admin_password: String::from("admin"),
admin_auth_type: AuthType::MD5,
admin_address_whitelist: None,
validate_config: true,
auth_query: None,
auth_query_user: None,
Expand Down
2 changes: 1 addition & 1 deletion src/dns_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
// so we can compare.
#[derive(Clone, PartialEq, Debug)]
pub struct AddrSet {
set: HashSet<IpAddr>,
pub set: HashSet<IpAddr>,
}

impl AddrSet {
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ extern crate async_trait;
extern crate bb8;
extern crate bytes;
extern crate exitcode;
extern crate ipnet;
extern crate log;
extern crate md5;
extern crate num_cpus;
Expand Down
46 changes: 46 additions & 0 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,52 @@ where
write_all_half(stream, &res).await
}

pub async fn ip_address_whitelist_fail<S>(
stream: &mut S,
user: &str,
source_ip: &str,
) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut error = BytesMut::new();

// Error level
error.put_u8(b'S');
error.put_slice(&b"FATAL\0"[..]);

// Error level (non-translatable)
error.put_u8(b'V');
error.put_slice(&b"FATAL\0"[..]);

// Error code: not sure how much this matters.
error.put_u8(b'C');
error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A.

// The short error message.
error.put_u8(b'M');
error.put_slice(
format!(
"No IP whitelist entry for (\"{}\", \"{}\")\0",
user, source_ip
)
.as_bytes(),
);

// No more fields follow.
error.put_u8(0);

// Compose the two message reply.
let mut res = BytesMut::new();

res.put_u8(b'E');
res.put_i32(error.len() as i32 + 4);

res.put(error);

write_all(stream, res).await
}

pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
Expand Down
Loading