diff --git a/filter.go b/filter.go index 05421b2..1452c36 100644 --- a/filter.go +++ b/filter.go @@ -3,6 +3,7 @@ package main import ( "errors" "github.com/miekg/dns" + "strings" ) type nameFilter map[string]bool @@ -12,6 +13,7 @@ func (n *nameFilter) AddString(name string) error { *n = make(map[string]bool) } + name = strings.ToLower(name) if name[len(name)-1] != '.' { name += "." } @@ -38,7 +40,7 @@ func (n nameFilter) Lookup(name []byte) bool { break } - if n[string(name)] { + if n[strings.ToLower(string(name))] { return true } name = rest[llen:] @@ -76,6 +78,14 @@ var errShortMessage = errors.New("DNS Message too short") var errTruncMessage = errors.New("DNS Message truncated") var errInvalidQname = errors.New("Invalid qname") +func lowerByte(b byte) byte { + const lower = byte('a') - byte('A') + if b >= byte('A') && b <= byte('Z') { + return b + lower + } + return b +} + func (n nameFilter) FilterMsgQname(m []byte) (bool, error) { if n == nil { return false, nil @@ -93,19 +103,31 @@ func (n nameFilter) FilterMsgQname(m []byte) (bool, error) { m = m[12:] - var qnameLen int + var name []byte for i := 0; i < len(m); i += int(m[i]) + 1 { - if m[i] == 0 { - qnameLen = i + 1 + llen := int(m[i]) + if llen == 0 { + name = append(name, 0) break } - if m[i] > 63 { + if llen > 63 { return false, errInvalidQname } + + lend := llen + i + 1 + if lend >= len(m) { + return false, errTruncMessage + } + name = append(name, byte(llen)) + + label := m[i+1 : lend] + for _, b := range label { + name = append(name, lowerByte(b)) + } } - if qnameLen == 0 { + if len(name) == 0 { return false, errTruncMessage } - return n.Lookup(m[:qnameLen]), nil + return n.Lookup(name), nil }