Skip to content

Commit

Permalink
Merge pull request #15029 from rgacogne/ddist-doh3-set-http-response
Browse files Browse the repository at this point in the history
dnsdist: Add the ability to set custom HTTP responses over DoH3
  • Loading branch information
rgacogne authored Jan 14, 2025
2 parents bfaf5ed + 7285f2f commit 72f679d
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 70 deletions.
36 changes: 24 additions & 12 deletions pdns/dnsdistdist/dnsdist-lua-actions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2019,7 +2019,7 @@ class ContinueAction : public DNSAction
std::shared_ptr<DNSAction> d_action;
};

#ifdef HAVE_DNS_OVER_HTTPS
#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
class HTTPStatusAction : public DNSAction
{
public:
Expand All @@ -2030,17 +2030,29 @@ class HTTPStatusAction : public DNSAction

DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
{
if (!dnsquestion->ids.du) {
return Action::None;
#if defined(HAVE_DNS_OVER_HTTPS)
if (dnsquestion->ids.du) {
dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
header.qr = true; // for good measure
setResponseHeadersFromConfig(header, d_responseConfig);
return true;
});
return Action::HeaderModify;
}

dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
header.qr = true; // for good measure
setResponseHeadersFromConfig(header, d_responseConfig);
return true;
});
return Action::HeaderModify;
#endif /* HAVE_DNS_OVER_HTTPS */
#if defined(HAVE_DNS_OVER_HTTP3)
if (dnsquestion->ids.doh3u) {
dnsquestion->ids.doh3u->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
header.qr = true; // for good measure
setResponseHeadersFromConfig(header, d_responseConfig);
return true;
});
return Action::HeaderModify;
}
#endif /* HAVE_DNS_OVER_HTTP3 */
return Action::None;
}

[[nodiscard]] std::string toString() const override
Expand All @@ -2059,7 +2071,7 @@ class HTTPStatusAction : public DNSAction
std::string d_contentType;
int d_code;
};
#endif /* HAVE_DNS_OVER_HTTPS */
#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */

#if defined(HAVE_LMDB) || defined(HAVE_CDB)
class KeyValueStoreLookupAction : public DNSAction
Expand Down
13 changes: 9 additions & 4 deletions pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
#endif /* HAVE_NET_SNMP */
});

#ifdef HAVE_DNS_OVER_HTTPS
#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
luaCtx.registerFunction<std::string (DNSQuestion::*)(void) const>("getHTTPPath", [](const DNSQuestion& dnsQuestion) {
if (dnsQuestion.ids.du) {
return dnsQuestion.ids.du->getHTTPPath();
Expand Down Expand Up @@ -563,14 +563,19 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
});

luaCtx.registerFunction<void (DNSQuestion::*)(uint64_t statusCode, const std::string& body, const boost::optional<std::string> contentType)>("setHTTPResponse", [](DNSQuestion& dnsQuestion, uint64_t statusCode, const std::string& body, const boost::optional<std::string>& contentType) {
if (dnsQuestion.ids.du == nullptr) {
if (dnsQuestion.ids.du == nullptr && dnsQuestion.ids.doh3u == nullptr) {
return;
}
checkParameterBound("DNSQuestion::setHTTPResponse", statusCode, std::numeric_limits<uint16_t>::max());
PacketBuffer vect(body.begin(), body.end());
dnsQuestion.ids.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
if (dnsQuestion.ids.du) {
dnsQuestion.ids.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
}
else {
dnsQuestion.ids.doh3u->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
}
});
#endif /* HAVE_DNS_OVER_HTTPS */
#endif /* HAVE_DNS_OVER_HTTPS HAVE_DNS_OVER_HTTP3 */

luaCtx.registerFunction<bool (DNSQuestion::*)(bool nxd, const std::string& zone, uint64_t ttl, const std::string& mname, const std::string& rname, uint64_t serial, uint64_t refresh, uint64_t retry, uint64_t expire, uint64_t minimum)>("setNegativeAndAdditionalSOA", [](DNSQuestion& dnsQuestion, bool nxd, const std::string& zone, uint64_t ttl, const std::string& mname, const std::string& rname, uint64_t serial, uint64_t refresh, uint64_t retry, uint64_t expire, uint64_t minimum) {
checkParameterBound("setNegativeAndAdditionalSOA", ttl, std::numeric_limits<uint32_t>::max());
Expand Down
30 changes: 20 additions & 10 deletions pdns/dnsdistdist/dnsdist-lua-ffi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,17 +499,27 @@ void dnsdist_ffi_dnsquestion_set_result(dnsdist_ffi_dnsquestion_t* dq, const cha

void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, size_t bodyLen, const char* contentType)
{
if (dq->dq->ids.du == nullptr) {
return;
#if defined(HAVE_DNS_OVER_HTTPS)
if (dq->dq->ids.du) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): C API
PacketBuffer bodyVect(body, body + bodyLen);
dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
header.qr = true;
return true;
});
}
#endif
#if defined(HAVE_DNS_OVER_HTTP3)
if (dq->dq->ids.doh3u) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): C API
PacketBuffer bodyVect(body, body + bodyLen);
dq->dq->ids.doh3u->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
header.qr = true;
return true;
});
}

#ifdef HAVE_DNS_OVER_HTTPS
PacketBuffer bodyVect(body, body + bodyLen);
dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
header.qr = true;
return true;
});
#endif
}

Expand Down
94 changes: 59 additions & 35 deletions pdns/dnsdistdist/doh3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,40 +285,53 @@ static bool tryWriteResponse(H3Connection& conn, const uint64_t streamID, Packet
return true;
}

static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
static void addHeaderToList(std::vector<quiche_h3_header>& headers, const char* name, size_t nameLen, const char* value, size_t valueLen)
{
headers.emplace_back((quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.name = reinterpret_cast<const uint8_t*>(name),
.name_len = nameLen,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.value = reinterpret_cast<const uint8_t*>(value),
.value_len = valueLen,
});
}

static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len, const std::string& contentType = {})
{
std::string status = std::to_string(statusCode);
std::string lenStr = std::to_string(len);
std::array<quiche_h3_header, 3> headers{
(quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.name = reinterpret_cast<const uint8_t*>(":status"),
.name_len = sizeof(":status") - 1,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.value = reinterpret_cast<const uint8_t*>(status.data()),
.value_len = status.size(),
},
(quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.name = reinterpret_cast<const uint8_t*>("content-length"),
.name_len = sizeof("content-length") - 1,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.value = reinterpret_cast<const uint8_t*>(lenStr.data()),
.value_len = lenStr.size(),
},
(quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.name = reinterpret_cast<const uint8_t*>("content-type"),
.name_len = sizeof("content-type") - 1,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.value = reinterpret_cast<const uint8_t*>("application/dns-message"),
.value_len = sizeof("application/dns-message") - 1,
},
};
PacketBuffer location;
PacketBuffer responseBody;
std::vector<quiche_h3_header> headers;
headers.reserve(4);
addHeaderToList(headers, ":status", sizeof(":status") - 1, status.data(), status.size());

if (statusCode >= 300 && statusCode < 400) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
addHeaderToList(headers, "location", sizeof("location") - 1, reinterpret_cast<const char*>(body), len);
static const std::string s_redirectStart{"<!DOCTYPE html><TITLE>Moved</TITLE><P>The document has moved <A HREF=\""};
static const std::string s_redirectEnd{"\">here</A>"};
static const std::string s_redirectContentType("text/html; charset=utf-8");
addHeaderToList(headers, "content-type", sizeof("content-type") - 1, s_redirectContentType.data(), s_redirectContentType.size());
responseBody.reserve(s_redirectStart.size() + len + s_redirectEnd.size());
responseBody.insert(responseBody.begin(), s_redirectStart.begin(), s_redirectStart.end());
// NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
responseBody.insert(responseBody.end(), body, body + len);
responseBody.insert(responseBody.end(), s_redirectEnd.begin(), s_redirectEnd.end());
body = responseBody.data();
len = responseBody.size();
}
else if (len > 0 && (statusCode == 200U || !contentType.empty())) {
// do not include content-type header info if there is no content
addHeaderToList(headers, "content-type", sizeof("content-type") - 1, contentType.empty() ? "application/dns-message" : contentType.data(), contentType.empty() ? sizeof("application/dns-message") - 1 : contentType.size());
}

const std::string lenStr = std::to_string(len);
addHeaderToList(headers, "content-length", sizeof("content-length") - 1, lenStr.data(), lenStr.size());

auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(),
streamID, headers.data(),
// do not include content-type header info if there is no content
(len > 0 && statusCode == 200U ? headers.size() : headers.size() - 1),
headers.size(),
len == 0);
if (returnValue != 0) {
/* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */
Expand Down Expand Up @@ -350,13 +363,13 @@ static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16
}
}

static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content = {})
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
h3_send_response(conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
}

static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response)
static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response, const std::string& contentType)
{
if (statusCode == 200) {
++frontend.d_validResponses;
Expand All @@ -368,7 +381,7 @@ static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uin
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR));
}
else {
h3_send_response(conn, streamID, statusCode, &response.at(0), response.size());
h3_send_response(conn, streamID, statusCode, &response.at(0), response.size(), contentType);
}
}

Expand Down Expand Up @@ -471,7 +484,7 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
const auto handleImmediateResponse = [](DOH3UnitUniquePtr&& unit, [[maybe_unused]] const char* reason) {
DEBUGLOG("handleImmediateResponse() reason=" << reason);
auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response, unit->d_contentTypeOut);
unit->ids.doh3u.reset();
};

Expand Down Expand Up @@ -658,7 +671,7 @@ static void flushResponses(pdns::channel::Receiver<DOH3Unit>& receiver)
auto unit = std::move(*tmp);
auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
if (conn) {
handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response, unit->d_contentTypeOut);
}
}
catch (const std::exception& e) {
Expand Down Expand Up @@ -1078,6 +1091,13 @@ const dnsdist::doh3::h3_headers_t& DOH3Unit::getHTTPHeaders() const
return headers;
}

void DOH3Unit::setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType)
{
status_code = statusCode;
response = std::move(body);
d_contentTypeOut = contentType;
}

#else /* HAVE_DNS_OVER_HTTP3 */

std::string DOH3Unit::getHTTPPath() const
Expand Down Expand Up @@ -1106,4 +1126,8 @@ const dnsdist::doh3::h3_headers_t& DOH3Unit::getHTTPHeaders() const
return headers;
}

void DOH3Unit::setHTTPResponse(uint16_t, PacketBuffer&&, const std::string&)
{
}

#endif /* HAVE_DNS_OVER_HTTP3 */
5 changes: 4 additions & 1 deletion pdns/dnsdistdist/doh3.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
#include <unordered_map>

#include "config.h"
#include "noinitvector.hh"

#ifdef HAVE_DNS_OVER_HTTP3
#include "channel.hh"
#include "iputils.hh"
#include "libssl.hh"
#include "noinitvector.hh"
#include "stat_t.hh"
#include "dnsdist-idstate.hh"

Expand Down Expand Up @@ -93,13 +93,15 @@ struct DOH3Unit
[[nodiscard]] std::string getHTTPHost() const;
[[nodiscard]] std::string getHTTPScheme() const;
[[nodiscard]] const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const;
void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "");

InternalQueryState ids;
PacketBuffer query;
PacketBuffer response;
PacketBuffer serverConnID;
dnsdist::doh3::h3_headers_t headers;
std::shared_ptr<DownstreamState> downstream{nullptr};
std::string d_contentTypeOut;
DOH3ServerConfig* dsc{nullptr};
uint64_t streamID{0};
size_t proxyProtocolPayloadSize{0};
Expand All @@ -126,6 +128,7 @@ struct DOH3Unit
[[nodiscard]] std::string getHTTPHost() const;
[[nodiscard]] std::string getHTTPScheme() const;
[[nodiscard]] const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const;
void setHTTPResponse(uint16_t, PacketBuffer&&, const std::string&);
};

struct DOH3Frontend
Expand Down
7 changes: 5 additions & 2 deletions regression-tests.dnsdist/dnsdisttests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,15 +1151,18 @@ def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQ
return (receivedQuery, message)

@classmethod
def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None):
def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None, rawResponse=False):

if response:
if toQueue:
toQueue.put(response, True, timeout)
else:
cls._toResponderQueue.put(response, True, timeout)

message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders)
if rawResponse:
return doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)

message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)

receivedQuery = None

Expand Down
14 changes: 10 additions & 4 deletions regression-tests.dnsdist/doh3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,16 @@ async def perform_http_request(
elapsed = time.time() - start

result = bytes()
headers = {}
for http_event in http_events:
if isinstance(http_event, DataReceived):
result += http_event.data
if isinstance(http_event, StreamReset):
result = http_event
return result
if isinstance(http_event, HeadersReceived):
for k, v in http_event.headers:
headers[k] = v
return (result, headers)


async def async_h3_query(
Expand Down Expand Up @@ -220,15 +224,15 @@ async def async_h3_query(

return answer
except asyncio.TimeoutError as e:
return e
return (e,{})


def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None):
def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
if verify:
configuration.load_verify_locations(verify)

result = asyncio.run(
(result, headers) = asyncio.run(
async_h3_query(
configuration=configuration,
baseurl=baseurl,
Expand All @@ -245,4 +249,6 @@ def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname
raise StreamResetError(result.error_code)
if (isinstance(result, asyncio.TimeoutError)):
raise TimeoutError()
if raw_response:
return (result, headers)
return dns.message.from_wire(result)
Loading

0 comments on commit 72f679d

Please sign in to comment.