diff --git a/lib/vnet/dns/dns.go b/lib/vnet/dns/dns.go
new file mode 100644
index 0000000000000..546ce3de800ab
--- /dev/null
+++ b/lib/vnet/dns/dns.go
@@ -0,0 +1,402 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package dns
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/net/dns/dnsmessage"
+ "golang.org/x/sync/errgroup"
+ "gvisor.dev/gvisor/pkg/tcpip"
+
+ "github.com/gravitational/teleport"
+)
+
+const (
+ // This is the recommended EDNS maximum payload size in https://www.rfc-editor.org/rfc/rfc6891.txt
+ // While this server doesn't directly support EDNS yet for queries that actually resolve to Teleport apps,
+ // upstream nameservers may, and we shouldn't drop valid requests to those upstream servers.
+ // This is not an absolute maximum for EDNS, but it is the usual maximum in practice, and the maximum
+ // supported by bind https://github.com/isc-projects/bind9/blob/9357019498d57aef95fff94d408198e21dcc93c9/lib/dns/resolver.c#L238
+ maxUDPDNSMessageSize = 4096
+
+ // https://www.rfc-editor.org/rfc/rfc1123#page-77 recommends 5 seconds as a minimum, and this seems to be
+ // common in practice.
+ forwardRequestTimeout = 5 * time.Second
+)
+
+// Resolver represents an entity that can resolve DNS requests.
+type Resolver interface {
+ // ResolveA should return a Result for an A record question. If an empty Result is returned with no error,
+ // the question will be forwarded upstream.
+ ResolveA(ctx context.Context, domain string) (Result, error)
+
+ // ResolveAAAA should return a Result for an AAAA record question. If an empty Result is returned with no
+ // error, the question will be forwarded upstream.
+ ResolveAAAA(ctx context.Context, domain string) (Result, error)
+}
+
+// Result holds the result of DNS resolution.
+type Result struct {
+ // A is an A record.
+ A [4]byte
+ // AAAA is an AAAA record.
+ AAAA [16]byte
+ // NXDomain indicates that the requested domain is invalid or unassigned, this is an authoritative answer.
+ NXDomain bool
+ // NoRecord indicates the domain exists but the requested record type doesn't, this is an authoritative
+ // answer.
+ NoRecord bool
+}
+
+// UpstreamNameserverSource provides the current set of upstream nameservers.
+type UpstreamNameserverSource interface {
+ // UpstreamNameservers should return the current set of upstream nameservers, requests that cannot be
+ // resolved will be forwarded to these addresses.
+ UpstreamNameservers(context.Context) ([]string, error)
+}
+
+// Server is a DNS server.
+type Server struct {
+ resolver Resolver
+ upstreamNameserverSource UpstreamNameserverSource
+ messageBuffers sync.Pool
+ slog *slog.Logger
+}
+
+// NewServer returns a DNS server that handles the details of the DNS protocol and asks [resolver] to answer
+// DNS questions. If [resolver] has no answer, requests will be forwarded to the upstream nameservers provided
+// by [upstreamNameserverSource].
+func NewServer(resolver Resolver, upstreamNameserverSource UpstreamNameserverSource) (*Server, error) {
+ return &Server{
+ resolver: resolver,
+ upstreamNameserverSource: upstreamNameserverSource,
+ messageBuffers: sync.Pool{
+ New: func() any {
+ buf := make([]byte, maxUDPDNSMessageSize)
+ return &buf
+ },
+ },
+ slog: slog.With(teleport.ComponentKey, "VNet.DNS"),
+ }, nil
+}
+
+// getMessageBuffer returns a buffer of size [maxUDPMessageSize]. Call [returnBuf] to return the buffer to the
+// shared pool. Use this to avoid large allocations on all DNS messages.
+func (s *Server) getMessageBuffer() (buf []byte, returnBuf func()) {
+ buf = *s.messageBuffers.Get().(*[]byte)
+ buf = buf[:cap(buf)]
+ return buf, func() {
+ s.messageBuffers.Put(&buf)
+ }
+}
+
+// HandleUDP reads and handles a single UDP message from [conn] and writes the response back to [conn].
+// This will be called by VNet code.
+func (s *Server) HandleUDP(ctx context.Context, conn net.Conn) error {
+ buf, returnBuf := s.getMessageBuffer()
+ defer returnBuf()
+
+ n, err := conn.Read(buf)
+ if err != nil {
+ return trace.Wrap(err, "failed to read from UDP conn")
+ }
+ if n >= maxUDPDNSMessageSize {
+ return trace.BadParameter("Dropping UDP message that is too large")
+ }
+ buf = buf[:n]
+
+ return trace.Wrap(s.handleDNSMessage(ctx, conn.RemoteAddr().String(), buf, conn))
+}
+
+// ListendAndServeUDP reads all incoming UDP messages from [conn], handles DNS questions, and writes the
+// responses back to [conn].
+// This is not called by VNet code and basically exists so we can test the resolver outside of VNet.
+func (s *Server) ListenAndServeUDP(ctx context.Context, conn *net.UDPConn) error {
+ buf, returnBuf := s.getMessageBuffer()
+ defer returnBuf()
+
+ for {
+ buf = buf[:cap(buf)]
+ n, remoteAddr, err := conn.ReadFromUDP(buf)
+ if err != nil {
+ return trace.Wrap(err, "failed to read from UDP conn")
+ }
+ if n >= maxUDPDNSMessageSize {
+ return trace.BadParameter("Dropping UDP message that is too large")
+ }
+ buf = buf[:n]
+
+ responseWriter := &udpWriter{
+ conn: conn,
+ remoteAddr: remoteAddr,
+ }
+ if err := s.handleDNSMessage(ctx, remoteAddr.String(), buf, responseWriter); err != nil {
+ s.slog.DebugContext(ctx, "Error handling DNS message.", "error", err)
+ }
+ }
+}
+
+type udpWriter struct {
+ conn *net.UDPConn
+ remoteAddr *net.UDPAddr
+}
+
+func (u *udpWriter) Write(b []byte) (int, error) {
+ n, _, err := u.conn.WriteMsgUDP(b, nil /*oob*/, u.remoteAddr)
+ return n, err
+}
+
+// handleDNSMessage handles the DNS message held in [buf] and writes the answer to [responseWriter].
+// This could handle DNS messages arriving over UDP or TCP.
+func (s *Server) handleDNSMessage(ctx context.Context, remoteAddr string, buf []byte, responseWriter io.Writer) error {
+ slog := s.slog.With("remote_addr", remoteAddr)
+ slog.DebugContext(ctx, "Handling DNS message.")
+ defer slog.DebugContext(ctx, "Done handling DNS message.")
+
+ var parser dnsmessage.Parser
+ requestHeader, err := parser.Start(buf)
+ if err != nil {
+ return trace.Wrap(err, "parsing DNS message")
+ }
+ if requestHeader.OpCode != 0 {
+ slog.DebugContext(ctx, "OpCode is not QUERY (0), forwarding.", "opcode", requestHeader.OpCode)
+ return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-Query DNS message")
+ }
+ question, err := parser.Question()
+ if err != nil {
+ return trace.Wrap(err, "parsing DNS question")
+ }
+ fqdn := question.Name.String()
+ slog = slog.With("fqdn", fqdn, "type", question.Type.String())
+ slog.DebugContext(ctx, "Received DNS question.", "question", question)
+ if question.Class != dnsmessage.ClassINET {
+ slog.DebugContext(ctx, "Query class is not INET, forwarding.", "class", question.Class)
+ return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding non-INET DNS query")
+ }
+
+ var result Result
+ switch question.Type {
+ case dnsmessage.TypeA:
+ result, err = s.resolver.ResolveA(ctx, fqdn)
+ if err != nil {
+ return trace.Wrap(err, "resolving A request for %q", fqdn)
+ }
+ case dnsmessage.TypeAAAA:
+ result, err = s.resolver.ResolveAAAA(ctx, fqdn)
+ if err != nil {
+ return trace.Wrap(err, "resolving AAAA request for %q", fqdn)
+ }
+ default:
+ slog.DebugContext(ctx, "Question type is not A or AAAA, forwarding.", "type", question.Type)
+ return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding %s DNS query", question.Type)
+ }
+
+ var response []byte
+ switch {
+ case result.NXDomain:
+ slog.DebugContext(ctx, "No match for name, responding with authoritative name error.")
+ response, err = buildNXDomainResponse(buf, &requestHeader, &question)
+ case result.NoRecord:
+ slog.DebugContext(ctx, "Name matched but no record, responding with authoritative non-answer.")
+ response, err = buildEmptyResponse(buf, &requestHeader, &question)
+ case question.Type == dnsmessage.TypeA && result.A != ([4]byte{}):
+ slog.DebugContext(ctx, "Matched DNS A.", "a", tcpip.AddrFrom4(result.A))
+ response, err = buildAResponse(buf, &requestHeader, &question, result.A)
+ case question.Type == dnsmessage.TypeAAAA && result.AAAA != ([16]byte{}):
+ slog.DebugContext(ctx, "Matched DNS AAAA.", "aaaa", tcpip.AddrFrom16(result.AAAA))
+ response, err = buildAAAAResponse(buf, &requestHeader, &question, result.AAAA)
+ default:
+ slog.DebugContext(ctx, "Forwarding unmatched query.")
+ return trace.Wrap(s.forward(ctx, slog, buf, responseWriter), "forwarding unmatched DNS query")
+ }
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ _, err = responseWriter.Write(response)
+ return trace.Wrap(err, "writing DNS response")
+}
+
+// forward forwards a raw DNS message to all upstream nameservers and writes the first response to
+// [responseWriter]. If there are no upstream nameservers, or none of them responds within the timeout, an
+// error is returned. This doesn't do any retries because the downstream resolver is likely to do its own
+// retries.
+func (s *Server) forward(ctx context.Context, slog *slog.Logger, buf []byte, responseWriter io.Writer) error {
+ ctx, cancel := context.WithTimeout(ctx, forwardRequestTimeout)
+ defer cancel()
+ deadline, _ := ctx.Deadline()
+
+ upstreamNameservers, err := s.upstreamNameserverSource.UpstreamNameservers(ctx)
+ if err != nil {
+ return trace.Wrap(err, "getting host default nameservers")
+ }
+ if len(upstreamNameservers) == 0 {
+ return trace.Errorf("no upstream nameservers")
+ }
+
+ // Forward the message to each upstream nameserver concurrently, the first to answer wins.
+ // Each goroutine will write a single error or a single response to the appropriate channel.
+ // Each goroutine should quickly exit after the context is canceled.
+ responses := make(chan []byte, len(upstreamNameservers))
+ errs := make(chan error, len(upstreamNameservers))
+ g, ctx := errgroup.WithContext(ctx)
+ ctx, cancel = context.WithCancel(ctx)
+ defer cancel()
+ for _, nameserver := range upstreamNameservers {
+ responseBuf, returnResponseBuf := s.getMessageBuffer()
+ defer returnResponseBuf()
+
+ nameserver := nameserver
+ g.Go(func() error {
+ slog := slog.With("nameserver", nameserver)
+ slog.DebugContext(ctx, "Forwarding request to upstream nameserver.")
+
+ upstreamConn, err := net.Dial("udp", nameserver)
+ if err != nil {
+ errs <- trace.Wrap(err, "dialing upstream nameserver")
+ return nil
+ }
+
+ // Immediately close the upstream conn after the context is canceled to unblock any i/o. This
+ // function will not return any answer until all errgroup goroutines have terminated.
+ go func() {
+ <-ctx.Done()
+ upstreamConn.Close()
+ }()
+
+ upstreamConn.SetDeadline(deadline)
+ _, err = upstreamConn.Write(buf)
+ if err != nil {
+ errs <- trace.Wrap(err, "writing message to upstream")
+ return nil
+ }
+ n, err := upstreamConn.Read(responseBuf)
+ if err != nil {
+ errs <- trace.Wrap(err, "reading forwarded DNS response")
+ return nil
+ }
+ if n == len(responseBuf) {
+ errs <- fmt.Errorf("DNS response too large")
+ return nil
+ }
+ // Cancel all other goroutines
+ cancel()
+ responses <- responseBuf[:n]
+ return nil
+ })
+ }
+
+ // Not using the errgroup err, errors were written to channel.
+ _ = g.Wait()
+
+ select {
+ case firstResponse := <-responses:
+ slog.DebugContext(ctx, "Got response to forwarded DNS query, responding to client.")
+ _, err := responseWriter.Write(firstResponse)
+ return trace.Wrap(err, "writing DNS response")
+ default:
+ }
+
+ close(errs)
+ return trace.Wrap(trace.NewAggregateFromChannel(errs, context.Background()), "no upstream answers")
+}
+
+func buildEmptyResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question) ([]byte, error) {
+ responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess)
+ if err != nil {
+ return buf, trace.Wrap(err)
+ }
+ // TODO(nklaassen): TTL in SOA record?
+ buf, err = responseBuilder.Finish()
+ return buf, trace.Wrap(err, "serializing DNS response")
+}
+
+func buildNXDomainResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question) ([]byte, error) {
+ responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeNameError)
+ if err != nil {
+ return buf, trace.Wrap(err)
+ }
+ // TODO(nklaassen): TTL in SOA record?
+ buf, err = responseBuilder.Finish()
+ return buf, trace.Wrap(err, "serializing DNS response")
+}
+
+func buildAResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, addr [4]byte) ([]byte, error) {
+ responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess)
+ if err != nil {
+ return buf, trace.Wrap(err)
+ }
+ if err := responseBuilder.StartAnswers(); err != nil {
+ return buf, trace.Wrap(err, "starting answers section of DNS response")
+ }
+ if err := responseBuilder.AResource(dnsmessage.ResourceHeader{
+ Name: question.Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ TTL: 10,
+ }, dnsmessage.AResource{A: addr}); err != nil {
+ return buf, trace.Wrap(err, "adding AResource to DNS response")
+ }
+ buf, err = responseBuilder.Finish()
+ return buf, trace.Wrap(err, "serializing DNS response")
+}
+
+func buildAAAAResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, addr [16]byte) ([]byte, error) {
+ responseBuilder, err := prepDNSResponse(buf, requestHeader, question, dnsmessage.RCodeSuccess)
+ if err != nil {
+ return buf, trace.Wrap(err)
+ }
+ if err := responseBuilder.StartAnswers(); err != nil {
+ return buf, trace.Wrap(err, "starting answers section of DNS response")
+ }
+ if err := responseBuilder.AAAAResource(dnsmessage.ResourceHeader{
+ Name: question.Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ TTL: 10,
+ }, dnsmessage.AAAAResource{AAAA: addr}); err != nil {
+ return buf, trace.Wrap(err, "adding AAAAResource to DNS response")
+ }
+ buf, err = responseBuilder.Finish()
+ return buf, trace.Wrap(err, "serializing DNS response")
+}
+
+func prepDNSResponse(buf []byte, requestHeader *dnsmessage.Header, question *dnsmessage.Question, rcode dnsmessage.RCode) (*dnsmessage.Builder, error) {
+ buf = buf[:0]
+ responseBuilder := dnsmessage.NewBuilder(buf, dnsmessage.Header{
+ ID: requestHeader.ID,
+ Response: true,
+ Authoritative: true,
+ RCode: dnsmessage.RCodeSuccess,
+ })
+ responseBuilder.EnableCompression()
+ if err := responseBuilder.StartQuestions(); err != nil {
+ return nil, trace.Wrap(err, "starting questions section of DNS response")
+ }
+ if err := responseBuilder.Question(*question); err != nil {
+ return nil, trace.Wrap(err, "adding question to DNS response")
+ }
+ return &responseBuilder, nil
+}
diff --git a/lib/vnet/dns/dns_test.go b/lib/vnet/dns/dns_test.go
new file mode 100644
index 0000000000000..81f0bf33cc8b6
--- /dev/null
+++ b/lib/vnet/dns/dns_test.go
@@ -0,0 +1,207 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package dns
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "net"
+ "os"
+ "testing"
+
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+ "gvisor.dev/gvisor/pkg/tcpip"
+
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+var (
+ udpLocalhost = &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}
+)
+
+func TestMain(m *testing.M) {
+ utils.InitLogger(utils.LoggingForCLI, slog.LevelDebug)
+ os.Exit(m.Run())
+}
+
+// TestServer sets up a main DNS server and two upstream DNS servers, all using real UDP sockets, and tests
+// that net.Resolver can successfully use the stack to lookup hosts.
+func TestServer(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ defaultIP4 := tcpip.AddrFrom4([4]byte{1, 2, 3, 4})
+ defaultIP6 := tcpip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})
+
+ staticResolver := &staticResolver{Result{
+ A: defaultIP4.As4(),
+ AAAA: defaultIP6.As16(),
+ }}
+ noUpstreams := &stubUpstreamNamservers{}
+
+ // Create two upstream nameservers that are able to resolve A and AAAA records for all names.
+ var upstreamAddrs []string
+ for i := 0; i < 2; i++ {
+ upstreamServer, err := NewServer(staticResolver, noUpstreams)
+ require.NoError(t, err)
+ conn, err := net.ListenUDP("udp", udpLocalhost)
+ require.NoError(t, err)
+
+ utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{
+ Name: fmt.Sprintf("upstream nameserver %d", i),
+ Task: func(ctx context.Context) error {
+ err := upstreamServer.ListenAndServeUDP(ctx, conn)
+ if err == nil || utils.IsOKNetworkError(err) {
+ return nil
+ }
+ return trace.Wrap(err)
+ },
+ Terminate: conn.Close,
+ })
+
+ upstreamAddrs = append(upstreamAddrs, conn.LocalAddr().String())
+ }
+
+ // Create the nameserver under test.
+ goTeleportIPv4 := tcpip.AddrFrom4([4]byte{1, 1, 1, 1})
+ goTeleportIPv6 := tcpip.AddrFrom16([16]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
+ teleportShIPv6 := tcpip.AddrFrom16([16]byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2})
+ resolver := &stubResolver{
+ aRecords: map[string]Result{
+ "goteleport.com.": Result{
+ A: goTeleportIPv4.As4(),
+ },
+ "teleport.sh.": Result{
+ NoRecord: true,
+ },
+ "fake.example.com.": Result{
+ NXDomain: true,
+ },
+ },
+ aaaaRecords: map[string]Result{
+ "goteleport.com.": Result{
+ AAAA: goTeleportIPv6.As16(),
+ },
+ "teleport.sh.": Result{
+ AAAA: teleportShIPv6.As16(),
+ },
+ "fake.example.com.": Result{
+ NXDomain: true,
+ },
+ },
+ }
+ upstreams := &stubUpstreamNamservers{nameservers: upstreamAddrs}
+ server, err := NewServer(resolver, upstreams)
+ require.NoError(t, err)
+
+ conn, err := net.ListenUDP("udp", udpLocalhost)
+ require.NoError(t, err)
+
+ utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{
+ Name: "nameserver under test",
+ Task: func(ctx context.Context) error {
+ err := server.ListenAndServeUDP(ctx, conn)
+ if err == nil || utils.IsOKNetworkError(err) {
+ return nil
+ }
+ return trace.Wrap(err)
+ },
+ Terminate: conn.Close,
+ })
+
+ netResolver := &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
+ // Always dial the resolver under test.
+ return net.Dial(network, conn.LocalAddr().String())
+ },
+ }
+
+ for _, tc := range []struct {
+ desc string
+ host string
+ expectAddrs []string
+ expectErr string
+ }{
+ {
+ desc: "v4 and v6",
+ host: "goteleport.com.",
+ expectAddrs: []string{goTeleportIPv4.String(), goTeleportIPv6.String()},
+ },
+ {
+ desc: "only v6",
+ host: "teleport.sh.",
+ expectAddrs: []string{teleportShIPv6.String()},
+ },
+ {
+ desc: "forward to upstream",
+ host: "example.com.",
+ expectAddrs: []string{defaultIP4.String(), defaultIP6.String()},
+ },
+ {
+ desc: "no domain",
+ host: "fake.example.com.",
+ expectErr: "no such host",
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ addrs, err := netResolver.LookupHost(ctx, tc.host)
+ if tc.expectErr != "" {
+ require.ErrorContains(t, err, tc.expectErr)
+ return
+ }
+ require.NoError(t, err)
+ require.ElementsMatch(t, tc.expectAddrs, addrs)
+ })
+ }
+}
+
+type stubResolver struct {
+ aRecords map[string]Result
+ aaaaRecords map[string]Result
+}
+
+func (s *stubResolver) ResolveA(ctx context.Context, fqdn string) (Result, error) {
+ return s.aRecords[fqdn], nil
+}
+
+func (s *stubResolver) ResolveAAAA(ctx context.Context, fqdn string) (Result, error) {
+ return s.aaaaRecords[fqdn], nil
+}
+
+type staticResolver struct {
+ result Result
+}
+
+func (s *staticResolver) ResolveA(ctx context.Context, fqdn string) (Result, error) {
+ return s.result, nil
+}
+
+func (s *staticResolver) ResolveAAAA(ctx context.Context, fqdn string) (Result, error) {
+ return s.result, nil
+}
+
+type stubUpstreamNamservers struct {
+ nameservers []string
+ err error
+}
+
+func (s *stubUpstreamNamservers) UpstreamNameservers(ctx context.Context) ([]string, error) {
+ return s.nameservers, s.err
+}
diff --git a/lib/vnet/dns/osnameservers_darwin.go b/lib/vnet/dns/osnameservers_darwin.go
new file mode 100644
index 0000000000000..ef398bb0791d2
--- /dev/null
+++ b/lib/vnet/dns/osnameservers_darwin.go
@@ -0,0 +1,101 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build darwin
+// +build darwin
+
+package dns
+
+import (
+ "bufio"
+ "context"
+ "log/slog"
+ "net"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+const (
+ confFilePath = "/etc/resolv.conf"
+)
+
+type OSUpstreamNameserverSource struct {
+ ttlCache *utils.FnCache
+}
+
+func NewOSUpstreamNameserverSource() (*OSUpstreamNameserverSource, error) {
+ ttlCache, err := utils.NewFnCache(utils.FnCacheConfig{
+ TTL: 10 * time.Second,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &OSUpstreamNameserverSource{
+ ttlCache: ttlCache,
+ }, nil
+}
+
+func (s *OSUpstreamNameserverSource) UpstreamNameservers(ctx context.Context) ([]string, error) {
+ return utils.FnCacheGet(ctx, s.ttlCache, 0, s.upstreamNameservers)
+}
+
+func (s *OSUpstreamNameserverSource) upstreamNameservers(ctx context.Context) ([]string, error) {
+ f, err := os.Open(confFilePath)
+ if err != nil {
+ return nil, trace.Wrap(err, "opening %s", confFilePath)
+ }
+ defer f.Close()
+
+ var nameservers []string
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.HasPrefix(line, "nameserver ") {
+ continue
+ }
+ fields := strings.Fields(line)
+ if len(fields) < 2 {
+ continue
+ }
+ address := fields[1]
+
+ ip := net.ParseIP(address)
+ if ip == nil {
+ slog.DebugContext(ctx, "Skipping invalid IP", "ip", address)
+ continue
+ }
+
+ // Add port 53 suffix, the only port supported on MacOS.
+ var nameserver string
+ switch {
+ case ip.To4() != nil:
+ nameserver = address + ":53"
+ case ip.To16() != nil:
+ nameserver = "[" + address + "]:53"
+ default:
+ continue
+ }
+ nameservers = append(nameservers, nameserver)
+ }
+
+ slog.DebugContext(ctx, "Loaded host upstream nameservers.", "nameservers", nameservers, "source", confFilePath)
+ return nameservers, nil
+}
diff --git a/lib/vnet/dns/osnameservers_darwin_test.go b/lib/vnet/dns/osnameservers_darwin_test.go
new file mode 100644
index 0000000000000..5fdea3fb9cc4a
--- /dev/null
+++ b/lib/vnet/dns/osnameservers_darwin_test.go
@@ -0,0 +1,82 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build darwin
+// +build darwin
+
+package dns
+
+import (
+ "context"
+ "net"
+ "testing"
+
+ "github.com/gravitational/teleport/lib/utils"
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+)
+
+// TestOSUpstreamNameservers configures the DNS server to forward requests for all addresses to the OS's real
+// upstream nameservers, to test that this logic is working correctly.
+func TestOSUpstreamNameservers(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ t.Cleanup(cancel)
+
+ resolver := &stubResolver{}
+ upstreams, err := NewOSUpstreamNameserverSource()
+ require.NoError(t, err)
+ server, err := NewServer(resolver, upstreams)
+ require.NoError(t, err)
+
+ conn, err := net.ListenUDP("udp", udpLocalhost)
+ require.NoError(t, err)
+
+ utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{
+ Name: "nameserver",
+ Task: func(ctx context.Context) error {
+ err := server.ListenAndServeUDP(ctx, conn)
+ if err == nil || utils.IsOKNetworkError(err) {
+ return nil
+ }
+ return trace.Wrap(err)
+ },
+ Terminate: conn.Close,
+ })
+
+ netResolver := &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
+ // Always dial the resolver under test.
+ return net.Dial(network, conn.LocalAddr().String())
+ },
+ }
+
+ for _, tc := range []struct {
+ host string
+ }{
+ {"goteleport.com"},
+ {"teleport.sh"},
+ {"example.com"},
+ } {
+ t.Run(tc.host, func(t *testing.T) {
+ addrs, err := netResolver.LookupHost(ctx, tc.host)
+ require.NoError(t, err)
+ require.NotEmpty(t, addrs)
+ })
+ }
+}
diff --git a/lib/vnet/dns/osnameservers_other.go b/lib/vnet/dns/osnameservers_other.go
new file mode 100644
index 0000000000000..a353020fa0b67
--- /dev/null
+++ b/lib/vnet/dns/osnameservers_other.go
@@ -0,0 +1,42 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build !darwin
+// +build !darwin
+
+package dns
+
+import (
+ "context"
+ "runtime"
+
+ "github.com/gravitational/trace"
+)
+
+var (
+ // vnetNotImplemented is an error indicating that VNet is not implemented on the host OS.
+ vnetNotImplemented = &trace.NotImplementedError{Message: "VNet is not implemented on " + runtime.GOOS}
+)
+
+type OSUpstreamNameserverSource struct{}
+
+func NewOSUpstreamNameserverSource() (*OSUpstreamNameserverSource, error) {
+ return nil, trace.Wrap(vnetNotImplemented)
+}
+
+func (s *OSUpstreamNameserverSource) UpstreamNameservers(ctx context.Context) ([]string, error) {
+ return nil, trace.Wrap(vnetNotImplemented)
+}