-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[vnet][2] DNS library #40972
[vnet][2] DNS library #40972
Changes from 5 commits
6ec5750
e9a8eb9
65cee5b
7a3f027
1858bfa
32ce558
befb942
a3d1ea1
90c284e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,392 @@ | ||
// Teleport | ||
// Copyright (C) 2024 Gravitational, Inc. | ||
// | ||
// This program is free software: you can redistribute it and/or modify | ||
// it under the terms of the GNU Affero General Public License as published by | ||
// the Free Software Foundation, either version 3 of the License, or | ||
// (at your option) any later version. | ||
// | ||
// This program is distributed in the hope that it will be useful, | ||
// but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
// GNU Affero General Public License for more details. | ||
// | ||
// You should have received a copy of the GNU Affero General Public License | ||
// along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
package dns | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"io" | ||
"log/slog" | ||
"net" | ||
"sync" | ||
"time" | ||
|
||
"github.com/gravitational/trace" | ||
"golang.org/x/net/dns/dnsmessage" | ||
"golang.org/x/sync/errgroup" | ||
"gvisor.dev/gvisor/pkg/tcpip" | ||
|
||
"github.com/gravitational/teleport" | ||
) | ||
|
||
const ( | ||
maxUDPMessageSize = 65535 | ||
forwardRequestTimeout = 5 * time.Second | ||
) | ||
|
||
// Resolver represents an entity that can resolve DNS requests. | ||
type Resolver interface { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right - |
||
// ResolveA should return a Result for an A record question. If an empty Result is returned with no error, | ||
// the question will be forwarded upstream. | ||
ResolveA(ctx context.Context, domain string) (Result, error) | ||
|
||
// ResolveAAAA should return a Result for an AAAA record question. If an empty Result is returned with no | ||
// error, the question will be forwarded upstream. | ||
ResolveAAAA(ctx context.Context, domain string) (Result, error) | ||
Comment on lines
+51
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since "empty Result" means something here, it would be helpful to codify what "empty" means with a |
||
} | ||
|
||
// Result holds the result of DNS resolution. | ||
type Result struct { | ||
// A is an A record. | ||
A [4]byte | ||
// AAAA is an AAAA record. | ||
AAAA [16]byte | ||
// NXDomain indicates that the requested domain is invalid or unassigned, this is an authoritative answer. | ||
NXDomain bool | ||
// NoRecord indicates the domain exists but the requested record type doesn't, this is an authoritative | ||
// answer. | ||
NoRecord bool | ||
} | ||
|
||
// UpstreamNameserverSource provides the current set of upstream nameservers. | ||
type UpstreamNameserverSource interface { | ||
// UpstreamNameservers should return the current set of upstream nameservers, requests that cannot be | ||
// resolved will be forwarded to these addresses. | ||
UpstreamNameservers(context.Context) ([]string, error) | ||
} | ||
|
||
// Server is a DNS server. | ||
type Server struct { | ||
resolver Resolver | ||
upstreamNameserverSource UpstreamNameserverSource | ||
messageBuffers sync.Pool | ||
slog *slog.Logger | ||
} | ||
|
||
// NewServer returns a DNS server that handles the details of the DNS protocol and asks [resolver] to answer | ||
// DNS questions. If [resolver] has no answer, requests will be forwarded to the upstream nameservers provided | ||
// by [upstreamNameserverSource]. | ||
func NewServer(resolver Resolver, upstreamNameserverSource UpstreamNameserverSource) (*Server, error) { | ||
return &Server{ | ||
resolver: resolver, | ||
upstreamNameserverSource: upstreamNameserverSource, | ||
messageBuffers: sync.Pool{ | ||
New: func() any { | ||
buf := make([]byte, maxUDPMessageSize) | ||
return &buf | ||
}, | ||
}, | ||
slog: slog.With(teleport.ComponentKey, "VNet.DNS"), | ||
}, nil | ||
} | ||
|
||
// getMessageBuffer returns a buffer of size [maxUDPMessageSize]. Call [returnBuf] to return the buffer to the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. `// ... Callers MUST call [returnBuf] to return the buffer to the shared pool after use. |
||
// shared pool. Use this to avoid large allocations on all DNS messages. | ||
func (s *Server) getMessageBuffer() (buf []byte, returnBuf func()) { | ||
buf = *s.messageBuffers.Get().(*[]byte) | ||
buf = buf[:cap(buf)] | ||
return buf, func() { | ||
s.messageBuffers.Put(&buf) | ||
} | ||
} | ||
|
||
// HandleUDP reads and handles a single UDP message from [conn] and writes the response back to [conn]. | ||
// This will be called by VNet code. | ||
func (s *Server) HandleUDP(ctx context.Context, conn net.Conn) error { | ||
buf, returnBuf := s.getMessageBuffer() | ||
defer returnBuf() | ||
|
||
n, err := conn.Read(buf) | ||
rosstimothy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if err != nil { | ||
return trace.Wrap(err, "failed to read from UDP conn") | ||
} | ||
if n >= maxUDPMessageSize { | ||
return trace.BadParameter("Dropping UDP message that is too large") | ||
} | ||
buf = buf[:n] | ||
|
||
return trace.Wrap(s.handleDNSMessage(ctx, conn.RemoteAddr().String(), buf, conn)) | ||
} | ||
|
||
// ListendAndServeUDP reads all incoming UDP messages from [conn], handles DNS questions, and writes the | ||
// responses back to [conn]. | ||
// This is not called by VNet code and basically exists so we can test the resolver outside of VNet. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Smells a bit fishy, like there should be a better way to simulate calling |
||
func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn) error { | ||
buf, returnBuf := s.getMessageBuffer() | ||
defer returnBuf() | ||
|
||
for { | ||
buf = buf[:cap(buf)] | ||
n, remoteAddr, err := conn.ReadFromUDP(buf) | ||
if err != nil { | ||
return trace.Wrap(err, "failed to read from UDP conn") | ||
} | ||
if n >= maxUDPMessageSize { | ||
return trace.BadParameter("Dropping UDP message that is too large") | ||
} | ||
buf = buf[:n] | ||
|
||
responseWriter := &udpWriter{ | ||
conn: conn, | ||
remoteAddr: remoteAddr, | ||
} | ||
if err := s.handleDNSMessage(ctx, remoteAddr.String(), buf, responseWriter); err != nil { | ||
s.slog.DebugContext(ctx, "Error handling DNS message.", "error", err) | ||
} | ||
} | ||
} | ||
|
||
type udpWriter struct { | ||
conn *net.UDPConn | ||
remoteAddr *net.UDPAddr | ||
} | ||
|
||
func (u *udpWriter) Write(b []byte) (int, error) { | ||
n, _, err := u.conn.WriteMsgUDP(b, nil /*oob*/, u.remoteAddr) | ||
return n, err | ||
} | ||
|
||
// handleDNSMessage handles the DNS message held in [buf] and writes the answer to [responseWriter]. | ||
// This could handle DNS messages arriving over UDP or TCP. | ||
func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []byte, responseWriter io.Writer) error { | ||
slog := s.slog.With("remote_addr", remoteAddr) | ||
slog.DebugContext(ctx, "Handling DNS message.") | ||
defer slog.DebugContext(ctx, "Done handling DNS message.") | ||
|
||
var parser dnsmessage.Parser | ||
requestHeader, err := parser.Start(buf) | ||
if err != nil { | ||
return trace.Wrap(err, "parsing DNS message") | ||
} | ||
if requestHeader.OpCode != 0 { | ||
slog.DebugContext(ctx, "OpCode is not QUERY (0), forwarding.", "opcode", requestHeader.OpCode) | ||
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-Query DNS message") | ||
} | ||
question, err := parser.Question() | ||
if err != nil { | ||
return trace.Wrap(err, "parsing DNS question") | ||
} | ||
fqdn := question.Name.String() | ||
slog = slog.With("fqdn", fqdn, "type", question.Type.String()) | ||
slog.DebugContext(ctx, "Received DNS question.", "question", question) | ||
if question.Class != dnsmessage.ClassINET { | ||
slog.DebugContext(ctx, "Query class is not INET, forwarding.", "class", question.Class) | ||
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-INET DNS query") | ||
} | ||
|
||
var result Result | ||
switch question.Type { | ||
case dnsmessage.TypeA: | ||
result, err = s.resolver.ResolveA(ctx, fqdn) | ||
if err != nil { | ||
return trace.Wrap(err, "resolving A request for %q", fqdn) | ||
} | ||
case dnsmessage.TypeAAAA: | ||
result, err = s.resolver.ResolveAAAA(ctx, fqdn) | ||
if err != nil { | ||
return trace.Wrap(err, "resolving AAAA request for %q", fqdn) | ||
} | ||
default: | ||
} | ||
nklaassen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
var response []byte | ||
switch { | ||
case result.NXDomain: | ||
slog.DebugContext(ctx, "No match for name, responding with authoritative name error.") | ||
response, err = buildNXDomainResponse(buf, &requestHeader, &question) | ||
case result.NoRecord: | ||
slog.DebugContext(ctx, "Name matched but no record, responding with authoritative non-answer.") | ||
response, err = buildEmptyResponse(buf, &requestHeader, &question) | ||
case question.Type == dnsmessage.TypeA && result.A != ([4]byte{}): | ||
slog.DebugContext(ctx, "Matched DNS A.", "a", tcpip.AddrFrom4(result.A)) | ||
response, err = buildAResponse(buf, &requestHeader, &question, result.A) | ||
case question.Type == dnsmessage.TypeAAAA && result.AAAA != ([16]byte{}): | ||
slog.DebugContext(ctx, "Matched DNS AAAA.", "aaaa", tcpip.AddrFrom16(result.AAAA)) | ||
response, err = buildAAAAResponse(buf, &requestHeader, &question, result.AAAA) | ||
default: | ||
slog.DebugContext(ctx, "Forwarding unmatched query.") | ||
return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding unmatched DNS query") | ||
} | ||
if err != nil { | ||
return trace.Wrap(err) | ||
} | ||
|
||
_, err = responseWriter.Write(response) | ||
return trace.Wrap(err, "writing DNS response") | ||
} | ||
|
||
// forward forwards raw DNS messages to all upstream nameservers and writes the first response to | ||
// [responseWriter]. If there are no upstream nameservers, or none of them responds within the timeout, an | ||
// error is returned. | ||
func (s *Server) forward(ctx context.Context, slog *slog.Logger, buf []byte, responseWriter io.Writer) error { | ||
ctx, cancel := context.WithTimeout(ctx, forwardRequestTimeout) | ||
defer cancel() | ||
deadline, _ := ctx.Deadline() | ||
|
||
upstreamNameservers, err := s.upstreamNameserverSource.UpstreamNameservers(ctx) | ||
if err != nil { | ||
return trace.Wrap(err, "getting host default nameservers") | ||
} | ||
if len(upstreamNameservers) == 0 { | ||
return trace.Errorf("no upstream nameservers") | ||
} | ||
|
||
// Forward the message to each upstream nameserver concurrently, the first to answer wins. | ||
// Each goroutine will write a single error or a single response to the appropriate channel. | ||
// Each goroutine should quickly exit after the context is canceled. | ||
responses := make(chan []byte, len(upstreamNameservers)) | ||
errs := make(chan error, len(upstreamNameservers)) | ||
g, ctx := errgroup.WithContext(ctx) | ||
ctx, cancel = context.WithCancel(ctx) | ||
defer cancel() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this |
||
for _, nameserver := range upstreamNameservers { | ||
responseBuf, returnResponseBuf := s.getMessageBuffer() | ||
defer returnResponseBuf() | ||
|
||
nameserver := nameserver | ||
g.Go(func() error { | ||
slog := slog.With("nameserver", nameserver) | ||
slog.DebugContext(ctx, "Forwarding request to upstream nameserver.") | ||
|
||
upstreamConn, err := net.Dial("udp", nameserver) | ||
rosstimothy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if err != nil { | ||
errs <- trace.Wrap(err, "dialing upstream nameserver") | ||
return nil | ||
} | ||
|
||
// Immediately close the upstream conn after the context is canceled to unblock any i/o. This | ||
// function will not return any answer until all errgroup goroutines have terminated. | ||
go func() { | ||
<-ctx.Done() | ||
upstreamConn.Close() | ||
}() | ||
|
||
upstreamConn.SetWriteDeadline(deadline) | ||
_, err = upstreamConn.Write(buf) | ||
if err != nil { | ||
errs <- trace.Wrap(err, "writing message to upstream") | ||
return nil | ||
} | ||
upstreamConn.SetReadDeadline(deadline) | ||
nklaassen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
n, err := upstreamConn.Read(responseBuf) | ||
if err != nil { | ||
errs <- trace.Wrap(err, "reading forwarded DNS response") | ||
return nil | ||
} | ||
if n == cap(responseBuf) { | ||
nklaassen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
errs <- fmt.Errorf("DNS response too large") | ||
return nil | ||
} | ||
// Cancel all other goroutines | ||
cancel() | ||
responses <- responseBuf[:n] | ||
return nil | ||
}) | ||
} | ||
|
||
// Not using the errgroup err, errors were written to channel. | ||
_ = g.Wait() | ||
Comment on lines
+311
to
+312
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wherever you do this I suspect there's a simpler way just using go primitives instead of |
||
|
||
select { | ||
case firstResponse := <-responses: | ||
slog.DebugContext(ctx, "Got response to forwarded DNS query, responding to client.") | ||
_, err := responseWriter.Write(firstResponse) | ||
return trace.Wrap(err, "writing DNS response") | ||
default: | ||
} | ||
|
||
close(errs) | ||
return trace.Wrap(trace.NewAggregateFromChannel(errs, context.Background()), "no upstream answers") | ||
} | ||
|
||
func buildEmptyResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question) ([]byte, error) { | ||
responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess) | ||
if err != nil { | ||
return buf, trace.Wrap(err) | ||
} | ||
// TODO(nklaassen): TTL in SOA record? | ||
buf, err = responseBuilder.Finish() | ||
return buf, trace.Wrap(err, "serializing DNS response") | ||
} | ||
|
||
func buildNXDomainResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question) ([]byte, error) { | ||
responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeNameError) | ||
if err != nil { | ||
return buf, trace.Wrap(err) | ||
} | ||
// TODO(nklaassen): TTL in SOA record? | ||
buf, err = responseBuilder.Finish() | ||
return buf, trace.Wrap(err, "serializing DNS response") | ||
} | ||
|
||
func buildAResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, addr [4]byte) ([]byte, error) { | ||
responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess) | ||
if err != nil { | ||
return buf, trace.Wrap(err) | ||
} | ||
if err := responseBuilder.StartAnswers(); err != nil { | ||
return buf, trace.Wrap(err, "starting answers section of DNS response") | ||
} | ||
if err := responseBuilder.AResource(dnsmessage.ResourceHeader{ | ||
Name: question.Name, | ||
Type: dnsmessage.TypeA, | ||
Class: dnsmessage.ClassINET, | ||
TTL: 10, | ||
}, dnsmessage.AResource{A: addr}); err != nil { | ||
return buf, trace.Wrap(err, "adding AResource to DNS response") | ||
} | ||
buf, err = responseBuilder.Finish() | ||
return buf, trace.Wrap(err, "serializing DNS response") | ||
} | ||
|
||
func buildAAAAResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, addr [16]byte) ([]byte, error) { | ||
responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess) | ||
if err != nil { | ||
return buf, trace.Wrap(err) | ||
} | ||
if err := responseBuilder.StartAnswers(); err != nil { | ||
return buf, trace.Wrap(err, "starting answers section of DNS response") | ||
} | ||
if err := responseBuilder.AAAAResource(dnsmessage.ResourceHeader{ | ||
Name: question.Name, | ||
Type: dnsmessage.TypeAAAA, | ||
Class: dnsmessage.ClassINET, | ||
TTL: 10, | ||
}, dnsmessage.AAAAResource{AAAA: addr}); err != nil { | ||
return buf, trace.Wrap(err, "adding AAAAResource to DNS response") | ||
} | ||
buf, err = responseBuilder.Finish() | ||
return buf, trace.Wrap(err, "serializing DNS response") | ||
} | ||
|
||
func prepDNSResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, rcode dnsmessage.RCode) (*dnsmessage.Builder, error) { | ||
buf = buf[:0] | ||
responseBuilder := dnsmessage.NewBuilder(buf, dnsmessage.Header{ | ||
ID: requestHeader.ID, | ||
Response: true, | ||
Authoritative: true, | ||
RCode: dnsmessage.RCodeSuccess, | ||
}) | ||
responseBuilder.EnableCompression() | ||
if err := responseBuilder.StartQuestions(); err != nil { | ||
return nil, trace.Wrap(err, "starting questions section of DNS response") | ||
} | ||
if err := responseBuilder.Question(*question); err != nil { | ||
return nil, trace.Wrap(err, "adding question to DNS response") | ||
} | ||
return &responseBuilder, nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would be good to document the significance of this value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for encouraging me to actually put a bit more thought into this, since we're only handling DNS we can actually shrink this quite a bit to the (typical) maximum EDNS message size of 4096 bytes.