Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use simple_connection for auth_query #342

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.55
1.1.56
49 changes: 49 additions & 0 deletions lib/single_connection.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
defmodule Supavisor.SingleConnection do
require Logger
@behaviour Postgrex.SimpleConnection

def connect(conf), do: Postgrex.SimpleConnection.start_link(__MODULE__, conf, conf)

@impl true
def init(args) do
Logger.debug("init args: #{inspect(args, pretty: true)}")
Process.monitor(args[:caller])
# put the hostname in the process dictionary to be able to find it in an emergency
Process.put(:auth_host, args[:hostname])
{:ok, %{from: nil, caller: args[:caller]}}
end

@impl true
def handle_call({:query, query}, from, state), do: {:query, query, %{state | from: from}}

def handle_result(results, state) when is_list(results) do
result =
case results do
[%Postgrex.Result{} = res] -> res
other -> other
end

Postgrex.SimpleConnection.reply(state.from, result)
{:noreply, state}
end

@impl true
def handle_result(%Postgrex.Error{} = error, state) do
Postgrex.SimpleConnection.reply(state.from, error)
{:noreply, state}
end

@impl true
def handle_info({:DOWN, _, _, caller, _}, %{caller: caller} = state) do
Logger.notice("Caller #{inspect(caller)} is down")
{:stop, state}
end

def handle_info(msg, state) do
Logger.error("Undefined message #{inspect(msg, pretty: true)}")
{:noreply, state}
end

@impl true
def notify(_, _, _), do: :ok
end
15 changes: 11 additions & 4 deletions lib/supavisor/client_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ defmodule Supavisor.ClientHandler do
alias Supavisor.DbHandler, as: Db
alias Supavisor.Helpers, as: H
alias Supavisor.HandlerHelpers, as: HH
alias Supavisor.{Tenants, Monitoring.Telem, Protocol.Client, Protocol.Server}
alias Supavisor.{Tenants, Monitoring.Telem, Protocol.Client, Protocol.Server, SingleConnection}

@impl true
def start_link(ref, _sock, transport, opts) do
Expand Down Expand Up @@ -822,6 +822,8 @@ defmodule Supavisor.ClientHandler do

@spec get_secrets(map, String.t()) :: {:ok, {:auth_query, fun()}} | {:error, term()}
def get_secrets(%{user: user, tenant: tenant}, db_user) do
Logger.info("ClientHandler: Get secrets started")

ssl_opts =
if tenant.upstream_ssl and tenant.upstream_verify == "peer" do
[
Expand All @@ -833,7 +835,7 @@ defmodule Supavisor.ClientHandler do
end

{:ok, conn} =
Postgrex.start_link(
SingleConnection.connect(
hostname: tenant.db_host,
port: tenant.db_port,
database: tenant.db_database,
Expand All @@ -846,9 +848,14 @@ defmodule Supavisor.ClientHandler do
],
queue_target: 1_000,
queue_interval: 5_000,
ssl_opts: ssl_opts || []
ssl_opts: ssl_opts || [],
caller: self()
)

Logger.debug(
"ClientHandler: Connected to db #{tenant.db_host} #{tenant.db_port} #{tenant.db_database} #{user.db_user}"
)

resp =
case H.get_user_secret(conn, tenant.auth_query, db_user) do
{:ok, secret} ->
Expand All @@ -859,7 +866,7 @@ defmodule Supavisor.ClientHandler do
{:error, reason}
end

GenServer.stop(conn, :normal)
Logger.info("ClientHandler: Get secrets finished")
resp
end

Expand Down
6 changes: 5 additions & 1 deletion lib/supavisor/helpers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ defmodule Supavisor.Helpers do
@spec get_user_secret(pid(), String.t(), String.t()) :: {:ok, map()} | {:error, String.t()}
def get_user_secret(conn, auth_query, user) do
try do
Postgrex.query!(conn, auth_query, [user])
# sanitize the user input by removing all characters that are not alphanumeric or underscores
user = String.replace(user, ~r/[^a-zA-Z0-9_]/, "")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that instead we should reject user if these characters are present instead of just pretending that these do not exists.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already have a check for invalid characters during connection, but it only checks for slashes. I added this sanitization as an additional guard. It might be worth applying the same rule (a-zA-Z0-9_) during connection as well

not_allowed = ["\"", "\\"]
if String.contains?(user, not_allowed) or String.contains?(db_name, not_allowed) do
reason = "Invalid characters in user or db_name"
Logger.error("ClientHandler: #{inspect(reason)}")
Telem.client_join(:fail, data.id)
HH.send_error(data.sock, "XX000", "Authentication error, reason: #{inspect(reason)}")
{:stop, {:shutdown, :invalid_characters}}

$ psql postgresql://p\"ostgres.tenant:password@localhost:5432/postgres
psql: error: connection to server at "localhost" (::1), port 5432 failed: Connection refused
	Is the server running on that host and accepting TCP/IP connections?
connection to server at "localhost" (127.0.0.1), port 5432 failed: FATAL:  Authentication error, reason: "Invalid characters in user or db_name"

auth_query = String.replace(auth_query, "$1", "'#{user}'")
abc3 marked this conversation as resolved.
Show resolved Hide resolved

Postgrex.SimpleConnection.call(conn, {:query, auth_query})
catch
_error, reason ->
{:error, "Authentication query failed: #{inspect(reason)}"}
Expand Down
36 changes: 36 additions & 0 deletions test/integration/single_connection_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defmodule Supavisor.Integration.SingleConnectionTest do
require Logger
use Supavisor.DataCase, async: true
alias Postgrex, as: P

@tenant "proxy_tenant1"

test "connects to database and executes a simple query" do
db_conf = Application.get_env(:supavisor, Repo)

args = [
hostname: db_conf[:hostname],
port: Application.get_env(:supavisor, :proxy_port_transaction),
database: "postgres",
password: db_conf[:password],
username: "transaction.#{@tenant}"
]

spawn(fn ->
{:ok, pid} =
args
|> Keyword.put_new(:caller, self())
|> Supavisor.SingleConnection.connect()

assert %Postgrex.Result{rows: [["1"]]} =
Postgrex.SimpleConnection.call(pid, {:query, "SELECT 1"})
end)

:timer.sleep(250)

# check that the connection dies after the caller dies
assert Enum.filter(Process.list(), fn pid ->
Process.info(pid)[:dictionary][:auth_host] == db_conf[:hostname]
end) == []
Comment on lines +32 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert Enum.filter(Process.list(), fn pid ->
Process.info(pid)[:dictionary][:auth_host] == db_conf[:hostname]
end) == []
assert Enum.all?(Process.list(), fn pid ->
Process.info(pid)[:dictionary][:auth_host] == db_conf[:hostname]
end)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what is worth, using filter for this test has a tiny benefit that it shows the left side with the failed PIDs in exception reports. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. That indeed may be useful.

end
end
Loading