diff --git a/pkg/proxy/dnsserver/dnsserver.go b/pkg/proxy/dnsserver/dnsserver.go index e361ec46..eafb004a 100644 --- a/pkg/proxy/dnsserver/dnsserver.go +++ b/pkg/proxy/dnsserver/dnsserver.go @@ -154,11 +154,11 @@ func (s *DnsServer) getResolveServer() (address string, err error) { } // Look for domain record from upstream dns server -func (s *DnsServer) lookup(domain string, qtype uint16, name string) (rr []dns.RR, err error) { +func (s *DnsServer) lookup(domain string, qtype uint16, name string) ([]dns.RR, error) { address, err := s.getResolveServer() if err != nil { log.Error().Err(err).Msgf("Failed to fetch upstream dns") - return + return []dns.RR{}, nil } log.Debug().Msgf("Resolving domain %s (%d) via upstream %s", domain, qtype, address) @@ -169,40 +169,28 @@ func (s *DnsServer) lookup(domain string, qtype uint16, name string) (rr []dns.R } else { log.Warn().Err(err).Msgf("Failed to answer name %s (%d) query for %s", name, qtype, domain) } - return + return []dns.RR{}, nil } if len(res.Answer) == 0 { log.Debug().Msgf("Empty answer") } - for _, item := range res.Answer { - log.Debug().Msgf("Response: %s", item.String()) - r, errInLoop := s.convertAnswer(name, domain, item) - if errInLoop != nil { - err = errInLoop - return - } - rr = append(rr, r) - } - - return + return s.convertAnswer(name, res.Answer), nil } // Replace fully qualified domain name with short domain name in dns answer -func (s *DnsServer) convertAnswer(name, inClusterName string, actual dns.RR) (rr dns.RR, err error) { - if name != inClusterName { - var parts []string - parts = append(parts, name) - answer := strings.Split(actual.String(), "\t") - parts = append(parts, answer[1:]...) - rrStr := strings.Join(parts, " ") - rr, err = dns.NewRR(rrStr) - if err != nil { - return +func (s *DnsServer) convertAnswer(name string, answer []dns.RR) []dns.RR { + cnames := []string{name} + for _, item := range answer { + log.Debug().Msgf("Response: %s", item.String()) + if item.Header().Rrtype == dns.TypeCNAME { + cnames = append(cnames, item.(*dns.CNAME).Target) } - } else { - rr = actual } - rr.Header().Name = name - return + for _, item := range answer { + if !util.Contains(item.Header().Name, cnames) { + item.Header().Name = name + } + } + return answer } diff --git a/pkg/proxy/dnsserver/dnsserver_test.go b/pkg/proxy/dnsserver/dnsserver_test.go index 5e386267..d45fd551 100644 --- a/pkg/proxy/dnsserver/dnsserver_test.go +++ b/pkg/proxy/dnsserver/dnsserver_test.go @@ -1,20 +1,26 @@ package dnsserver import ( + "github.com/stretchr/testify/require" "testing" "github.com/miekg/dns" ) -func TestAnswerRewrite(t *testing.T) { +func TestShouldRewriteARecord(t *testing.T) { s := &DnsServer{} - actual, _ := dns.NewRR("tomcat.default.svc.cluster.local. 5 IN A 172.21.4.129") - r, err := s.convertAnswer("tomcat.", "tomcat.default.svc.cluster.local.", actual) - if err != nil { - t.Errorf("error") - return - } - if r.String() != "tomcat. 5 IN A 172.21.4.129" { - t.Errorf("error, get result: " + r.String()) - } + r1, _ := dns.NewRR("tomcat.default.svc.cluster.local. 5 IN A 172.21.4.129") + rr := []dns.RR{r1} + result := s.convertAnswer("tomcat.", rr) + require.Equal(t, "tomcat.\t5\tIN\tA\t172.21.4.129", result[0].String()) +} + +func TestShouldNotRewriteCnameRecord(t *testing.T) { + s := &DnsServer{} + r1, _ := dns.NewRR("tomcat.com. 465 IN CNAME www.tomcat.com.") + r2, _ := dns.NewRR("www.tomcat.com. 346 IN A 10.12.4.6") + rr := []dns.RR{r1, r2} + result := s.convertAnswer("tomcat.", rr) + require.Equal(t, "tomcat.\t465\tIN\tCNAME\twww.tomcat.com.", result[0].String()) + require.Equal(t, "www.tomcat.com.\t346\tIN\tA\t10.12.4.6", result[1].String()) }