Skip to content

Commit

Permalink
refactor: server URL parser (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
natesales committed Oct 21, 2023
1 parent 33d5cdf commit 14abd73
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 132 deletions.
238 changes: 109 additions & 129 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"reflect"
"regexp"
"slices"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -97,161 +98,130 @@ func txtConcat(m *dns.Msg) {
m.Answer = answers
}

// parseServer parses opts.Server and returns the server address and transport type
func parseServer() (string, transport.Type, error) {
var txp transport.Type
var host, port, scopeId string
var isHTTPS bool

// Set default protocol
if !strings.Contains(opts.Server, "://") {
txp = transport.TypePlain
} else {
txp = transport.Type(strings.Split(opts.Server, "://")[0])
if txp == "https" {
isHTTPS = true
txp = transport.TypeHTTP
}
}
// dnsStampToURL converts a DNS stamp string to a URL string
func dnsStampToURL(s string) (string, error) {
var u url.URL

// Parse DNS stamp
if strings.HasPrefix(opts.Server, "sdns://") {
parsedStamp, err := dnsstamps.NewServerStampFromString(opts.Server)
if err != nil {
return "", "", err
}
parsedStamp, err := dnsstamps.NewServerStampFromString(s)
if err != nil {
return "", err
}

switch parsedStamp.Proto {
case dnsstamps.StampProtoTypePlain:
txp = transport.TypePlain
case dnsstamps.StampProtoTypeTLS:
txp = transport.TypeTLS
case dnsstamps.StampProtoTypeDoH:
isHTTPS = true // Default to DoH (HTTPS)
txp = transport.TypeHTTP
case dnsstamps.StampProtoTypeDNSCrypt:
// DNS stamp parsing happens again in the DNSCrypt transport
return opts.Server, transport.TypeDNSCrypt, nil
default:
return "", "", fmt.Errorf("unsupported protocol %s in DNS stamp", parsedStamp.Proto.String())
}
log.Tracef("DNS stamp parsed as %s", txp)

// TODO: This might be a source of problems...we might want to be using parsedStamp.ServerAddrStr
host = parsedStamp.ProviderName
} else { // Not DNS stamp
// Remove anything before and including the first ://
host = regexp.MustCompile(`^.*://`).ReplaceAllString(opts.Server, "")

// Remove port from host
switch {
case strings.Contains(host, "[") && !strings.Contains(host, "]") ||
!strings.Contains(host, "[") && strings.Contains(host, "]"):
return "", "", fmt.Errorf("invalid IPv6 bracket notation")
case strings.Contains(host, "[") && strings.Contains(host, "]"): // IPv6 in bracket notation
portSuffix := strings.Split(host, "]:")
if len(portSuffix) > 1 { // With explicit port
port = portSuffix[1]
} else {
port = ""
}
switch parsedStamp.Proto {
case dnsstamps.StampProtoTypePlain:
u.Scheme = string(transport.TypePlain)
case dnsstamps.StampProtoTypeTLS:
u.Scheme = string(transport.TypeTLS)
case dnsstamps.StampProtoTypeDoH:
u.Scheme = string(transport.TypeHTTP) + "s" // default to HTTPS
case dnsstamps.StampProtoTypeDNSCrypt:
// DNS stamp parsing happens again in the DNSCrypt transport, so pass the input along unchanged
return s, nil
default:
return "", fmt.Errorf("unsupported protocol %s in DNS stamp", parsedStamp.Proto.String())
}

host = strings.Split(strings.Split(host, "[")[1], "]")[0]
// TODO: This might be a source of problems...we might want to be using parsedStamp.ServerAddrStr
u.Host = parsedStamp.ProviderName

// Remove IPv6 scope ID
if strings.Contains(host, "%") {
parts := strings.Split(host, "%")
host = parts[0]
scopeId = parts[1]
}
log.Tracef("DNS stamp parsed into URL as %s", u.String())
return u.String(), nil
}

host = "[" + host + "]"
log.Tracef("host contains ], treating as v6 with port. host: %s port: %s", host, port)
case strings.Contains(host, ".") && strings.Contains(host, ":"): // IPv4 or hostname with port
parts := strings.Split(host, ":")
host = parts[0]
port = parts[1]
log.Tracef("host contains . and :, treating as (v4 or host) with explicit port. host %s port %s", host, port)
case strings.Contains(host, ":") && !strings.Contains(host, "/"): // IPv6 no port
// Remove IPv6 scope ID
if strings.Contains(host, "%") {
parts := strings.Split(host, "%")
host = parts[0]
scopeId = parts[1]
}
// setPort sets the port of a url.URL
func setPort(u *url.URL, port int) {
if strings.Contains(u.Host, ":") {
if strings.Contains(u.Host, "[") && strings.Contains(u.Host, "]") {
u.Host = fmt.Sprintf("%s]:%d", strings.Split(u.Host, "]")[0], port)
return
}
u.Host = "[" + u.Host + "]"
}
u.Host = fmt.Sprintf("%s:%d", u.Host, port)
}

host = "[" + host + "]"
log.Tracef("host contains :, treating as v6 without port. host %s", host)
default:
log.Tracef("no cases matched for host %s port %s", host, port)
// parseServer is a revised version of parseServer that uses the URL package for parsing
func parseServer(s string) (string, transport.Type, error) {
// Remove IPv6 scope ID if present
var scopeId string
v6scopeRe := regexp.MustCompile(`\[[a-fA-F0-9:]+%[a-zA-Z0-9]+]`)
if v6scopeRe.MatchString(s) {
v6scopeRemoveRe := regexp.MustCompile(`(%[a-zA-Z0-9]+)`)
matches := v6scopeRemoveRe.FindStringSubmatch(s)
if len(matches) > 1 {
scopeId = matches[1]
s = v6scopeRemoveRe.ReplaceAllString(s, "")
}
log.Tracef("Removed IPv6 scope ID %s from server %s", scopeId, s)
}

// Validate ODoH
if opts.ODoHProxy != "" {
if !strings.HasPrefix(opts.ODoHProxy, "https://") {
return "", "", fmt.Errorf("ODoH proxy must use HTTPS")
// Handle DNS stamp
if strings.HasPrefix(s, "sdns://") {
var err error
s, err = dnsStampToURL(s)
if err != nil {
return "", "", fmt.Errorf("converting DNS stamp to URL: %s", err)
}
if !strings.HasPrefix(opts.Server, "https://") {
return "", "", fmt.Errorf("ODoH target must use HTTPS")
// If s is still a DNS stamp, it's DNSCrypt
if strings.HasPrefix(s, "sdns://") {
return s, transport.TypeDNSCrypt, nil
}
}

if port == "" {
switch txp {
case transport.TypeQUIC:
port = "853"
case transport.TypeTLS:
port = "853"
case transport.TypeHTTP:
if isHTTPS {
port = "443"
} else {
port = "80"
}
case transport.TypePlain, transport.TypeTCP:
port = "53"
}
log.Tracef("Setting port to %s", port)
} else {
log.Tracef("Port is %s, not overriding", port)
// Check if server starts with a scheme, if not, default to plain
schemeRe := regexp.MustCompile(`^[a-zA-Z0-9]+://`)
if !schemeRe.MatchString(s) {
s = "plain://" + s
}

urlScheme := string(txp)
if isHTTPS {
urlScheme = "https"
// Parse server as URL
tu, err := url.Parse(s)
if err != nil {
return "", "", fmt.Errorf("parsing %s: %s", s, err)
}

fqdn := urlScheme + "://" + host
if txp != transport.TypeHTTP {
fqdn += ":" + port
// Parse transport type
ts := transport.Type(tu.Scheme)
if tu.Scheme == "https" { // Override HTTPS to HTTP, preserving tu.Scheme as HTTPS
ts = transport.TypeHTTP
}
log.Tracef("checking FQDN %s", fqdn)
u, err := url.Parse(fqdn)
if err != nil {
return "", "", err
if !slices.Contains(transport.Types, ts) {
return "", "", fmt.Errorf("unsupported transport %s. expected: %+v", ts, transport.Types)
}

server := host + ":" + port
// Set default port
if tu.Port() == "" {
switch ts {
case transport.TypeQUIC, transport.TypeTLS:
setPort(tu, 853)
case transport.TypeHTTP:
if tu.Scheme == "https" {
setPort(tu, 443)
} else {
setPort(tu, 80)
}
case transport.TypePlain, transport.TypeTCP:
setPort(tu, 53)
}
}

if txp == transport.TypeHTTP {
port = strings.Split(port, "/")[0]
u.Host += ":" + port
server = u.String()
// Add default path if missing
if ts == transport.TypeHTTP && tu.Path == "" {
tu.Path = "/dns-query"
}

// Add default path if missing
if u.Path == "" {
server += "/dns-query"
log.Tracef("HTTPS scheme and no path, setting server to %s", server)
}
server := tu.String()
// Remove scheme from server if irrelevant to protocol
if ts != transport.TypeHTTP {
server = strings.Split(server, "://")[1]
}

// Insert scope ID before ']'
// Add IPv6 scope ID back to server
if scopeId != "" {
server = strings.Replace(server, "]", "%"+scopeId+"]", 1)
server = strings.Replace(server, "]", scopeId+"]", 1)
}

return server, txp, nil
return server, ts, nil
}

// driver is the "main" function for this program that accepts a flag slice for testing
Expand Down Expand Up @@ -387,6 +357,16 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
}
}

// Validate ODoH
if opts.ODoHProxy != "" {
if !strings.HasPrefix(opts.ODoHProxy, "https://") {
return fmt.Errorf("ODoH proxy must use HTTPS")
}
if !strings.HasPrefix(opts.Server, "https://") {
return fmt.Errorf("ODoH target must use HTTPS")
}
}

if opts.Chaos {
log.Debug("Flag set, using chaos class")
opts.Class = dns.ClassCHAOS
Expand Down Expand Up @@ -439,7 +419,7 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
)

// Parse server address and transport type
server, transportType, err := parseServer()
server, transportType, err := parseServer(opts.Server)
if err != nil {
return err
}
Expand Down
4 changes: 1 addition & 3 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,7 @@ func TestMainParseServer(t *testing.T) {
},
} {
t.Run(tc.Server, func(t *testing.T) {
clearOpts()
opts.Server = tc.Server
server, transportType, err := parseServer()
server, transportType, err := parseServer(tc.Server)
assert.Nil(t, err)
assert.Equal(t, tc.ExpectedHost, server)
assert.Equal(t, tc.Type, transportType)
Expand Down
3 changes: 3 additions & 0 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ const (
TypeDNSCrypt Type = "dnscrypt"
)

// Types is a list of all supported transports
var Types = []Type{TypePlain, TypeTCP, TypeTLS, TypeHTTP, TypeQUIC, TypeDNSCrypt}

// Interface guards
var (
_ Transport = (*Plain)(nil)
Expand Down

0 comments on commit 14abd73

Please sign in to comment.